From 8988c895f06cf9a3be42030e5fdd7c4e0f5a079c Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Wed, 10 Jul 2024 09:31:32 -0700 Subject: [PATCH 001/483] Prevent AMP exit --- .../core/grappler/optimizers/auto_mixed_precision.cc | 6 ++++-- .../grappler/optimizers/auto_mixed_precision_test.cc | 5 ++++- tensorflow/core/util/util.cc | 10 ++++++++++ tensorflow/core/util/util.h | 2 ++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 1c7a7712efefed..f8c3785d0f8946 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -2310,9 +2310,11 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, << " graph optimizer"; return absl::OkStatus(); } - // Check if CPU supports FP16 + // Check if CPU supports FP16, oneDNN supports FP16 on + // some platforms by converting to and from FP32 if (mode_ == AutoMixedPrecisionMode::FP16_CPU && - !IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF)) { + !IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF) && + !IsAVXConvertSupportedByOneDNNOnThisCPU()) { VLOG(1) << "No support for " << name() << " graph optimizer on CPU"; return absl::OkStatus(); } diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 0b855be91f8099..f3def370cee3a6 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -132,7 +132,10 @@ class AutoMixedPrecisionTest : public GrapplerTest { bool is_fp16_enabled_on_cpu = false; #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - is_fp16_enabled_on_cpu = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); + // oneDNN supports FP16 on some platforms by converting to and from FP32 + is_fp16_enabled_on_cpu = + IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF) || + IsAVXConvertSupportedByOneDNNOnThisCPU(); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 if (!IsMKLEnabled() || !is_fp16_enabled_on_cpu) { GTEST_SKIP() << "This device doesn't support FP16"; diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc index 6e9c20d0a39671..0c2cec78500d97 100644 --- a/tensorflow/core/util/util.cc +++ b/tensorflow/core/util/util.cc @@ -187,4 +187,14 @@ bool IsAMXDataTypeSupportedByOneDNNOnThisCPU(const DataType& dt) { return result; } +// Check if oneDNN supports AVX-NE-CONVERT on CPU +bool IsAVXConvertSupportedByOneDNNOnThisCPU() { + bool result = false; +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + using port::TestCPUFeature; + result = TestCPUFeature(port::CPUFeature::AVX_NE_CONVERT); +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 + return result; +} + } // namespace tensorflow diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h index 8bcbba6c6cf52b..701c423045da8f 100644 --- a/tensorflow/core/util/util.h +++ b/tensorflow/core/util/util.h @@ -71,6 +71,8 @@ bool IsDataTypeSupportedByOneDNNOnThisCPU(const DataType& dt); // Check if input type supports AMX on CPU when oneDNN is enabled bool IsAMXDataTypeSupportedByOneDNNOnThisCPU(const DataType& dt); +bool IsAVXConvertSupportedByOneDNNOnThisCPU(); + } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_UTIL_H_ From 174f9e1d0757b64a040768db44f411a13e6934e5 Mon Sep 17 00:00:00 2001 From: Ye Huang Date: Mon, 19 Aug 2024 22:41:02 +0800 Subject: [PATCH 002/483] Update tpu_cluster_resolver.py FIX: TPUClusterResolver.connect missing arguments --- .../distribute/cluster_resolver/tpu/tpu_cluster_resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py index 4af2925ebd8650..6e69ba14152cf2 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py @@ -144,7 +144,7 @@ def connect(tpu=None, """ resolver = TPUClusterResolver(tpu, zone, project) remote.connect_to_cluster(resolver) - tpu_strategy_util.initialize_tpu_system_impl(resolver) + tpu_strategy_util.initialize_tpu_system_impl(resolver, TPUClusterResolver) return resolver @staticmethod From 1f15bdb721aab6b6c9dc076de34eee6fdc0b286f Mon Sep 17 00:00:00 2001 From: Venkat6871 Date: Fri, 13 Sep 2024 13:08:06 +0530 Subject: [PATCH 003/483] Fix typos in documentation strings --- tensorflow/python/autograph/operators/control_flow.py | 2 +- tensorflow/python/autograph/tests/loop_basic_test.py | 4 ++-- tensorflow/python/client/tf_session_helper.cc | 2 +- .../python/compiler/tensorrt/model_tests/model_handler.py | 6 +++--- tensorflow/python/compiler/tensorrt/test/base_test.py | 2 +- .../python/compiler/tensorrt/test/biasadd_matmul_test.py | 2 +- tensorflow/python/compiler/tensorrt/test/cast_test.py | 2 +- .../compiler/tensorrt/test/dynamic_input_shapes_test.py | 2 +- tensorflow/python/compiler/tensorrt/test/int32_test.py | 2 +- .../python/compiler/tensorrt/test/shape_output_test.py | 2 +- .../compiler/tensorrt/test/tf_trt_integration_test_base.py | 4 ++-- tensorflow/python/compiler/tensorrt/utils.py | 2 +- 12 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index 8386c0c9edac54..88ab9a8ed3d63b 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -856,7 +856,7 @@ def guarded_test(): 1, 'Caught error while evaluating while loop condition', exc_info=True) - # TODO(mdan): distinguish beteen these two cases. + # TODO(mdan): distinguish between these two cases. raise NotImplementedError( 'The condition of while loop started as non-Tensor, then changed to' ' Tensor. This may happen either because variables changed type, or' diff --git a/tensorflow/python/autograph/tests/loop_basic_test.py b/tensorflow/python/autograph/tests/loop_basic_test.py index 45848329dfe591..29804a51c1c1f1 100644 --- a/tensorflow/python/autograph/tests/loop_basic_test.py +++ b/tensorflow/python/autograph/tests/loop_basic_test.py @@ -249,7 +249,7 @@ def test_while_one_var(self, n, type_): )) def test_for_one_var(self, l, type_, xla): if type_ is _int_dataset and xla: - self.skipTest('Datsets not supported in XLA') + self.skipTest('Datasets not supported in XLA') if type_ is _int_tensor and xla and not l: self.skipTest('Empty loops not supported in XLA') @@ -298,7 +298,7 @@ def test_while_two_vars(self, n, type_): )) def test_for_two_vars(self, l, type_, xla): if type_ is _int_dataset and xla: - self.skipTest('Datsets not supported in XLA') + self.skipTest('Datasets not supported in XLA') if type_ is _int_tensor and xla and not l: self.skipTest('Empty loops not supported in XLA') diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index f06c65e7837d2e..585d770d7c735e 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -191,7 +191,7 @@ void MakeCallableHelper(tensorflow::Session* session, callable_options->length)) { tsl::Set_TF_Status_from_Status( out_status, - absl::InvalidArgumentError("Unparseable CallableOptions proto")); + absl::InvalidArgumentError("Unparsable CallableOptions proto")); return; } tensorflow::Session::CallableHandle handle; diff --git a/tensorflow/python/compiler/tensorrt/model_tests/model_handler.py b/tensorflow/python/compiler/tensorrt/model_tests/model_handler.py index 3bb52085084a41..c12bc56f44abd4 100644 --- a/tensorflow/python/compiler/tensorrt/model_tests/model_handler.py +++ b/tensorflow/python/compiler/tensorrt/model_tests/model_handler.py @@ -557,7 +557,7 @@ def run(self, class _ModelHandlerManagerBase(metaclass=abc.ABCMeta): - """Manages a series of ModelHandlers for aggregrated testing/benchmarking.""" + """Manages a series of ModelHandlers for aggregated testing/benchmarking.""" def __init__( self, name: str, model_config: ModelConfig, @@ -657,14 +657,14 @@ def run_model(model, **kwargs): class ModelHandlerManagerV1(_ModelHandlerManagerBase): - """Manages a series of ModelHandlers for aggregrated testing/benchmarking in TF1.""" + """Manages a series of ModelHandlers for aggregated testing/benchmarking in TF1.""" model_handler_cls = ModelHandlerV1 trt_model_handler_cls = TrtModelHandlerV1 class ModelHandlerManagerV2(_ModelHandlerManagerBase): - """Manages a series of ModelHandlers for aggregrated testing/benchmarking in TF2.""" + """Manages a series of ModelHandlers for aggregated testing/benchmarking in TF2.""" model_handler_cls = ModelHandlerV2 trt_model_handler_cls = TrtModelHandlerV2 diff --git a/tensorflow/python/compiler/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py index 203ea5cb8d38ce..2cd8ad5990a64d 100644 --- a/tensorflow/python/compiler/tensorrt/test/base_test.py +++ b/tensorflow/python/compiler/tensorrt/test/base_test.py @@ -120,7 +120,7 @@ def ExpectedEnginesToBuild(self, run_params): def setUp(self): super().setUp() # Disable layout optimizer, since it will convert BiasAdd with NHWC - # format to NCHW format under four dimentional input. + # format to NCHW format under four dimensional input. self.DisableNonTrtOptimizers() def ShouldRunTest(self, run_params): diff --git a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py index 25576870c54a31..b37760eefcc836 100644 --- a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py @@ -105,7 +105,7 @@ def GetParams(self): def setUp(self): super().setUp() # Disable layout optimizer, since it will convert BiasAdd with NHWC - # format to NCHW format under four dimentional input. + # format to NCHW format under four dimensional input. self.DisableNonTrtOptimizers() def GetMaxBatchSize(self, run_params): diff --git a/tensorflow/python/compiler/tensorrt/test/cast_test.py b/tensorflow/python/compiler/tensorrt/test/cast_test.py index c5d95ade05ca7c..bb33746f04ca31 100644 --- a/tensorflow/python/compiler/tensorrt/test/cast_test.py +++ b/tensorflow/python/compiler/tensorrt/test/cast_test.py @@ -25,7 +25,7 @@ class CastInt32ToFp32Test(trt_test.TfTrtIntegrationTestBase): - """Tests cast to FP32 are splitted in FP16 mode.""" + """Tests cast to FP32 are split in FP16 mode.""" def _ConstOp(self, shape, dtype): return constant_op.constant(np.random.randn(*shape), dtype=dtype) diff --git a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py index f6e26ffac02f30..f9aaef20625f56 100644 --- a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py +++ b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py @@ -79,7 +79,7 @@ def GetParams(self): def setUp(self): super().setUp() # Disable layout optimizer, since it will convert BiasAdd with NHWC - # format to NCHW format under four dimentional input. + # format to NCHW format under four dimensional input. self.DisableNonTrtOptimizers() def ExpectedEnginesToBuild(self, run_params): diff --git a/tensorflow/python/compiler/tensorrt/test/int32_test.py b/tensorflow/python/compiler/tensorrt/test/int32_test.py index 21517e884f08d9..88100cc00cb0fc 100644 --- a/tensorflow/python/compiler/tensorrt/test/int32_test.py +++ b/tensorflow/python/compiler/tensorrt/test/int32_test.py @@ -45,7 +45,7 @@ def GetParams(self): def setUp(self): super().setUp() # Disable layout optimizer, since it will convert BiasAdd with NHWC - # format to NCHW format under four dimentional input. + # format to NCHW format under four dimensional input. self.DisableNonTrtOptimizers() def GetMaxBatchSize(self, run_params): diff --git a/tensorflow/python/compiler/tensorrt/test/shape_output_test.py b/tensorflow/python/compiler/tensorrt/test/shape_output_test.py index 1485f5918b0d61..eab70a4d74dd47 100644 --- a/tensorflow/python/compiler/tensorrt/test/shape_output_test.py +++ b/tensorflow/python/compiler/tensorrt/test/shape_output_test.py @@ -229,7 +229,7 @@ def ExpectedEnginesToBuild(self, run_params): return [] def ShouldRunTest(self, run_params): - # We cannot calibrate without bulding the engine, we turn of INT8 test. + # We cannot calibrate without building the engine, we turn of INT8 test. return (run_params.dynamic_shape and run_params.precision_mode != "INT8", "no calibration dynamic shape") diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 06784c0910641c..c3b1f97884d836 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -736,7 +736,7 @@ def _VerifyMaxBatchSizeAnnotations( original_gdef: GraphDef. The graph def before TensorRT conversion. converted_gdef: GraphDef. The graph def after TensorRT conversion. default_max_batch_size: The default maximum batch size to use if no node - inside a segment is annoted with a customized max batch size. This value + inside a segment is annotated with a customized max batch size. This value is None when the graph is converted to TF-TRT with dynamic engines. expected_max_batch_sizes: Optional. A sequence of max batch sizes for all the engines. `None` if does not check enforce max batch sizes. @@ -769,7 +769,7 @@ def _DetectStaticBatchSize(node_def): It is incorrect to use the output shapes to find the batch size of an operation, as the segmenter actually uses the input shapes. However, it is - a simplication and works for most of the cases for the test purposes. + a simplification and works for most of the cases for the test purposes. Args: node_def: `tf.NodeDef`. The target node for analysis. diff --git a/tensorflow/python/compiler/tensorrt/utils.py b/tensorflow/python/compiler/tensorrt/utils.py index a908f920b14996..7eadddae2f1517 100644 --- a/tensorflow/python/compiler/tensorrt/utils.py +++ b/tensorflow/python/compiler/tensorrt/utils.py @@ -242,7 +242,7 @@ def draw_graphdef_as_graphviz(graphdef, dot_output_filename): print(" }", file=f) - # Step 3: Alignement of the legend with the graph. + # Step 3: Alignment of the legend with the graph. print("\n edge[style=\"invisible\", dir=\"none\"];", file=f) for dtype in dtype_index.keys(): for node_name in nodes_with_no_inputs: From 5dac5825a7b084a1f37513819945aa10c3008339 Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Tue, 3 Sep 2024 17:39:46 -0700 Subject: [PATCH 004/483] [Tosa] Update Sin/Cos operators legalization - with the introduction of tosa.sin and tosa.cos ops - update the legalization to do direct mapping Signed-off-by: Jerry Ge --- .../mlir/tosa/tests/tf-to-tosa-pipeline.mlir | 42 +------ .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 44 +------- .../mlir/tosa/transforms/legalize_common.cc | 103 ------------------ .../mlir/tosa/transforms/legalize_common.h | 4 - .../mlir/tosa/transforms/legalize_tf.cc | 37 ++----- .../mlir/tosa/transforms/legalize_tfl.cc | 35 ++---- 6 files changed, 28 insertions(+), 237 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 11648c9572b63c..3f3b7bcc9ef7a9 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -418,56 +418,18 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_sin -// CHECK-SAME: -> tensor<10xf32> +// CHECK: %[[VAR0:.*]] = tosa.sin %arg0 func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}} : tensor<513xi16>}> - // CHECK-DAG: %[[IN_SCALED:.+]] = tosa.mul %arg0, %[[IN_SCALE]] - // CHECK-DAG: %[[FLOOR:.+]] = tosa.floor %[[IN_SCALED]] - // CHECK-DAG: %[[SUB1:.+]] = tosa.sub %[[IN_SCALED]], %[[FLOOR]] - // CHECK-DAG: %[[MUL1:.+]] = tosa.mul %[[SUB1]], %[[TWO]] - // CHECK-DAG: %[[SUB2:.+]] = tosa.sub %[[MUL1]], %[[ONE]] - // CHECK-DAG: %[[MUL2:.+]] = tosa.mul %[[SUB2]], %[[INT_MAX]] - // CHECK-DAG: %[[TO_INT:.+]] = tosa.cast %[[MUL2]] - // CHECK-DAG: %[[TABLE:.+]] = tosa.table %[[TO_INT]], %[[TBLVAL]] - // CHECK-DAG: %[[TABLE_CAST:.+]] = tosa.cast %[[TABLE]] - // CHECK-DAG: %[[RESULT:.+]] = tosa.mul %[[TABLE_CAST:.+]], %[[RESULT_SCALE]] %0 = "tf.Sin"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> - - // CHECK: return %[[RESULT]] func.return %0 : tensor<*xf32> } // ----- // CHECK-LABEL: test_cos -// CHECK-SAME: -> tensor<10xf32> +// CHECK: %[[VAR0:.*]] = tosa.cos %arg0 func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> - // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() <{value = dense<1.57079637> : tensor<1xf32>}> - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}} : tensor<513xi16>}> - // CHECK-DAG: %[[IN_TRANSLATE:.+]] = tosa.add %arg0, %[[HALF_PI]] - // CHECK-DAG: %[[IN_SCALED:.+]] = tosa.mul %[[IN_TRANSLATE]], %[[IN_SCALE]] - // CHECK-DAG: %[[FLOOR:.+]] = tosa.floor %[[IN_SCALED]] - // CHECK-DAG: %[[SUB1:.+]] = tosa.sub %[[IN_SCALED]], %[[FLOOR]] - // CHECK-DAG: %[[MUL1:.+]] = tosa.mul %[[SUB1]], %[[TWO]] - // CHECK-DAG: %[[SUB2:.+]] = tosa.sub %[[MUL1]], %[[ONE]] - // CHECK-DAG: %[[MUL2:.+]] = tosa.mul %[[SUB2]], %[[INT_MAX]] - // CHECK-DAG: %[[TO_INT:.+]] = tosa.cast %[[MUL2]] - // CHECK-DAG: %[[TABLE:.+]] = tosa.table %[[TO_INT]], %[[TBLVAL]] - // CHECK-DAG: %[[TABLE_CAST:.+]] = tosa.cast %[[TABLE]] - // CHECK-DAG: %[[RESULT:.+]] = tosa.mul %[[TABLE_CAST:.+]], %[[RESULT_SCALE]] %0 = "tf.Cos"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> - - // CHECK: return %[[RESULT]] func.return %0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 19f0c6e216c259..e12c0a9ae0b38e 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -818,56 +818,20 @@ func.func @test_sign(%arg0: tensor<21x45xi32>) -> tensor<21x45xi32> { // ----- // CHECK-LABEL: test_sin -// CHECK-SAME: -> tensor<10xf32> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32> +// CHECK: %[[VAL_1:.*]] = tosa.sin %[[VAL_0]] : (tensor<10xf32> func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> - // CHECK-DAG: %[[IN_SCALED:.+]] = tosa.mul %arg0, %[[IN_SCALE]] - // CHECK-DAG: %[[FLOOR:.+]] = tosa.floor %[[IN_SCALED]] - // CHECK-DAG: %[[SUB1:.+]] = tosa.sub %[[IN_SCALED]], %[[FLOOR]] - // CHECK-DAG: %[[MUL1:.+]] = tosa.mul %[[SUB1]], %[[TWO]] - // CHECK-DAG: %[[SUB2:.+]] = tosa.sub %[[MUL1]], %[[ONE]] - // CHECK-DAG: %[[MUL2:.+]] = tosa.mul %[[SUB2]], %[[INT_MAX]] - // CHECK-DAG: %[[TO_INT:.+]] = tosa.cast %[[MUL2]] - // CHECK-DAG: %[[TABLE:.+]] = tosa.table %[[TO_INT]], %[[TBLVAL]] - // CHECK-DAG: %[[TABLE_CAST:.+]] = tosa.cast %[[TABLE]] - // CHECK-DAG: %[[RESULT:.+]] = tosa.mul %[[TABLE_CAST:.+]], %[[RESULT_SCALE]] %0 = "tfl.sin"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> - - // CHECK: return %[[RESULT]] func.return %0 : tensor<*xf32> } // ----- // CHECK-LABEL: test_cos -// CHECK-SAME: -> tensor<10xf32> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32> +// CHECK: %[[VAL_1:.*]] = tosa.cos %[[VAL_0]] : (tensor<10xf32> func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> - // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() <{value = dense<1.57079637> : tensor<1xf32>}> - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> - // CHECK-DAG: %[[IN_TRANSLATE:.+]] = tosa.add %arg0, %[[HALF_PI]] - // CHECK-DAG: %[[IN_SCALED:.+]] = tosa.mul %[[IN_TRANSLATE]], %[[IN_SCALE]] - // CHECK-DAG: %[[FLOOR:.+]] = tosa.floor %[[IN_SCALED]] - // CHECK-DAG: %[[SUB1:.+]] = tosa.sub %[[IN_SCALED]], %[[FLOOR]] - // CHECK-DAG: %[[MUL1:.+]] = tosa.mul %[[SUB1]], %[[TWO]] - // CHECK-DAG: %[[SUB2:.+]] = tosa.sub %[[MUL1]], %[[ONE]] - // CHECK-DAG: %[[MUL2:.+]] = tosa.mul %[[SUB2]], %[[INT_MAX]] - // CHECK-DAG: %[[TO_INT:.+]] = tosa.cast %[[MUL2]] - // CHECK-DAG: %[[TABLE:.+]] = tosa.table %[[TO_INT]], %[[TBLVAL]] - // CHECK-DAG: %[[TABLE_CAST:.+]] = tosa.cast %[[TABLE]] - // CHECK-DAG: %[[RESULT:.+]] = tosa.mul %[[TABLE_CAST:.+]], %[[RESULT_SCALE]] %0 = "tfl.cos"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> - - // CHECK: return %[[RESULT]] func.return %0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 25707c2bde1331..8d6b423b20b38f 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -4586,109 +4586,6 @@ std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, .getResult(); } -// Lowers Sin operator to a sequence of TOSA ops. -std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, - Value input, ShapedType output_type) { - RankedTensorType input_type = dyn_cast(input.getType()); - Location loc = op->getLoc(); - - Type input_ety = input_type.getElementType(); - Type output_ety = output_type.getElementType(); - - if (!input) return std::nullopt; - - if (input_ety != output_ety) { - (void)rewriter.notifyMatchFailure(op, - "input/output element type must match"); - return std::nullopt; - } - - bool input_is_fp = input_ety.isF32(); - bool output_is_fp = output_ety.isF32(); - - if (!input_is_fp || !output_is_fp) { - (void)rewriter.notifyMatchFailure(op, "input/result must be fp32"); - return std::nullopt; - } - - // To perform a sin operation we remap the sin domain to be over a single - // period of the function, remapping to the domain of the table function. - // We then remap the range of the table function to map to the range of the - // sin operation. - - // 1. Normalize the period of the domain from [0, 2π) to [0, 1). - auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); - Value fp_scale = rewriter.create( - loc, fp_scalar_ty, - DenseElementsAttr::get(fp_scalar_ty, {static_cast(0.5 / M_PI)})); - - // 2. Remap the periodic behavior of the domain to line up within [0, 1). - Value fp_scaled = CreateOpAndInfer( - rewriter, loc, input_type, input, fp_scale, rewriter.getI8IntegerAttr(0)); - auto floored = - CreateOpAndInfer(rewriter, loc, input_type, fp_scaled); - auto repeated = CreateOpAndInfer(rewriter, loc, input_type, - fp_scaled, floored); - - // 3. Scale and translate the normalized domain to the table domain. This - // includes a translating and scaling to [-int16_max, int16_max] and casting - // to an i16. - Value one = rewriter.create( - loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {1.0f})); - - Value two = rewriter.create( - loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {2.0f})); - auto scale_up = CreateOpAndInfer( - rewriter, loc, input_type, repeated, two, rewriter.getI8IntegerAttr(0)); - auto translate = - CreateOpAndInfer(rewriter, loc, input_type, scale_up, one); - - Value int_limit = rewriter.create( - loc, fp_scalar_ty, - DenseElementsAttr::get( - fp_scalar_ty, - {static_cast(std::numeric_limits::max())})); - auto int_scaled = - CreateOpAndInfer(rewriter, loc, input_type, translate, - int_limit, rewriter.getI8IntegerAttr(0)); - - auto int16_ty = input_type.clone(rewriter.getIntegerType(16)); - auto casted = - CreateOpAndInfer(rewriter, loc, int16_ty, int_scaled); - - // 4. Compute the lookup table using the range of [-255, 255] for sin. - llvm::SmallVector values; - const int num_values = 513; - values.resize(num_values, 0); - // First and last values should be 0; - for (int i = 1; i < num_values - 1; ++i) - values[i] = std::numeric_limits::max() * - sin(static_cast(i) * 2.0 * M_PI / (num_values - 1.0)); - - auto table_ty = - RankedTensorType::get({num_values}, rewriter.getIntegerType(16)); - Value table = rewriter.create( - loc, table_ty, DenseElementsAttr::get(table_ty, llvm::ArrayRef(values))); - - auto table_result_ty = input_type.clone(rewriter.getIntegerType(32)); - auto table_result = CreateOpAndInfer( - rewriter, loc, table_result_ty, casted, table); - - // 5. The range of table is a 23-bit two's compliment value. Normalize the - // range by casting to an fp32 and dividing by 2^22. - auto table_result_fp = - CreateOpAndInfer(rewriter, loc, input_type, table_result); - auto output_scale = rewriter.create( - loc, fp_scalar_ty, - DenseElementsAttr::get( - fp_scalar_ty, - {static_cast(1.0 / static_cast(1 << 22))})); - - return CreateOpAndInfer(rewriter, loc, output_type, table_result_fp, - output_scale, rewriter.getI8IntegerAttr(0)) - .getResult(); -} - // Lowers Sign operator to a sequence of TOSA ops. std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 20dbfb19d44702..cfe063408edea0 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -298,10 +298,6 @@ std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value on_value, Value off_value, int32_t depth, int32_t axis); -// Lowers 32-bit floating sin operator to a sequence of TOSA ops. -std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, - Value input, ShapedType output_type); - // Lowers Sign operator to a sequence of TOSA ops. std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 904394d370bcce..01a134bced0ca6 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -270,43 +270,28 @@ LogicalResult ConvertTFSignOp::matchAndRewrite( LogicalResult ConvertTFSinOp::matchAndRewrite(Operation* op, PatternRewriter& rewriter) const { auto tf_sin_op = cast(op); - ShapedType output_type = - mlir::cast(tf_sin_op.getResult().getType()); - std::optional result = - convertSinOp(rewriter, op, tf_sin_op.getX(), output_type); - if (!result) return failure(); + ShapedType output_type = dyn_cast(tf_sin_op.getResult().getType()); + if (!output_type) + return rewriter.notifyMatchFailure(op, "output_type required"); + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_sin_op.getX()); - rewriter.replaceOp(op, {result.value()}); return success(); } LogicalResult ConvertTFCosOp::matchAndRewrite(Operation* op, PatternRewriter& rewriter) const { auto tf_cos_op = cast(op); - Value input = tf_cos_op.getX(); - RankedTensorType input_ty = dyn_cast(input.getType()); - ShapedType output_ty = dyn_cast(tf_cos_op.getResult().getType()); - if (!input_ty || !output_ty) return failure(); - - bool input_is_fp = mlir::isa(input_ty.getElementType()); - bool output_is_fp = mlir::isa(output_ty.getElementType()); - - if (!input_is_fp || !output_is_fp) { - return rewriter.notifyMatchFailure( - op, "ConvertTFCosOp: input/result must be fp."); - } + ShapedType output_type = dyn_cast(tf_cos_op.getResult().getType()); + if (!output_type) + return rewriter.notifyMatchFailure(op, "output_type required"); - // Replace with the equivalent sin operation: - // cos(x) = sin(x + π / 2). - auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); - auto pi_2 = rewriter.create( - op->getLoc(), fp_scalar_ty, - DenseElementsAttr::get(fp_scalar_ty, {static_cast(M_PI_2)})); - auto offset = rewriter.create(op->getLoc(), input_ty, input, pi_2); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_cos_op.getX()); - CreateReplaceOpAndInfer(rewriter, op, output_ty, offset); return success(); } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index e6e7bc98e8d613..717603da479339 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -3305,42 +3305,29 @@ LogicalResult ConvertTFLHardSwishOp::matchAndRewrite( LogicalResult ConvertTFLSinOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_sin_op = cast(op); - auto input = tfl_sin_op.getX(); + ShapedType output_type = dyn_cast(tfl_sin_op.getResult().getType()); - std::optional result = convertSinOp(rewriter, op, input, output_type); - if (!result) return failure(); + if (!output_type) + return rewriter.notifyMatchFailure(op, "output_type required"); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_sin_op.getX()); - rewriter.replaceOp(op, {result.value()}); return success(); } LogicalResult ConvertTFLCosOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_cos_op = cast(op); - Value input = tfl_cos_op.getX(); - RankedTensorType input_ty = dyn_cast(input.getType()); - ShapedType output_ty = dyn_cast(tfl_cos_op.getResult().getType()); - - if (!input_ty || !output_ty) return failure(); - - bool input_is_fp = mlir::isa(input_ty.getElementType()); - bool output_is_fp = mlir::isa(output_ty.getElementType()); - if (!input_is_fp || !output_is_fp) { - return rewriter.notifyMatchFailure(op, "input/result must be fp"); - } - - // Replace with the equivalent sin operation: - // cos(x) = sin(x + π / 2). - auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); - auto pi_2 = rewriter.create( - op->getLoc(), fp_scalar_ty, - DenseElementsAttr::get(fp_scalar_ty, {static_cast(M_PI_2)})); - auto offset = rewriter.create(op->getLoc(), input_ty, input, pi_2); + ShapedType output_type = + dyn_cast(tfl_cos_op.getResult().getType()); + if (!output_type) + return rewriter.notifyMatchFailure(op, "output_type required"); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_cos_op.getX()); - CreateReplaceOpAndInfer(rewriter, op, output_ty, offset); return success(); } From 372aa0b045d7ee24f5659f13a4d473d7e0e80634 Mon Sep 17 00:00:00 2001 From: Juliana Franco Date: Thu, 19 Sep 2024 08:25:01 -0700 Subject: [PATCH 005/483] Skip axes of size 1 when building new TensorShardingAttr. Without this CL, the test added would result in a sharding including the axis of size 1 (i.e. "x") in the first dimension's sharding. PiperOrigin-RevId: 676423383 --- .../service/spmd/shardy/mhlo_round_trip/BUILD | 20 +++++ .../shardy/mhlo_round_trip/mhlo_import.cc | 3 + .../mhlo_round_trip/mhlo_import_test.cc | 76 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD index d2c8a4daa2318a..b8ab5769b90267 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD @@ -1,6 +1,7 @@ # Import/Export passes for going from `sdy.sharding`s to `mhlo.sharding`s and vice versa. load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -121,6 +122,25 @@ cc_library( ], ) +xla_cc_test( + name = "mhlo_import_test", + srcs = ["mhlo_import_test.cc"], + deps = [ + ":mhlo_import", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@shardy//shardy/dialect/sdy/ir:register", + ], +) + cc_library( name = "shard_map_import", srcs = ["shard_map_import.cc"], diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc index 02553d14e9a98b..50177a1268f9e0 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc @@ -441,6 +441,9 @@ TensorShardingAttr convertToSdySharding( // break it when we find common mesh axes. while (product < localAxisSize) { MeshAxisAttr axisAttr = globalMesh.getAxes()[globalAxisIndex++]; + if (axisAttr.getSize() == 1) { + continue; + } globalAxes.push_back(AxisRefAttr::get(ctx, axisAttr.getName())); product *= axisAttr.getSize(); } diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc new file mode 100644 index 00000000000000..f1635a804ab332 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h" + +#include + +#include +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/LLVM.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/register.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "tsl/platform/test.h" + +namespace mlir::sdy { + +namespace { + +TEST(MhloImportTest, SkipFirstAxisOfSize1) { + MLIRContext context; + loadAllRequiredDialects(&context); + SmallVector axes; + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "x", 1)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "y", 4)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "z", 2)); + auto mesh = sdy::MeshAttr::get(&context, axes); + + TensorShardingAttr sharding = xla::sdy::convertToSdySharding( + /*hloSharding=*/xla::HloSharding::IotaTile({4, 2}), + /*globalMesh=*/mesh, + /*deviceIdToMaximalMeshName=*/ + llvm::SmallDenseMap(), /*rank=*/2, + /*openDims=*/true); + EXPECT_EQ(attributeToString(sharding), + "#sdy.sharding<@mesh, [{\"y\", ?}, {\"z\", ?}]>"); +} + +// As above, but the middle axis is the one with size 1. +TEST(MhloImportTest, SkipSecondAxisOfSize1) { + MLIRContext context; + loadAllRequiredDialects(&context); + SmallVector axes; + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "y", 4)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "x", 1)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "z", 2)); + auto mesh = sdy::MeshAttr::get(&context, axes); + + TensorShardingAttr sharding = xla::sdy::convertToSdySharding( + /*hloSharding=*/xla::HloSharding::IotaTile({4, 2}), + /*globalMesh=*/mesh, + /*deviceIdToMaximalMeshName=*/ + llvm::SmallDenseMap(), /*rank=*/2, + /*openDims=*/true); + EXPECT_EQ(attributeToString(sharding), + "#sdy.sharding<@mesh, [{\"y\", ?}, {\"z\", ?}]>"); +} + +} // namespace +} // namespace mlir::sdy From e9b9ce36adcbf2856eda21aa77082c35cda85108 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 08:27:11 -0700 Subject: [PATCH 006/483] Integrate Triton up to [f4c48a92](https://github.com/openai/triton/commits/f4c48a9233957903e30474bae6443bf3d3a79bf7) PiperOrigin-RevId: 676424019 --- third_party/triton/workspace.bzl | 4 ++-- third_party/xla/third_party/triton/workspace.bzl | 4 ++-- .../service/gpu/fusions/triton/compilation_pipeline_rocm.cc | 2 ++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 8f68ab621acb1d..f952345e93b979 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl673813747" - TRITON_SHA256 = "3e901c1b441407b1b7ac601092f64a9141571879b00a1ff54437c8e9370a365f" + TRITON_COMMIT = "cl675928942" + TRITON_SHA256 = "4e31bfdd10d3e9c6277a47b9af9d64fc60a2cc1b81330da3cb7d01d938be1d36" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index 8f68ab621acb1d..f952345e93b979 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl673813747" - TRITON_SHA256 = "3e901c1b441407b1b7ac601092f64a9141571879b00a1ff54437c8e9370a365f" + TRITON_COMMIT = "cl675928942" + TRITON_SHA256 = "4e31bfdd10d3e9c6277a47b9af9d64fc60a2cc1b81330da3cb7d01d938be1d36" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 2a95ea833f4bcc..a48e65ab3a6953 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -96,6 +96,8 @@ absl::Status CreateTritonPipeline( if (block_level_parameters.num_stages != kAmdDoubleBuffering) { pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); } + pm.addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); + pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); From 746c9616698cd5e0eb1742edd587e3b38389e187 Mon Sep 17 00:00:00 2001 From: Matt Bahr Date: Thu, 19 Sep 2024 12:17:22 -0400 Subject: [PATCH 007/483] fix integer overflow in range function --- tensorflow/core/kernels/ragged_range_op.cc | 12 ++++++++++-- tensorflow/core/kernels/sequence_ops.cc | 12 ++++++++++-- tensorflow/core/ops/math_ops.cc | 12 ++++++++++-- tensorflow/python/ops/math_ops_test.py | 8 ++++++++ tensorflow/python/ops/ragged/ragged_range_op_test.py | 9 +++++++++ 5 files changed, 47 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/ragged_range_op.cc b/tensorflow/core/kernels/ragged_range_op.cc index 90c2060c33f386..f2c24534f3fd6e 100644 --- a/tensorflow/core/kernels/ragged_range_op.cc +++ b/tensorflow/core/kernels/ragged_range_op.cc @@ -87,8 +87,16 @@ class RaggedRangeOp : public OpKernel { size = 0; } else if constexpr (std::is_integral::value) { // The following is copied from tensorflow::RangeOp::Compute(). - size = Eigen::divup(Eigen::numext::abs(limit - start), - Eigen::numext::abs(delta)); + uint64_t range; + if ((limit > 0 && start < 0) || (limit < 0 && start > 0)) { + range = static_cast(Eigen::numext::abs(limit)) + + static_cast(Eigen::numext::abs(start)); + } else { + range = static_cast(Eigen::numext::abs(limit - start)); + } + + size = Eigen::divup(range, + static_cast(Eigen::numext::abs(delta))); } else { // The following is copied from tensorflow::RangeOp::Compute(). auto size_auto = diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index 5256db35a1f228..1fa09a0e09d54a 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -93,8 +93,16 @@ class RangeOp : public OpKernel { } int64_t size; if constexpr (std::is_integral::value) { - size = Eigen::divup(Eigen::numext::abs(limit - start), - Eigen::numext::abs(delta)); + uint64_t range; + if ((limit > 0 && start < 0) || (limit < 0 && start > 0)) { + range = static_cast(Eigen::numext::abs(limit)) + + static_cast(Eigen::numext::abs(start)); + } else { + range = static_cast(Eigen::numext::abs(limit - start)); + } + + size = Eigen::divup(range, + static_cast(Eigen::numext::abs(delta))); } else { auto size_auto = Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta)); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 25e4e373d5aace..65a98857cb0310 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1513,8 +1513,16 @@ Status RangeSize(const Tensor* start_t, const Tensor* limit_t, int64_t size; if (std::is_integral::value) { - size = Eigen::divup(static_cast(Eigen::numext::abs(limit - start)), - static_cast(Eigen::numext::abs(delta))); + uint64_t range; + if ((limit > 0 && start < 0) || (limit < 0 && start > 0)) { + range = static_cast(Eigen::numext::abs(limit)) + + static_cast(Eigen::numext::abs(start)); + } else { + range = static_cast(Eigen::numext::abs(limit - start)); + } + + size = Eigen::divup(range, + static_cast(Eigen::numext::abs(delta))); } else { auto size_auto = Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta)); diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index c2665986d18ab7..c3685de0c896d7 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -1360,6 +1360,14 @@ def testInputsNearInt64Max(self): self.assertAllEqual( (0,), self.evaluate(x)) # smallest input with potential overflow + def testInt32Overflow(self): + start = 1136033460 + end = -2110457150 + step = -1849827689 + expected = np.arange(start, end, step) + actual = math_ops.range(start, end, step) + self.assertAllEqual(expected, self.evaluate(actual)) + @test_util.run_all_in_graph_and_eager_modes class ErfcinvTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/ragged/ragged_range_op_test.py b/tensorflow/python/ops/ragged/ragged_range_op_test.py index c759b8254ac167..d3201b58c21275 100644 --- a/tensorflow/python/ops/ragged/ragged_range_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_range_op_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Tests for ragged_range op.""" +import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -129,6 +130,14 @@ def testShape(self): self.assertAllEqual( ragged_math_ops.range([1, 2, 3], [4, 5, 6]).shape.as_list(), [3, None]) + def testInt32Overflow(self): + start = 1136033460 + end = -2110457150 + step = -1849827689 + expected = [np.arange(start, end, step)] + actual = ragged_math_ops.range(start, end, step) + self.assertAllEqual(expected, self.evaluate(actual)) + if __name__ == '__main__': googletest.main() From 47be8a42d772e65db6f8572689b660a350ee95c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 09:14:31 -0700 Subject: [PATCH 008/483] Reverts bb1e54dd75230256ce943f22ce0e4a3113830fa8 PiperOrigin-RevId: 676440146 --- third_party/xla/xla/stream_executor/launch_dim.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/stream_executor/launch_dim.h b/third_party/xla/xla/stream_executor/launch_dim.h index f8a408bfc7211e..59b935c1ac7574 100644 --- a/third_party/xla/xla/stream_executor/launch_dim.h +++ b/third_party/xla/xla/stream_executor/launch_dim.h @@ -58,6 +58,10 @@ struct ThreadDim : internal::Dim3D { struct BlockDim : internal::Dim3D { explicit BlockDim(uint64_t x = 1, uint64_t y = 1, uint64_t z = 1) : internal::Dim3D({x, y, z}) {} + + std::string ToString() const { + return absl::StrCat("BlockDim{", x, ", ", y, ", ", z, "}"); + } }; // Cluster dimensionality for use in a kernel launch. From da6730e15a5800a2cd1b9433d1b124c9862efa3d Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 19 Sep 2024 09:23:12 -0700 Subject: [PATCH 009/483] [XLA:GPU] Fix crash in `SoftmaxRewriterTriton` when the HLO contains a broadcast from scalar. Example crash observed: ``` b = f32[64] broadcast(f32[] param), dimensions={} ``` PiperOrigin-RevId: 676443178 --- .../gpu/transforms/softmax_rewriter_triton.cc | 11 ++++++-- .../softmax_rewriter_triton_test.cc | 28 +++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index 31a1b7e30cbb56..00f00e7bba4f17 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -157,6 +157,10 @@ inline bool HasOneUse(const HloInstruction* instr) { // Unsupported case #4: // p = f32[a,b] parameter(0) // b = f32[a,x,b] broadcast(p), dimensions={0,2} +// +// Unsupported case #5: +// p = f32[] parameter(0) +// b = f32[x] broadcast(p), dimensions={} bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) { CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast) << "Expected broadcast " << hlo.ToShortString(); @@ -169,9 +173,10 @@ bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) { const HloParameterInstruction* parameter = Cast(hlo.operand(0)); - // Support only one dim broadcast. - if (parameter->shape().dimensions_size() + 1 != - broadcast->shape().dimensions_size()) { + // Support only one dim broadcast. Scalar parameters are handled elsewhere. + if (broadcast->dimensions().empty() || + parameter->shape().dimensions_size() + 1 != + broadcast->shape().dimensions_size()) { return false; } diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc index 80bed2552becf7..57052282cbaa4d 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc @@ -1620,6 +1620,34 @@ ENTRY main { } } +TEST_F(SoftmaxRewriterTritonTest, DoesNotCrashOnScalarBroadcast) { + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = f32[127,125]{1,0} parameter(0) + param_1 = f32[] parameter(1) + broadcast_from_scalar = f32[127] broadcast(param_1), dimensions={} + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + add = f32[127]{0} add(broadcast_from_scalar, reduce) + broadcast = f32[127,125]{1,0} broadcast(add), dimensions={0} + subtract = f32[127,125]{1,0} subtract(param_0, broadcast) + ROOT abs = f32[127,125]{1,0} abs(subtract) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig))); +} + } // anonymous namespace } // namespace gpu } // namespace xla From 784164302f709dc2369aa3159d4c8a8666094a6a Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Thu, 19 Sep 2024 10:02:52 -0700 Subject: [PATCH 010/483] Add custom kernel fusion to gemm fusion autotuner. The GemmFusionAutotuner currently takes a fusion and compares its runtime on different backends (Triton, CuBLAS and CuDNN). We add CustomKernelFusions (mostly Cutlass kernels) to the autotuner. PiperOrigin-RevId: 676458461 --- third_party/xla/xla/autotuning.proto | 7 +- .../xla/xla/service/gpu/autotuning/BUILD | 11 +- .../gpu/autotuning/gemm_fusion_autotuner.cc | 293 +++++++++++++----- .../gpu/autotuning/gemm_fusion_autotuner.h | 21 +- .../autotuning/gemm_fusion_autotuner_test.cc | 181 ++++++++++- 5 files changed, 423 insertions(+), 90 deletions(-) diff --git a/third_party/xla/xla/autotuning.proto b/third_party/xla/xla/autotuning.proto index a7ffcbb57ae6ef..4cadf6dbb250eb 100644 --- a/third_party/xla/xla/autotuning.proto +++ b/third_party/xla/xla/autotuning.proto @@ -83,6 +83,10 @@ message AutotuneResult { int64 num_ctas = 7; } + message CustomKernelFusionKey { + int64 kernel_index = 1; + } + int64 scratch_bytes = 8; google.protobuf.Duration run_time = 9; @@ -93,10 +97,11 @@ message AutotuneResult { GemmKey gemm = 6; TritonGemmKey triton = 17; CudaConvPlanKey cuda_conv_plan = 15; + CustomKernelFusionKey custom_kernel_fusion = 18; stream_executor.dnn.AlgorithmProto algorithm = 16; } - // Next ID: 17 + // Next ID: 19 } message AutotuningLog { diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index be63f3888442af..d162b1e8f0aded 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -45,9 +45,11 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/utils:hlo_query", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:algorithm_util", + "//xla/service:call_inliner", "//xla/service:dump", "//xla/service:executable", "//xla/service:float_normalization", @@ -58,12 +60,15 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:gpu_float_support", - "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:split_k_gemm_rewriter", "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/kernels:custom_kernel_fusion", + "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", "//xla/service/gpu/transforms:cudnn_fusion_compiler", + "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:priority_fusion", @@ -72,11 +77,9 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:bits", "//xla/tsl/util/proto:proto_utils", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -137,6 +140,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_executor_header", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 79524924584c97..8f041d6e8d27f2 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,7 +26,6 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -51,24 +49,28 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" +#include "xla/service/call_inliner.h" #include "xla/service/dump.h" -#include "xla/service/executable.h" #include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" -#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" +#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/priority_fusion.h" @@ -82,7 +84,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -140,76 +141,6 @@ constexpr std::array kNumCtas = {1, 2, 4, 8, 16}; using AutoTuneCacheKeyCount = absl::flat_hash_map; -class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { - public: - explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) - : config_(config) {} - - absl::Status HandleFusion(HloInstruction* hlo) override { - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); - FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - if (backend_config.kind() != kTritonGemmFusionKind && - backend_config.kind() != kCuDnnFusionKind) { - return absl::OkStatus(); - } - - VLOG(4) << "Processing " << hlo->ToString(); - if (!backend_config.has_triton_gemm_config() && - !backend_config.has_cudnn_fusion_config()) { - TF_ASSIGN_OR_RETURN( - AutotuneResult autotune_result, - AutotunerUtil::Autotune( - hlo, config_, [&]() -> absl::StatusOr { - if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( - "Expect autotune result cache hit for deviceless " - "compilation (HLO: ", - hlo->ToString(), ")")); - } - return absl::InternalError("Expect autotune result cache hit."); - })); - VLOG(4) << "Result: " << autotune_result.ShortDebugString(); - - if (autotune_result.has_triton()) { - *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } else if (autotune_result.has_gemm()) { - // Falling back to cuBLAS: Converting the fusion to a Call, so that it - // can be inlined back again. - HloComputation* const computation = hlo->parent(); - HloInstruction* const call = computation->AddInstruction( - HloInstruction::CreateCall(hlo->shape(), hlo->operands(), - hlo->fused_instructions_computation())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); - hlo = call; - } else { - CHECK(autotune_result.has_algorithm()); - backend_config.set_kind(std::string(kCuDnnFusionKind)); - backend_config.mutable_cudnn_fusion_config()->set_plan_id( - autotune_result.algorithm().algo_id()); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } - } - - if (backend_config.has_triton_gemm_config()) { - TF_ASSIGN_OR_RETURN( - const TritonGemmConfig config, - TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); - } - } - - MarkAsChanged(); - return absl::OkStatus(); - } - - private: - AutotuneConfig config_; -}; - class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { public: explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl) @@ -259,7 +190,9 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { bool missing_config = (backend_config.kind() == kTritonGemmFusionKind && !backend_config.has_triton_gemm_config()) || (backend_config.kind() == kCuDnnFusionKind && - !backend_config.has_cudnn_fusion_config()); + !backend_config.has_cudnn_fusion_config()) || + (backend_config.kind() == kCustomFusionKind && + !backend_config.has_custom_fusion_config()); if (missing_config) { if (error_out_on_cache_miss_) { return absl::NotFoundError(absl::StrCat( @@ -427,6 +360,46 @@ absl::StatusOr> CublasGemmAutotuneExtractor( return new_module; } +absl::Status UpdateFusionInstructionKernelIndex( + HloInstruction* fusion_instruction, int kernel_index) { + GpuBackendConfig gpu_config = + fusion_instruction->backend_config().value(); + gpu_config.mutable_fusion_backend_config() + ->mutable_custom_fusion_config() + ->set_kernel_index(kernel_index); + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config)); + + return absl::OkStatus(); +} + +absl::StatusOr> CustomFusionKernelAutotuneExtractor( + const GemmFusionAutotunerImpl::CustomKernelFusionConfig& cutlass_config, + const AutotuneConfig& config, const se::SemanticVersion& toolkit_version, + const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { + const HloComputation* fusion_computation = fusion->called_computation(); + std::unique_ptr new_module = + ExtractComputationIntoNewModule(*fusion_computation); + new_module->mutable_config().set_debug_options(debug_opts); + + CustomKernelFusionRewriter rewriter( + &config.GetExecutor()->GetDeviceDescription()); + PriorityFusion fusion_pass( + /*thread_pool=*/nullptr, config.GetExecutor()->GetDeviceDescription(), + PriorityFusionOptions()); + TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); + TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + + // Select custom kernel fusion kernel. + HloInstruction* custom_kernel_fusion = + hlo_query::GetFirstInstructionWithOpcode(*new_module->entry_computation(), + HloOpcode::kFusion); + int64_t kernel_index = cutlass_config.kernel_index; + TF_RETURN_IF_ERROR( + UpdateFusionInstructionKernelIndex(custom_kernel_fusion, kernel_index)); + + return new_module; +} + absl::StatusOr> FusionExtractor( const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); @@ -475,6 +448,11 @@ AutotuneResult FromConfig(const BackendConfig& config) { AutotuneResult res; if (std::holds_alternative(config)) { res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT); + } else if (std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config)) { + res.mutable_custom_kernel_fusion()->set_kernel_index( + std::get(config) + .kernel_index); } else if (std::holds_alternative( config)) { res.mutable_algorithm()->set_algo_id( @@ -574,6 +552,98 @@ std::string Serialize(const BackendConfig& config) { } // anonymous namespace +absl::Status RewriteGemmFusionToCall(HloInstruction* fusion_instr) { + // Falling back to cuBLAS: Converting the fusion to a Call, so that it + // can be inlined back again. + HloComputation* const computation = fusion_instr->parent(); + HloInstruction* const call = + computation->AddInstruction(HloInstruction::CreateCall( + fusion_instr->shape(), fusion_instr->operands(), + fusion_instr->fused_instructions_computation())); + return computation->ReplaceInstruction(fusion_instr, call); +} + +absl::Status RewriteGemmFusionToCustomKernelFusion( + HloInstruction* fusion_instr, se::DeviceDescription device_description, + int64_t kernel_index) { + // Rewrites gemm fusion to custom kernel fusion. + // First convert the fusion to a call. Then inlines the call. Then + // rewrites to custom kernel fusion. + HloComputation* const computation = fusion_instr->parent(); + HloInstruction* const call = + computation->AddInstruction(HloInstruction::CreateCall( + fusion_instr->shape(), fusion_instr->operands(), + fusion_instr->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(fusion_instr, call)); + HloPassPipeline pipeline("autotuner_custom_kernel_fusion_rewriter"); + pipeline.AddPass(); + pipeline.AddPass(&device_description, + kernel_index); + HloModule* hlo_module = call->GetModule(); + return pipeline.Run(hlo_module).status(); +} + +absl::Status GemmFusionAutotunerRewriterVisitor::HandleFusion( + HloInstruction* fusion_instr) { + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion_instr->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + if (backend_config.kind() != kTritonGemmFusionKind && + backend_config.kind() != kCuDnnFusionKind && + backend_config.kind() != kCustomFusionKind) { + return absl::OkStatus(); + } + + VLOG(4) << "Processing " << fusion_instr->ToString(); + if (!backend_config.has_triton_gemm_config() && + !backend_config.has_cudnn_fusion_config() && + !backend_config.has_custom_fusion_config()) { + TF_ASSIGN_OR_RETURN( + AutotuneResult autotune_result, + AutotunerUtil::Autotune( + fusion_instr, config_, [&]() -> absl::StatusOr { + if (config_.IsDeviceless()) { + return absl::InternalError(absl::StrCat( + "Expect autotune result cache hit for deviceless " + "compilation (HLO: ", + fusion_instr->ToString(), ")")); + } + return absl::InternalError("Expect autotune result cache hit."); + })); + VLOG(4) << "Result: " << autotune_result.ShortDebugString(); + + if (autotune_result.has_triton()) { + *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); + TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); + } else if (autotune_result.has_gemm()) { + TF_RETURN_IF_ERROR(RewriteGemmFusionToCall(fusion_instr)); + } else if (autotune_result.has_custom_kernel_fusion()) { + TF_RETURN_IF_ERROR(RewriteGemmFusionToCustomKernelFusion( + fusion_instr, config_.GetExecutor()->GetDeviceDescription(), + autotune_result.custom_kernel_fusion().kernel_index())); + } else { + CHECK(autotune_result.has_algorithm()); + backend_config.set_kind(std::string(kCuDnnFusionKind)); + backend_config.mutable_cudnn_fusion_config()->set_plan_id( + autotune_result.algorithm().algo_id()); + TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); + } + } + + if (backend_config.has_triton_gemm_config()) { + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(fusion_instr, config)); + } + } + + MarkAsChanged(); + return absl::OkStatus(); +} + // Methods required for sorting the configs. bool GemmFusionAutotunerImpl::CuBlasConfig::operator<( const CuBlasConfig& other) const { @@ -583,6 +653,10 @@ bool GemmFusionAutotunerImpl::CuDnnConfig::operator<( const CuDnnConfig& other) const { return plan_id < other.plan_id; } +bool GemmFusionAutotunerImpl::CustomKernelFusionConfig::operator<( + const CustomKernelFusionConfig& other) const { + return false; +} bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { return debug_options_.xla_gpu_autotune_level() > 0 && @@ -603,6 +677,48 @@ bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { } } +std::vector GenerateCustomKernelFusionConfigs( + const HloFusionInstruction& fusion, + se::DeviceDescription device_description) { + std::vector configs; + const CustomKernelFusionPatternRegistry* patterns = + CustomKernelFusionPatternRegistry::Default(); + HloComputation* computation = fusion.called_computation(); + // Get the first dot instruction in the fusion body. + HloInstruction* dot_instruction = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + std::vector match = + patterns->Match(device_description, dot_instruction); + + // For Cutlass we expect only one match for a gemm fusion. + if (match.size() == 1) { + CustomKernelFusionRegistry* registry = + CustomKernelFusionRegistry::Default(); + auto* custom_kernel_fusion = registry->Lookup(match[0].config().name()); + + // If custom fusion is not found it means that some of the build targets + // might not be statically linked into the binary. + if (custom_kernel_fusion != nullptr) { + // Load custom kernels that can implement a fusion computation. + absl::StatusOr> kernels = + custom_kernel_fusion->LoadKernels( + device_description, fusion.fused_instructions_computation()); + if (!kernels.ok()) { + VLOG(2) << "Skip custom kernel config. Failed to load custom kernels: " + << kernels.status(); + } else { + for (int i = 0; i < kernels.value().size(); ++i) { + GemmFusionAutotunerImpl::CustomKernelFusionConfig config{ + /*kernel_index=*/i}; + configs.push_back(config); + } + } + } + } + + return configs; +} + absl::StatusOr> GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { const HloDotInstruction* dot = @@ -642,6 +758,19 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { } } + // Add CustomKernelFusion (Cutlass) configs, if available. + // Go through all the instructions in the fusion body try to match them to + // a custom kernel fusion pattern. + if ((IsFusionKind(fusion, kCustomFusionKind) || + IsFusionKind(fusion, kTritonGemmFusionKind)) && + IsAutotuningEnabled() && !config_.IsDeviceless()) { + std::vector custom_kernel_fusion_configs = + GenerateCustomKernelFusionConfigs( + fusion, config_.GetExecutor()->GetDeviceDescription()); + configs.insert(configs.end(), custom_kernel_fusion_configs.begin(), + custom_kernel_fusion_configs.end()); + } + // Add triton configs. TF_ASSIGN_OR_RETURN(std::vector triton_configs, GenerateTritonConfigs(*dot)); @@ -805,6 +934,14 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, config_, config_.GetExecutor()->GetDeviceDescription(), toolkit_version_, fusion, opts); })); + } else if (std::holds_alternative(config)) { + TF_ASSIGN_OR_RETURN(executable, + compile_util.Compile([&](const DebugOptions& opts) { + return CustomFusionKernelAutotuneExtractor( + std::get(config), + config_, toolkit_version_, fusion, opts); + })); + } else { LOG(FATAL) << "Unsupported config type: " << config.index(); } @@ -1305,8 +1442,8 @@ absl::StatusOr GemmFusionAutotuner::Run( } } - return GemmFusionAutotunerVisitor(config_).RunOnModule(module, - execution_threads); + return GemmFusionAutotunerRewriterVisitor(config_).RunOnModule( + module, execution_threads); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index 7c262ffc8c613b..17272607532c20 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -29,7 +29,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" @@ -46,6 +48,18 @@ limitations under the License. namespace xla { namespace gpu { +// Uses profile results to rewrite a gemm fusion to use the best backend. +class GemmFusionAutotunerRewriterVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmFusionAutotunerRewriterVisitor(const AutotuneConfig& config) + : config_(config) {} + + absl::Status HandleFusion(HloInstruction* fusion_instr) override; + + private: + AutotuneConfig config_; +}; + // Takes a gemm fusion and chooses between cuBLAS, cuDNN, and Triton backends. // In the case of Triton, it also chooses the best tiling configuration. // @@ -99,8 +113,13 @@ class GemmFusionAutotunerImpl { int64_t plan_id; bool operator<(const CuDnnConfig& other) const; }; + struct CustomKernelFusionConfig { + int64_t kernel_index; + bool operator<(const CustomKernelFusionConfig& other) const; + }; using BackendConfig = - std::variant; + std::variant; using BackendConfigs = std::vector< std::pair>>; diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index f47003ecea4256..cb6309ffe9b8ea 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -50,7 +50,9 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/semantic_version.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" @@ -195,6 +197,25 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { .cuda_compute_capability(); } + absl::StatusOr> + GetPossibleMatmulAutotuneConfigs( + const HloFusionInstruction& fusion, + const se::CudaComputeCapability& compute_capability, + const se::SemanticVersion& toolkit_version, + const DebugOptions& debug_options) { + se::GpuDeviceInfoProto deviceless_proto; + auto ccc = deviceless_proto.mutable_cuda_compute_capability(); + ccc->set_major(compute_capability.major); + ccc->set_minor(compute_capability.minor); + + DeviceConfig test_config{backend().default_stream_executor(), + backend().memory_allocator()}; + AutotuneConfig autotune_config{test_config, debug_options}; + GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, + debug_options, nullptr); + return autotuner.GenerateConfigs(fusion); + } + void CheckTritonAutotuning(absl::string_view hlo, absl::string_view expected) { HloPassPipeline pipeline("gemm_rewrite"); @@ -247,7 +268,8 @@ class GemmFusionAutotunerTestWithMorePreciseReduction } }; -absl::StatusOr> GetPossibleMatmulAutotuneConfigs( +absl::StatusOr> +GetPossibleMatmulAutotuneTritonConfigs( const HloDotInstruction& dot, const se::CudaComputeCapability& compute_capability, const se::SemanticVersion& toolkit_version, @@ -276,7 +298,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -298,7 +320,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -320,7 +342,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -875,7 +897,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -907,7 +929,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -938,7 +960,7 @@ ENTRY wais { TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), debug_options)); @@ -1002,6 +1024,151 @@ ENTRY entry { CHECK_OK(autotuner.CompileAll(*compile_util, configs)); } +TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { + const std::string kHlo = R"( + HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} + + %gemm_fusion_r_computation { + %parameter_0 = bf16[1024,1024]{1,0} parameter(0) + %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) + %parameter_1 = bf16[1024,1024]{1,0} parameter(1) + %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) + ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = bf16[1024,1024]{1,0} parameter(0) + %p1 = bf16[1024,1024]{1,0} parameter(1) + ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + })"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, + IgnoreCustomKernelFusionConfigIfKernelNotFound) { + // There are cases where the custom kernel fusion pattern is matched, but + // the kernel is not found. Make sure that the autotuner ignores this case. + const std::string kHlo = R"( + HloModule module + + %gemm_fusion_r_computation (parameter_0.1: f32[1,256,4,4096], parameter_1.1: bf16[1,4,4096,4096]) -> bf16[1048576] { + %parameter_0.1 = f32[1,256,4,4096]{3,2,1,0} parameter(0) + %bitcast.60 = f32[256,16384]{1,0} bitcast(f32[1,256,4,4096]{3,2,1,0} %parameter_0.1) + %parameter_1.1 = bf16[1,4,4096,4096]{3,2,1,0} parameter(1) + %bitcast.61 = bf16[16384,4096]{1,0} bitcast(bf16[1,4,4096,4096]{3,2,1,0} %parameter_1.1) + %convert.22 = f32[16384,4096]{1,0} convert(bf16[16384,4096]{1,0} %bitcast.61) + %dot.5 = f32[256,4096]{1,0} dot(f32[256,16384]{1,0} %bitcast.60, f32[16384,4096]{1,0} %convert.22), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(_einsum)/jit(main)/dot_general" source_file="gdm/jax/tokamax/xla/utils.py" source_line=33} + %convert.23 = bf16[256,4096]{1,0} convert(f32[256,4096]{1,0} %dot.5), metadata={op_name="jit(_einsum)/jit(main)/convert_element_type" source_file="gdm/jax/tokamax/xla/utils.py" source_line=33} + %bitcast.62 = bf16[1,256,4096]{2,1,0} bitcast(bf16[256,4096]{1,0} %convert.23) + %transpose.18 = bf16[1,4096,256]{2,1,0} transpose(bf16[1,256,4096]{2,1,0} %bitcast.62), dimensions={0,2,1}, metadata={op_name="jit(_einsum)/jit(main)/convert_element_type" source_file="gdm/jax/tokamax/xla/utils.py" source_line=33} + ROOT %bitcast.63 = bf16[1048576]{0} bitcast(bf16[1,4096,256]{2,1,0} %transpose.18) + } + + ENTRY main { + %p0 = f32[1,256,4,4096] parameter(0) + %p1 = bf16[1,4,4096,4096] parameter(1) + ROOT %gemm_fusion_r = bf16[1048576] fusion(%p0, %p1), kind=kCustom, + calls=gemm_fusion_r_computation, + backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + EXPECT_TRUE(std::none_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, RewritesTritonFusionToCustomKernelFusion) { + const std::string kHlo = R"( + HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} + + %gemm_fusion_r_computation { + %parameter_0 = bf16[1024,1024]{1,0} parameter(0) + %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) + %parameter_1 = bf16[1024,1024]{1,0} parameter(1) + %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) + ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = bf16[1024,1024]{1,0} parameter(0) + %p1 = bf16[1024,1024]{1,0} parameter(1) + ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + + DebugOptions opts; + AutotuneConfig autotune_config{ + DeviceConfig{backend().default_stream_executor(), + backend().memory_allocator()}, + opts}; + AutotuneCacheKey cache_key(autotune_config.GetModelStr(), + *module->entry_computation()->root_instruction()); + TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override, + ParseTextProto(R"pb( + version: 3 + results { + device: "..." + hlo: "..." + result { + custom_kernel_fusion { kernel_index: 1 } + run_time { nanos: 14 } + } + })pb")); + autotune_results_override.mutable_results(0)->set_device( + std::string(cache_key.GetModelStr())); + autotune_results_override.mutable_results(0)->set_hlo( + std::string(cache_key.GetHlo())); + + GemmFusionAutotunerRewriterVisitor visitor(autotune_config); + + CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override)); + visitor.RunOnModule(module.get(), {}).value(); + std::string pattern = R"( + CHECK: ROOT %cutlass_gemm_with_upcast + CHECK-SAME: fusion + CHECK-SAME: kind=kCustom + CHECK-SAME: "kernel_index":1 + )"; + TF_ASSERT_OK_AND_ASSIGN(bool file_check_matches, + RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(file_check_matches); +} + } // namespace } // namespace gpu } // namespace xla From cac0ae5790ac5b87bb313bfa57d2fcea9c752372 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 10:04:41 -0700 Subject: [PATCH 011/483] Replace any ;s that the unique name the Tf2Xla rewriter creates for new nodes. PiperOrigin-RevId: 676459644 --- .../mlir/tf2xla/transforms/tf2xla_rewriter.cc | 21 ++++++++++++++++++- .../mlir/tf2xla/transforms/tf2xla_rewriter.h | 2 +- .../tf2xla/transforms/tf2xla_rewriter_test.cc | 16 ++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index 2709f9dada21a7..228df862f82d3c 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -94,6 +95,22 @@ using ::tensorflow::Tensor; using ::tsl::StatusOr; using ::xla::XlaComputation; +// The OpOrArgLocNameMapper adds invalid characters to the name of the op when +// concatenating locations. This version removes those characters to make the +// name valid for NodeDef. +class OpOrArgLocNameMapperWithoutInvalidCharacters + : public tensorflow::OpOrArgLocNameMapper { + public: + OpOrArgLocNameMapperWithoutInvalidCharacters() = default; + ~OpOrArgLocNameMapperWithoutInvalidCharacters() override = default; + + protected: + std::string GetName(tensorflow::OpOrVal op_or_val) override { + std::string name = OpOrArgLocNameMapper::GetName(op_or_val); + return absl::StrReplaceAll(name, {{";", "."}}); + } +}; + static std::unique_ptr CreateDeviceMgr( const std::string& device_type) { // Register compilation kernels for all registered XLA backends. @@ -125,6 +142,8 @@ Tf2XlaRewriter::Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter, : op_(op), device_type_(device_type), rewriter_(rewriter), + name_mapper_( + std::make_unique()), context_(nullptr), xla_builder_(op_->getName().getStringRef().str()) {} @@ -319,7 +338,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { } auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef( - op_, name_mapper_.GetUniqueName(op_), + op_, name_mapper_->GetUniqueName(op_), /*ignore_unregistered_attrs=*/true); if (!nodedef_or.ok()) { return op_->emitRemark() << "failed to convert op to NodeDef: " diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h index 2b8c52750a6c44..7fcf0bafafb0a3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -106,7 +106,7 @@ class Tf2XlaRewriter { std::string device_type_; mlir::PatternRewriter& rewriter_; - tensorflow::OpOrArgLocNameMapper name_mapper_; + std::unique_ptr name_mapper_; tensorflow::XlaContext* context_; // Ref-counted. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index 8d0f0404b8980f..46da448fe8f301 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -317,6 +317,22 @@ TEST_F(Tf2XlaRewriterTest, CreatesDefaultValues) { TF_ASSERT_OK(LegalizeModule(kModuleWithOpWithoutValuesThatShouldBeDefaulted)); } +TEST_F(Tf2XlaRewriterTest, OpWithLocationDoesntBreakNodeDefName) { + // A named location 'Name(Source)' causes the GetNameFromLoc method to append + // all the other locations to the name with a ';' separator. This test ensures + // that the name used for the NodeDef does not contain that invalid character. + static constexpr char kModuleWithOpWithoutValuesThatShouldBeDefaulted[] = + R"mlir( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1610 : i32}} { + func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> loc(fused["exp"("exp"), "exp"]) + func.return %0 : tensor<2xf32> + } + })mlir"; + + TF_ASSERT_OK(LegalizeModule(kModuleWithOpWithoutValuesThatShouldBeDefaulted)); +} + TEST_F(Tf2XlaRewriterTest, ErrorsWithInvalidNumberOfParametersToArgs) { XlaBuilder builder("test_builder"); XlaComputation to_apply; From 1beccb733259cf9430152559e534c5be8f444ef7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 10:08:30 -0700 Subject: [PATCH 012/483] Reduce the set of strategies generated for scatter ops. PiperOrigin-RevId: 676461609 --- .../auto_sharding/auto_sharding_strategy.cc | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 91d40860c73255..c20cd9af5b3680 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -139,32 +139,37 @@ ComputeSliceShardingAndCommunicationCostFromOperand( // the original implementation), but it should be easy to generalize if needed. void GenerateScatterShardingFromOperands( const HloScatterInstruction* scatter, const HloSharding& data_sharding, - const HloSharding& indices_sharding, const HloSharding& update_sharding, - const CallGraph& call_graph, + const HloSharding& update_sharding, const CallGraph& call_graph, absl::FunctionRef yield_sharding) { + absl::flat_hash_set scatter_shardings; CHECK_EQ(scatter->scatter_operand_count(), 1); const HloInstruction* scatter_data = scatter->scatter_operands()[0]; const HloInstruction* scatter_indices = scatter->scatter_indices(); const HloInstruction* scatter_update = scatter->scatter_updates()[0]; - yield_sharding(data_sharding, indices_sharding, update_sharding, - data_sharding); + const HloSharding& indices_sharding = hlo_sharding_util:: + ScatterIndexShardingFromUpdateIndexPassthroughDimensions(update_sharding, + scatter); + scatter_shardings.insert(data_sharding); if (std::optional maybe_from_update = hlo_sharding_util::ScatterOutputShardingFromUpdate(update_sharding, *scatter)) { - yield_sharding(data_sharding, indices_sharding, update_sharding, - *maybe_from_update); + scatter_shardings.insert(*maybe_from_update); } std::optional scatter_parallel_dims = hlo_sharding_util::GetScatterParallelBatchDims(*scatter, call_graph); if (!scatter_parallel_dims) { + for (const HloSharding& sharding : scatter_shardings) { + yield_sharding(data_sharding, indices_sharding, update_sharding, + sharding); + } return; } @@ -178,29 +183,30 @@ void GenerateScatterShardingFromOperands( aligned_operand_parallel_dims; // Infer output sharding from scatter operand sharding. const Shape& shape = scatter->shape(); - yield_sharding( - data_sharding, indices_sharding, update_sharding, + scatter_shardings.insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( data_sharding, scatter_data->shape(), shape, absl::MakeConstSpan(aligned_operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); // Infer output sharding from scatter indices sharding. - HloSharding parallel_sharding_from_indices = + scatter_shardings.insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( indices_sharding, scatter_indices->shape(), shape, absl::MakeConstSpan(scatter_parallel_dims->indices_parallel_dims), - absl::MakeConstSpan(output_parallel_dims)); - yield_sharding(data_sharding, indices_sharding, update_sharding, - parallel_sharding_from_indices); + absl::MakeConstSpan(output_parallel_dims))); // Infer output sharding from scatter update sharding. - yield_sharding( - data_sharding, indices_sharding, update_sharding, + scatter_shardings.insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( update_sharding, scatter_update->shape(), shape, absl::MakeConstSpan(update_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); + + for (const HloSharding& scatter_sharding : scatter_shardings) { + yield_sharding(data_sharding, indices_sharding, update_sharding, + scatter_sharding); + } } // NOLINTBEGIN(readability/fn_size) @@ -359,18 +365,15 @@ BuildStrategyAndCost( const HloScatterInstruction* scatter = Cast(ins); const HloInstruction* scatter_data = scatter->scatter_operands()[0]; - const HloInstruction* scatter_indices = scatter->scatter_indices(); const HloInstruction* scatter_update = scatter->scatter_updates()[0]; ForEachInCartesianProduct( {strategy_map.at(scatter_data)->GetStrategies(), - strategy_map.at(scatter_indices)->GetStrategies(), strategy_map.at(scatter_update)->GetStrategies()}, [&](const std::vector& operand_shardings) { GenerateScatterShardingFromOperands( scatter, operand_shardings[0].output_sharding, - operand_shardings[1].output_sharding, - operand_shardings[2].output_sharding, call_graph, + operand_shardings[1].output_sharding, call_graph, add_scatter_sharding); }); From d8f20aebe72c9307289ac8c0ccac74c21240e58f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 19 Sep 2024 10:16:34 -0700 Subject: [PATCH 013/483] [JAX] Switch host_callback to use MLIR lowering instead of the older direct HLO translation rules. Change in preparation for removing XlaBuilder from Python bindings. PiperOrigin-RevId: 676465019 --- third_party/xla/xla/python/BUILD | 2 + .../xla/xla/python/outfeed_receiver.cc | 48 ++++++++++++------- third_party/xla/xla/python/outfeed_receiver.h | 2 + .../xla/xla/python/outfeed_receiver_py.cc | 24 ++++++++++ third_party/xla/xla/python/xla_client.py | 2 +- .../python/xla_extension/outfeed_receiver.pyi | 7 +++ 6 files changed, 68 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 011900dbe17b98..6f60b9242747b4 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -707,6 +707,7 @@ cc_library( "//xla/pjrt:pjrt_executable", "//xla/python/pjrt_ifrt", "//xla/service:computation_placer_hdr", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -864,6 +865,7 @@ cc_library( "@llvm-project//llvm:Support", "@nanobind", "//xla:literal", + "//xla:shape_util", "//xla/client:executable_build_options", "//xla/client:xla_builder", "//xla/pjrt:status_casters", diff --git a/third_party/xla/xla/python/outfeed_receiver.cc b/third_party/xla/xla/python/outfeed_receiver.cc index 539d1f2df6c308..63a5a03aff237d 100644 --- a/third_party/xla/xla/python/outfeed_receiver.cc +++ b/third_party/xla/xla/python/outfeed_receiver.cc @@ -196,6 +196,8 @@ class OutfeedReceiverImpl { std::vector arrays, uint32_t device_idx); + absl::Status RegisterOutfeed(uint32_t consumer_id, const Shape& shape); + private: bool CallbackQueueHasSpace() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { return callback_queue_size_bytes_ < max_callback_queue_size_bytes_; @@ -465,34 +467,39 @@ absl::Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { return absl::OkStatus(); } -absl::StatusOr OutfeedReceiverImpl::AddOutfeedToBuilder( - XlaBuilder* builder, XlaOp token, uint32_t consumer_id, - std::vector arrays, uint32_t device_idx) { - XlaOp data = Tuple(builder, std::move(arrays)); - Shape shape_with_layout = builder->GetShape(data).value(); - ShapeUtil::ForEachMutableSubshape( - &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { - if (!subshape->has_layout()) { - LayoutUtil::SetToDefaultLayout(subshape); - } - }); +absl::Status OutfeedReceiverImpl::RegisterOutfeed(uint32_t consumer_id, + const Shape& shape) { VLOG(2) << "RegisterShape cons=" << consumer_id - << "; shape=" << shape_with_layout.ToString(); + << "; shape=" << shape.ToString(); { absl::MutexLock lock(&mu_); auto found = shape_registry_.find(consumer_id); if (found != shape_registry_.end()) { - if (!ShapeUtil::Equal(shape_with_layout, found->second)) { + if (!ShapeUtil::Equal(shape, found->second)) { return InvalidArgument( "Shape %s does not match previous shape %s used " "for consumer id %d", - shape_with_layout.DebugString(), found->second.DebugString(), - consumer_id); + shape.DebugString(), found->second.DebugString(), consumer_id); } } else { - shape_registry_.insert({consumer_id, shape_with_layout}); + shape_registry_.insert({consumer_id, shape}); } } + return absl::OkStatus(); +} + +absl::StatusOr OutfeedReceiverImpl::AddOutfeedToBuilder( + XlaBuilder* builder, XlaOp token, uint32_t consumer_id, + std::vector arrays, uint32_t device_idx) { + XlaOp data = Tuple(builder, std::move(arrays)); + Shape shape_with_layout = builder->GetShape(data).value(); + ShapeUtil::ForEachMutableSubshape( + &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + TF_RETURN_IF_ERROR(RegisterOutfeed(consumer_id, shape_with_layout)); std::vector header{kOutfeedHeaderStart, consumer_id}; XlaOp header_op = ConstantR1(builder, header); @@ -532,4 +539,13 @@ absl::StatusOr OutfeedReceiver::AddOutfeedToBuilder( device_idx); } +absl::Status OutfeedReceiver::RegisterOutfeed(uint32_t consumer_id, + const Shape& shape) { + if (consumer_id == kOutfeedCidShutdown) { + return InvalidArgument("Consumer ID cannot be a reserved value: %d", + consumer_id); + } + return p_impl_->RegisterOutfeed(consumer_id, shape); +} + } // namespace xla diff --git a/third_party/xla/xla/python/outfeed_receiver.h b/third_party/xla/xla/python/outfeed_receiver.h index a9f47280c56ee6..8330dbb2d3de12 100644 --- a/third_party/xla/xla/python/outfeed_receiver.h +++ b/third_party/xla/xla/python/outfeed_receiver.h @@ -75,6 +75,8 @@ class OutfeedReceiver { std::vector arrays, uint32_t device_idx); + absl::Status RegisterOutfeed(uint32_t consumer_id, const Shape& shape); + private: std::unique_ptr p_impl_; }; diff --git a/third_party/xla/xla/python/outfeed_receiver_py.cc b/third_party/xla/xla/python/outfeed_receiver_py.cc index e5ce3f9e6bf9df..81a327bda1aaa8 100644 --- a/third_party/xla/xla/python/outfeed_receiver_py.cc +++ b/third_party/xla/xla/python/outfeed_receiver_py.cc @@ -35,6 +35,7 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/client/executable_build_options.h" #include "xla/client/xla_builder.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" @@ -43,6 +44,8 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/py_client.h" #include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "tsl/platform/logging.h" namespace xla { @@ -126,6 +129,18 @@ class OutfeedReceiverForPython { arrays, device_idx); } + absl::Status RegisterOutfeed(uint32_t consumer_id, + const std::vector& shapes) { + Shape shape = ShapeUtil::MakeTupleShape(shapes); + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + return outfeed_receiver_->RegisterOutfeed(consumer_id, shape); + } + void Callback(ifrt::Device* device, uint32_t consumer_id, std::shared_ptr literal) { { @@ -210,6 +225,15 @@ void BuildOutfeedReceiverSubmodule(nb::module_& m) { ID. Returns error if the outfeed shape is not compatible with previously used shape for the same consumer ID.)", nb::call_guard()); + + outfeed_receiver_class.def( + "register_outfeed", + xla::ThrowIfErrorWrapper(&OutfeedReceiverForPython::RegisterOutfeed), + nb::arg("consumer_id"), nb::arg("shapes"), + R"(Registers the sent shape along with the consumer + ID. Returns error if the outfeed shape is not compatible with previously + used shape for the same consumer ID.)", + nb::call_guard()); } } // namespace xla diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 89332de94b0b82..bcb7d7a96fd249 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 286 +_version = 287 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/third_party/xla/xla/python/xla_extension/outfeed_receiver.pyi b/third_party/xla/xla/python/xla_extension/outfeed_receiver.pyi index b0850355de65a3..ba41631cfbb121 100644 --- a/third_party/xla/xla/python/xla_extension/outfeed_receiver.pyi +++ b/third_party/xla/xla/python/xla_extension/outfeed_receiver.pyi @@ -19,6 +19,7 @@ from typing import Any, Optional, Sequence from xla.python import xla_extension Client = xla_extension.Client +Shape = xla_extension.Shape XlaBuilder = xla_extension.XlaBuilder XlaOp = xla_extension.XlaOp @@ -44,3 +45,9 @@ class OutfeedReceiverForPython: device_idx: int, ) -> XlaOp: ... + + def register_outfeed( + consumer_id: int, + shapes: Sequence[Shape], + ) -> None: + ... From 9b4fae433241e2ddb9ebbe7706a4e4b02d1f8d44 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Thu, 19 Sep 2024 10:28:41 -0700 Subject: [PATCH 014/483] Disable `xnn_enable_avxvnniint8` for Android. This is only supported on the very latest compilers at the moment. PiperOrigin-RevId: 676469843 --- .bazelrc | 2 ++ third_party/xla/.bazelrc | 2 ++ third_party/xla/third_party/tsl/.bazelrc | 2 ++ 3 files changed, 6 insertions(+) diff --git a/.bazelrc b/.bazelrc index fcbc9ff2772db4..8201ce4582a00f 100644 --- a/.bazelrc +++ b/.bazelrc @@ -150,6 +150,8 @@ build:android_x86_64 --fat_apk_cpu=x86_64 # Build everything statically for Android since all static libs are later # bundled together into a single .so for deployment. build:android --dynamic_mode=off +# TODO(belitskiy): Remove once on Clang 20. +build:android --define=xnn_enable_avxvnniint8=false # Sets the default Apple platform to macOS. build:macos --apple_platform_type=macos diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index fcbc9ff2772db4..8201ce4582a00f 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -150,6 +150,8 @@ build:android_x86_64 --fat_apk_cpu=x86_64 # Build everything statically for Android since all static libs are later # bundled together into a single .so for deployment. build:android --dynamic_mode=off +# TODO(belitskiy): Remove once on Clang 20. +build:android --define=xnn_enable_avxvnniint8=false # Sets the default Apple platform to macOS. build:macos --apple_platform_type=macos diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index fcbc9ff2772db4..8201ce4582a00f 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -150,6 +150,8 @@ build:android_x86_64 --fat_apk_cpu=x86_64 # Build everything statically for Android since all static libs are later # bundled together into a single .so for deployment. build:android --dynamic_mode=off +# TODO(belitskiy): Remove once on Clang 20. +build:android --define=xnn_enable_avxvnniint8=false # Sets the default Apple platform to macOS. build:macos --apple_platform_type=macos From 0b43808f90708849d0b7d6f141f85a81598e9740 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 19 Sep 2024 10:44:32 -0700 Subject: [PATCH 015/483] [xla:cpu] Add support for 17 sort inputs. Fixes https://github.com/google/jax/issues/23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676475983 --- .../xla/backends/cpu/runtime/sort_thunk.cc | 3 ++ third_party/xla/xla/tests/BUILD | 2 + third_party/xla/xla/tests/sort_test.cc | 52 +++++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc index 8d2df6f298cbcf..fa892d2df54134 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc @@ -468,6 +468,9 @@ static absl::Status SortInplace(absl::Span data, case 16: sort(std::integral_constant{}); break; + case 17: + sort(std::integral_constant{}); + break; case 25: sort(std::integral_constant{}); break; diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 16d3b36efcbe4f..16f1f8e8f07c40 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1803,6 +1803,8 @@ xla_test( ":test_macros_header", ":xla_internal_test_main", "//xla:error_spec", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/tests/sort_test.cc b/third_party/xla/xla/tests/sort_test.cc index b832dbdd0df0d5..3acef7e48170b7 100644 --- a/third_party/xla/xla/tests/sort_test.cc +++ b/third_party/xla/xla/tests/sort_test.cc @@ -13,9 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" #include "xla/error_spec.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" @@ -85,5 +91,51 @@ XLA_TEST_F(SortTest, SortTwiceWithSameComparator) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } +// TODO(penporn): Parameterize `num_inputs` and test several numbers when we +// have a generic fallback sort kernel. +XLA_TEST_F(SortTest, SortManyInputs) { + constexpr int num_inputs = 17; + std::string_view hlo_text_module_template = R"( + HloModule sort + + compare { + ${COMPARE_DECLARATIONS} + ROOT lt = pred[] compare(p0, p1), direction=LT + } + + ENTRY e { + ${SORT_DECLARATIONS} + ROOT sort = (${SORT_SHAPE}) sort(${SORT_PARAMS}), dimensions={0}, + to_apply=compare + } + )"; + + // Prepare values for template substitutions. + std::string sort_decls = ""; + std::vector param_names; + param_names.reserve(num_inputs * 2); + for (int i = 0; i < num_inputs; ++i) { + sort_decls += absl::StrFormat("p%d = f32[32,64] parameter(%d)\n", i, i); + param_names.emplace_back(absl::StrCat("p", i)); + } + std::string sort_params = absl::StrJoin(param_names, ", "); + std::string sort_shape = + absl::StrJoin(std::vector(num_inputs, "f32[32,64]"), ","); + std::string compare_decls = ""; + for (int i = 0; i < num_inputs * 2; ++i) { + compare_decls += absl::StrFormat("p%d = f32[] parameter(%d)\n", i, i); + } + std::string compare_params = absl::StrJoin(param_names, ", "); + + // Finalize HLO text. + std::string hlo_text_module = absl::StrReplaceAll( + hlo_text_module_template, {{"${SORT_DECLARATIONS}", sort_decls}, + {"${SORT_SHAPE}", sort_shape}, + {"${SORT_PARAMS}", sort_params}, + {"${COMPARE_DECLARATIONS}", compare_decls}}); + + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); +} + } // namespace } // namespace xla From e081db68a13da5cb86b6bb14e5856866b45321f5 Mon Sep 17 00:00:00 2001 From: Joshua Lang Date: Thu, 19 Sep 2024 10:45:13 -0700 Subject: [PATCH 016/483] [xla:gpu] Enable Conditions in cuda graphs by default PiperOrigin-RevId: 676476297 --- third_party/xla/xla/debug_options_flags.cc | 1 + .../xla/service/gpu/transforms/command_buffer_scheduling.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index f651949a82a14d..3af22151ea6d6c 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -116,6 +116,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS); opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL); opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); + opts.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS); opts.set_xla_gpu_graph_min_graph_size(5); opts.set_xla_gpu_graph_enable_concurrent_region(false); opts.set_xla_cmd_buffer_trace_cache_size(16); diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index c772ef052b5a1a..3d8f11dd0dc7b6 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -768,6 +768,11 @@ absl::StatusOr CommandBufferScheduling::Run( device_description_.driver_version()) < se::SemanticVersion{12, 3, 0}) { erase(kRequireTracing); // cuStreamBeginCaptureToGraph + } + if (std::min(device_description_.runtime_version(), + device_description_.driver_version()) < + se::SemanticVersion{12, 4, 0}) { + // Conditionals With Memsets require cuda 12.4.1. erase(kRequireConditionals); // on-device control flow } }; From e4facc9c959047916a2802695afda08b5aa2c904 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 11:26:01 -0700 Subject: [PATCH 017/483] Remove some unused code that forces auto-sharding to generate data-parallel strategies. PiperOrigin-RevId: 676493922 --- .../auto_sharding/auto_sharding.cc | 161 +++--------------- .../auto_sharding/auto_sharding.h | 19 +-- .../auto_sharding_dot_handler.cc | 47 ++--- .../auto_sharding/auto_sharding_option.cc | 2 - .../auto_sharding/auto_sharding_option.h | 5 - .../auto_sharding/auto_sharding_strategy.cc | 49 +++--- .../auto_sharding/cluster_environment.cc | 12 -- 7 files changed, 63 insertions(+), 232 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index ec9dc5f0916c56..66370ab0020b32 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -959,8 +959,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, void EnumerateAllPartition( const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - const InstructionBatchDimMap& batch_dim_map, bool only_allow_divisible, + const StrategyMap& strategy_map, bool only_allow_divisible, bool allow_shardings_small_dims_across_many_devices, const CallGraph& call_graph, const int64_t partition_dimensions, const std::vector& tensor_dims, StrategyGroup& strategy_group) { @@ -971,15 +970,10 @@ void EnumerateAllPartition( strategy_group); return; } - auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); - int64_t batch_dim = -1; - if (iter != batch_dim_map.end()) { - batch_dim = iter->second; - } // Fully tile the buffer to the mesh for (int64_t i = 0; i < shape.rank(); ++i) { auto tensor_it = std::find(tensor_dims.begin(), tensor_dims.end(), i); - if ((batch_dim != -1 && batch_dim != i) || tensor_it != tensor_dims.end()) { + if (tensor_it != tensor_dims.end()) { continue; } if (!allow_shardings_small_dims_across_many_devices && @@ -993,7 +987,7 @@ void EnumerateAllPartition( std::vector next_tensor_dims = tensor_dims; next_tensor_dims.push_back(i); EnumerateAllPartition( - ins, shape, device_mesh, cluster_env, strategy_map, batch_dim_map, + ins, shape, device_mesh, cluster_env, strategy_map, only_allow_divisible, allow_shardings_small_dims_across_many_devices, call_graph, partition_dimensions, next_tensor_dims, strategy_group); } @@ -1127,7 +1121,6 @@ void BuildStrategyAndCostForReshape(const HloInstruction* ins, void EnumeratePartitionReshape( const HloInstruction* ins, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - const InstructionBatchDimMap& batch_dim_map, const bool only_allow_divisible, const int64_t partition_dimensions, const std::vector& tensor_dims, StrategyGroup& strategy_group) { const auto tensor_dims_size = tensor_dims.size(); @@ -1136,16 +1129,10 @@ void EnumeratePartitionReshape( tensor_dims, strategy_group); return; } - auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); - int64_t batch_dim = -1; - if (iter != batch_dim_map.end()) { - batch_dim = iter->second; - } - // Split batch dim + another dim for (int64_t i = 0; i < ins->shape().rank(); ++i) { auto tensor_it = std::find(tensor_dims.begin(), tensor_dims.end(), i); - if ((batch_dim != -1 && batch_dim != i) || tensor_it != tensor_dims.end()) { + if (tensor_it != tensor_dims.end()) { continue; } if (ins->shape().dimensions(i) < device_mesh.dim(tensor_dims_size)) { @@ -1160,9 +1147,8 @@ void EnumeratePartitionReshape( std::vector next_tensor_dims = tensor_dims; next_tensor_dims.push_back(i); EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - batch_dim_map, only_allow_divisible, - partition_dimensions, next_tensor_dims, - strategy_group); + only_allow_divisible, partition_dimensions, + next_tensor_dims, strategy_group); } } @@ -1290,40 +1276,12 @@ bool AllowTieFollowing(const HloInstruction* ins) { return true; } -// 1. Disable mixed mesh shape if the batch dim is not divisible by the -// number of devices. -// 2. Disable force_batch_dim_to_mesh_dim if the batch dim is 1. In this case, -// the batch dim analysis can be wrong because the batch dim might be dropped. -void DisableIncompatibleMixedMeshShapeAndForceBatchDim( - const InstructionBatchDimMap& batch_dim_map, - const std::vector& instructions, int num_devices, - AutoShardingOption& option) { - int64_t batch_size = INT_MAX; - for (const auto& iter : batch_dim_map) { - batch_size = std::min(batch_size, FindInstruction(instructions, iter.first) - ->shape() - .dimensions(iter.second)); - } - - if (IsDivisible(batch_size, num_devices)) { - if (option.allow_mixed_mesh_shape) { - option.allow_mixed_mesh_shape = false; - LOG(WARNING) - << "Mixed mesh shape is disabled due to indivisible batch size."; - } - } - - if (batch_size == 1) { - option.force_batch_dim_to_mesh_dim = -1; - } -} - void FillAllStrategiesForArray( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, const double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, - const bool only_allow_divisible, const bool create_replicated_strategies, + const CallGraph& call_graph, const bool only_allow_divisible, + const bool create_replicated_strategies, const bool create_partially_replicated_strategies, StrategyGroup& strategy_group) { if (create_partially_replicated_strategies || cluster_env.IsDeviceMesh1D()) { @@ -1336,7 +1294,7 @@ void FillAllStrategiesForArray( // Split 2 dims if (cluster_env.IsDeviceMesh2D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, batch_dim_map, only_allow_divisible, + strategy_map, only_allow_divisible, option.allow_shardings_small_dims_across_many_devices, call_graph, /*partitions*/ 2, /*tensor_dims*/ {}, strategy_group); @@ -1344,7 +1302,7 @@ void FillAllStrategiesForArray( // Split 3 dims if (cluster_env.IsDeviceMesh3D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, batch_dim_map, only_allow_divisible, + strategy_map, only_allow_divisible, option.allow_shardings_small_dims_across_many_devices, call_graph, /*partitions*/ 3, /*tensor_dims*/ {}, strategy_group); @@ -1367,22 +1325,13 @@ void FillAllStrategiesForArray( AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, replicated_penalty, {}, strategy_group); } - - // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies - // and only keep the data parallel strategies. - if (option.force_batch_dim_to_mesh_dim >= 0 && - batch_dim_map.contains(GetBatchDimMapKey(ins))) { - CHECK_OK(FilterStrategy(ins, shape, cluster_env, batch_dim_map, option, - strategy_group)); - } } absl::StatusOr> CreateAllStrategiesGroup( const HloInstruction* ins, const Shape& shape, const size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, - const double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, + const double replicated_penalty, const CallGraph& call_graph, const bool only_allow_divisible, const bool create_replicated_strategies, const bool create_partially_replicated_strategies) { std::unique_ptr strategy_group; @@ -1390,12 +1339,11 @@ absl::StatusOr> CreateAllStrategiesGroup( strategy_group = CreateTupleStrategyGroup(instruction_id); for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) { auto child_strategies = - CreateAllStrategiesGroup(ins, shape.tuple_shapes(i), instruction_id, - strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - create_replicated_strategies, - create_partially_replicated_strategies) + CreateAllStrategiesGroup( + ins, shape.tuple_shapes(i), instruction_id, strategy_groups, + cluster_env, strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, create_replicated_strategies, + create_partially_replicated_strategies) .value(); child_strategies->tuple_element_idx = i; strategy_group->AddChild(std::move(child_strategies)); @@ -1405,9 +1353,8 @@ absl::StatusOr> CreateAllStrategiesGroup( strategy_groups); FillAllStrategiesForArray( ins, shape, cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, - create_replicated_strategies, create_partially_replicated_strategies, - *strategy_group); + call_graph, only_allow_divisible, create_replicated_strategies, + create_partially_replicated_strategies, *strategy_group); } else if (shape.IsToken()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); @@ -1820,7 +1767,6 @@ std::unique_ptr CreateReshapeStrategies( const size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const bool only_allow_divisible, const double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const AutoShardingOption& option, StrategyGroups& strategy_groups, const CallGraph& call_graph) { const DeviceMesh& device_mesh = cluster_env.device_mesh_; @@ -1881,7 +1827,7 @@ std::unique_ptr CreateReshapeStrategies( VLOG(2) << "Enumerating all strategies for reshape"; FillAllStrategiesForArray( ins, ins->shape(), cluster_env, strategy_map, option, - replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, + replicated_penalty, call_graph, only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true, *strategy_group); } @@ -3235,58 +3181,6 @@ absl::Status GenerateReduceScatter( return absl::OkStatus(); } -// Filter strategies according to the option.force_batch_dim_to_mesh_dim. -// This can be used to forcibly generate data-parallel strategies. -absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, - const AutoShardingOption& option, - StrategyGroup& strategy_group) { - int mesh_dim = option.force_batch_dim_to_mesh_dim; - int batch_dim = batch_map.at(GetBatchDimMapKey(ins)); - const DeviceMesh& device_mesh = cluster_env.device_mesh_; - - if (shape.dimensions(batch_dim) % device_mesh.dim(mesh_dim) != 0) { - return absl::InvalidArgumentError( - "The length of batch dimension is " - "not divisible by the number of devices"); - } - - std::vector> new_strategies; - const auto& strategy_input_shardings = - strategy_group.GetStrategyInputShardings(); - for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { - const InputShardings& input_shardings = strategy_input_shardings[iid]; - const ShardingStrategy& strategy = - strategy_group.GetStrategyForInputShardings(iid); - const HloSharding& output_sharding = strategy.output_sharding; - const std::vector tensor_dim_to_mesh_dim = - cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding); - - if (device_mesh.dim(mesh_dim) > 1) { - // If the mesh dim is not one, the output tensor must be - // tiled along the mesh dim. - if (tensor_dim_to_mesh_dim[batch_dim] == mesh_dim) { - new_strategies.push_back({strategy, input_shardings}); - } - } else { - // If the mesh dim is one, the output tensor must be replicated - // on the mesh dim. - if (tensor_dim_to_mesh_dim[batch_dim] == -1) { - new_strategies.push_back({strategy, input_shardings}); - } - } - } - CHECK(!new_strategies.empty()) - << ins->ToString() << " does not have any valid strategies"; - strategy_group.ClearStrategies(); - for (const auto& [strategy, input_shardings] : new_strategies) { - strategy_group.AddStrategy(strategy, input_shardings); - } - - return absl::OkStatus(); -} - // Return the output sharding of the reduce-scatter variant of a given strategy. HloSharding GetReduceScatterOutput(const HloInstruction* ins, const ShardingStrategy& strategy, @@ -3780,13 +3674,6 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( instruction_execution_counts = spmd::ComputeInstructionExecutionCounts( module, option_.loop_iteration_count_estimate); - // ----- Analyze the batch dim ----- - spmd::InstructionBatchDimMap batch_dim_map; - // TODO(yuemmawang) Enable the batch_dim_map if it becomes helpful. This is - // supposed to make the solver faster, but it makes it much much slower for - // both 1D and 2D mesh shapes. - // batch_dim_map = spmd::BuildInstructionBatchDimMap(sequence); - // ----- Read parameters of device mesh ----- spmd::DeviceMesh original_device_mesh(option_.device_mesh_shape); original_device_mesh.SetValues(option_.device_mesh_ids); @@ -3871,12 +3758,6 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( << option_.memory_budget_per_device; } - if (option_.force_batch_dim_to_mesh_dim >= 0) { - spmd::DisableIncompatibleMixedMeshShapeAndForceBatchDim( - batch_dim_map, sequence.instructions(), device_mesh.num_elements(), - option_); - } - // ----- Analyze depth ----- spmd::InstructionDepthMap ins_depth_map; ins_depth_map = spmd::BuildInstructionDepthMap(sequence); @@ -3889,8 +3770,8 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( std::tie(strategy_map, strategy_groups, associative_dot_pairs), BuildStrategyAndCost(sequence, module, instructions_to_shard, instruction_execution_counts, ins_depth_map, - batch_dim_map, alias_map, cluster_env, option_, - *call_graph, hlo_cost_analysis, + alias_map, cluster_env, option_, *call_graph, + hlo_cost_analysis, option_.try_multiple_mesh_shapes)); spmd::AliasSet alias_set = spmd::BuildAliasSet(module, input_output_alias_config, strategy_map); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index be983f15916eed..1a1b67c83873ff 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -151,7 +151,6 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group); absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, StrategyGroup& strategy_group); @@ -162,7 +161,6 @@ absl::Status HandleDot(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); @@ -173,7 +171,6 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); @@ -250,17 +247,16 @@ void FillAllStrategiesForArray( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, - bool only_allow_divisible, bool create_replicated_strategies, + const CallGraph& call_graph, bool only_allow_divisible, + bool create_replicated_strategies, bool create_partially_replicated_strategies, StrategyGroup& strategy_group); absl::StatusOr> CreateAllStrategiesGroup( const HloInstruction* ins, const Shape& shape, size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, - double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, - const CallGraph& call_graph, bool only_allow_divisible, - bool create_replicated_strategies, + double replicated_penalty, const CallGraph& call_graph, + bool only_allow_divisible, bool create_replicated_strategies, bool create_partially_replicated_strategies); // Enumerates sharding strategies for elementwise operators by following @@ -294,7 +290,6 @@ std::unique_ptr CreateReshapeStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, bool only_allow_divisible, double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const AutoShardingOption& option, StrategyGroups& strategy_groups, const CallGraph& call_graph); @@ -313,8 +308,7 @@ void EnumerateAll1DPartition( void EnumerateAllPartition( const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - const InstructionBatchDimMap& batch_dim_map, bool only_allow_divisible, + const StrategyMap& strategy_map, bool only_allow_divisible, bool allow_shardings_small_dims_across_many_devices, const CallGraph& call_graph, int64_t partition_dimensions, const std::vector& tensor_dims, StrategyGroup& strategy_group); @@ -368,8 +362,7 @@ BuildStrategyAndCost( const absl::flat_hash_set& instructions_to_shard, const absl::flat_hash_map& instruction_execution_counts, - const InstructionDepthMap& depth_map, - const InstructionBatchDimMap& batch_dim_map, const AliasMap& alias_map, + const InstructionDepthMap& depth_map, const AliasMap& alias_map, const ClusterEnvironment& cluster_env, AutoShardingOption& option, const CallGraph& call_graph, const HloCostAnalysis& hlo_cost_analysis, bool trying_multiple_mesh_shapes); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index f80958b099ff45..1916172aa70901 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -70,7 +70,6 @@ class HandlerBase { const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) : strategy_group_(strategy_group), strategy_map_(strategy_map), @@ -79,7 +78,6 @@ class HandlerBase { instruction_sequence_(instruction_sequence), hlo_cost_analysis_(hlo_cost_analysis), cluster_env_(cluster_env), - batch_map_(batch_map), option_(option), call_graph_(call_graph), device_mesh_(cluster_env.device_mesh_), @@ -221,7 +219,6 @@ class HandlerBase { const HloInstructionSequence& instruction_sequence_; const HloCostAnalysis& hlo_cost_analysis_; const ClusterEnvironment& cluster_env_; - const InstructionBatchDimMap& batch_map_; const AutoShardingOption& option_; const CallGraph& call_graph_; @@ -238,7 +235,6 @@ class DotHandler : public HandlerBase { const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); DotHandler( @@ -247,8 +243,7 @@ class DotHandler : public HandlerBase { const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, + const ClusterEnvironment& cluster_env, const AutoShardingOption& option, const CallGraph& call_graph); ~DotHandler() override = default; @@ -293,7 +288,6 @@ class ConvHandler : public HandlerBase { const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); ~ConvHandler() override = default; @@ -492,12 +486,11 @@ DotHandler::DotHandler(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) : HandlerBase(strategy_group, strategy_map, ins, instruction_id, - instruction_sequence, hlo_cost_analysis, cluster_env, - batch_map, option, call_graph), + instruction_sequence, hlo_cost_analysis, cluster_env, option, + call_graph), is_dot_(true), space_base_dim_(ins->dot_dimension_numbers().lhs_batch_dimensions_size()), lhs_con_dims_(ins->dot_dimension_numbers().lhs_contracting_dimensions()), @@ -525,12 +518,11 @@ DotHandler::DotHandler( const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, + const ClusterEnvironment& cluster_env, const AutoShardingOption& option, const CallGraph& call_graph) : HandlerBase(strategy_group, strategy_map, ins, instruction_id, - instruction_sequence, hlo_cost_analysis, cluster_env, - batch_map, option, call_graph), + instruction_sequence, hlo_cost_analysis, cluster_env, option, + call_graph), is_dot_(false), space_base_dim_(-1) { CHECK(conv_as_dot_dims.conv_spatial_dims.empty()); @@ -858,12 +850,11 @@ ConvHandler::ConvHandler(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) : HandlerBase(strategy_group, strategy_map, ins, instruction_id, - instruction_sequence, hlo_cost_analysis, cluster_env, - batch_map, option, call_graph), + instruction_sequence, hlo_cost_analysis, cluster_env, option, + call_graph), conv_dnums_(ins->convolution_dimension_numbers()) { lhs_batch_dim_ = conv_dnums_.input_batch_dimension(); lhs_in_channel_dim_ = conv_dnums_.input_feature_dimension(); @@ -969,14 +960,6 @@ absl::Status ConvHandler::RegisterStrategies() { 2, /*current_mesh_dim_idx=*/0, all_mesh_dims, /*current_dim_map=*/{}); - // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies - // and only keep the data parallel strategies. - if (option_.force_batch_dim_to_mesh_dim >= 0 && - batch_map_.contains(GetBatchDimMapKey(ins_))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), cluster_env_, - batch_map_, option_, *strategy_group_)); - } - SortStrategies(); return absl::OkStatus(); } @@ -1027,7 +1010,6 @@ absl::Status HandleDot(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, @@ -1035,7 +1017,7 @@ absl::Status HandleDot(std::unique_ptr& strategy_group, DotHandler handler(strategy_group, strategy_map, Cast(ins), instruction_id, instruction_sequence, hlo_cost_analysis, - cluster_env, batch_map, option, call_graph); + cluster_env, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); return absl::OkStatus(); } @@ -1048,7 +1030,6 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, @@ -1057,16 +1038,16 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(ins); if (conv_as_dot_dims.conv_spatial_dims.empty()) { - DotHandler handler( - strategy_group, strategy_map, Cast(ins), - instruction_id, instruction_sequence, hlo_cost_analysis, - conv_as_dot_dims, cluster_env, batch_map, option, call_graph); + DotHandler handler(strategy_group, strategy_map, + Cast(ins), instruction_id, + instruction_sequence, hlo_cost_analysis, + conv_as_dot_dims, cluster_env, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } else { ConvHandler handler(strategy_group, strategy_map, ins, instruction_id, instruction_sequence, hlo_cost_analysis, cluster_env, - batch_map, option, call_graph); + option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } return absl::OkStatus(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc index f8c2d41614ccb5..3b35dec03b3cd4 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc @@ -66,8 +66,6 @@ std::string AutoShardingOption::ToString() const { lines.push_back(absl::StrCat("reduce_scatter_cost: ", reduce_scatter_cost)); } - lines.push_back(absl::StrCat("force_batch_dim_to_mesh_dim: ", - force_batch_dim_to_mesh_dim)); lines.push_back(absl::StrCat("allow_replicated_parameters: ", allow_replicated_parameters)); lines.push_back( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 468ab4aa8f3c79..5decdf1a24c5f8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -84,11 +84,6 @@ struct AutoShardingOption { bool force_override_reduce_scatter_cost = false; double reduce_scatter_cost = 0; - // Forcibly split the batch dimension and map it to a mesh dimension. - // This can force the auto-sharding pass to generate the data parallel - // strategy. - int force_batch_dim_to_mesh_dim = -1; - // If true, allow replicated parameters. bool allow_replicated_parameters = true; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index c20cd9af5b3680..3bfd5e2838cf54 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -217,8 +217,7 @@ BuildStrategyAndCost( const absl::flat_hash_set& instructions_to_shard, const absl::flat_hash_map& instruction_execution_counts, - const InstructionDepthMap& depth_map, - const InstructionBatchDimMap& batch_dim_map, const AliasMap& alias_map, + const InstructionDepthMap& depth_map, const AliasMap& alias_map, const ClusterEnvironment& cluster_env, AutoShardingOption& option, const CallGraph& call_graph, const HloCostAnalysis& hlo_cost_analysis, bool trying_multiple_mesh_shapes) { @@ -310,9 +309,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - option.allow_replicated_parameters, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, option.allow_replicated_parameters, /* create_partially_replicated_strategies */ true) .value(); break; @@ -322,9 +320,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - option.allow_replicated_parameters, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, option.allow_replicated_parameters, /* create_partially_replicated_strategies */ true) .value(); break; @@ -520,8 +517,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true) .value(); @@ -530,8 +527,8 @@ BuildStrategyAndCost( case HloOpcode::kReshape: { strategy_group = CreateReshapeStrategies( instruction_id, ins, strategy_map, cluster_env, - only_allow_divisible, replicated_penalty, batch_dim_map, option, - strategy_groups, call_graph); + only_allow_divisible, replicated_penalty, option, strategy_groups, + call_graph); break; } case HloOpcode::kTranspose: @@ -737,8 +734,8 @@ BuildStrategyAndCost( } else { strategy_group = CreateReshapeStrategies( instruction_id, ins, strategy_map, cluster_env, - only_allow_divisible, replicated_penalty, batch_dim_map, option, - strategy_groups, call_graph); + only_allow_divisible, replicated_penalty, option, strategy_groups, + call_graph); } break; } @@ -812,10 +809,9 @@ BuildStrategyAndCost( break; } case HloOpcode::kDot: { - TF_RETURN_IF_ERROR(HandleDot(strategy_group, strategy_groups, - strategy_map, ins, instruction_id, - sequence, hlo_cost_analysis, cluster_env, - batch_dim_map, option, call_graph)); + TF_RETURN_IF_ERROR(HandleDot( + strategy_group, strategy_groups, strategy_map, ins, instruction_id, + sequence, hlo_cost_analysis, cluster_env, option, call_graph)); if (option.allow_recompute_heavy_op) { AddReplicatedStrategy( @@ -827,10 +823,9 @@ BuildStrategyAndCost( break; } case HloOpcode::kConvolution: { - TF_RETURN_IF_ERROR(HandleConv(strategy_group, strategy_groups, - strategy_map, ins, instruction_id, - sequence, hlo_cost_analysis, cluster_env, - batch_dim_map, option, call_graph)); + TF_RETURN_IF_ERROR(HandleConv( + strategy_group, strategy_groups, strategy_map, ins, instruction_id, + sequence, hlo_cost_analysis, cluster_env, option, call_graph)); if (option.allow_recompute_heavy_op) { AddReplicatedStrategy( ins, ins->shape(), cluster_env, strategy_map, @@ -851,8 +846,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true) .value(); @@ -923,7 +918,7 @@ BuildStrategyAndCost( CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, + call_graph, only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true) .value(); @@ -987,8 +982,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true) .value(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc index 42402e39a1496f..f9c1ce429a114c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -42,12 +42,6 @@ double ClusterEnvironment::AllGatherCost(double num_bytes, int mesh_dim) const { num_bytes / 4, "float32"); } - if (auto_sharding_option_.force_batch_dim_to_mesh_dim == mesh_dim) { - // if data-parallel is forced on this dim, we only allow all-reduce - // in this dimension. - return kInfinityCost; - } - int64_t num_devices = device_mesh_.dim(mesh_dim); return (round(mesh_alpha_[mesh_dim] + mesh_beta_[mesh_dim] * (num_devices - 1) / num_devices * @@ -123,12 +117,6 @@ double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const { num_bytes / 4, "float32"); } - if (auto_sharding_option_.force_batch_dim_to_mesh_dim == mesh_dim) { - // if data-parallel is forced on this dim, we only allow all-reduce - // in this dimension. - return kInfinityCost; - } - int64_t num_devices = device_mesh_.dim(mesh_dim); return AllToAllCostUtil(num_bytes, mesh_dim, num_devices); } From 5cf0bbd9d8b0dacacb068716d6fb97ec63f4f263 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 19 Sep 2024 11:31:45 -0700 Subject: [PATCH 018/483] Rearrange definitions in alphabetical order. PiperOrigin-RevId: 676496366 --- .../hlo_to_mhlo/module_attributes_importer.cc | 26 +++++++++---------- .../hlo_to_mhlo/module_attributes_importer.h | 12 ++++----- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc index a60499f1b9bdf0..d445640802dc56 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc @@ -208,6 +208,19 @@ void ImportFrontendAttributes(const xla::HloModule& hlo_module, } } +void ImportInputOutputAlias(const xla::HloModule& hlo_module, + mlir::ModuleOp module, mlir::Builder builder) { + module->setAttr(kInputOutputAlias, + ConvertInputOutputAlias( + hlo_module.input_output_alias_config(), &builder)); +} + +void ImportIsDynamic(const xla::HloModule& hlo_module, mlir::ModuleOp module, + mlir::Builder builder) { + module->setAttr(kIsDynamic, mlir::BoolAttr::get(builder.getContext(), + hlo_module.is_dynamic())); +} + void ImportNumPartitions(const xla::HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { const auto& config = hlo_module.config(); @@ -226,19 +239,6 @@ void ImportNumReplicas(const xla::HloModule& hlo_module, mlir::ModuleOp module, } } -void ImportInputOutputAlias(const xla::HloModule& hlo_module, - mlir::ModuleOp module, mlir::Builder builder) { - module->setAttr(kInputOutputAlias, - ConvertInputOutputAlias( - hlo_module.input_output_alias_config(), &builder)); -} - -void ImportIsDynamic(const xla::HloModule& hlo_module, mlir::ModuleOp module, - mlir::Builder builder) { - module->setAttr(kIsDynamic, mlir::BoolAttr::get(builder.getContext(), - hlo_module.is_dynamic())); -} - void ImportSpmdOutputSharding(const xla::HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { if (hlo_module.has_spmd_output_sharding()) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h index bd4580e5d315a5..b29c09c86e29a4 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h @@ -37,18 +37,18 @@ void ImportEntryComputationLayoutAndTiles(const HloModule& hlo_module, void ImportFrontendAttributes(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder); -void ImportNumPartitions(const HloModule& hlo_module, mlir::ModuleOp module, - mlir::Builder builder); - -void ImportNumReplicas(const HloModule& hlo_module, mlir::ModuleOp module, - mlir::Builder builder); - void ImportInputOutputAlias(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder); void ImportIsDynamic(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder); +void ImportNumPartitions(const HloModule& hlo_module, mlir::ModuleOp module, + mlir::Builder builder); + +void ImportNumReplicas(const HloModule& hlo_module, mlir::ModuleOp module, + mlir::Builder builder); + void ImportSpmdOutputSharding(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder); From 6c16a024cee4ce3a9d7a58080f66a49475e4230c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 11:59:47 -0700 Subject: [PATCH 019/483] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/4e088077c6e455610a39ed8d2a18fe195ada6137. PiperOrigin-RevId: 676507184 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 3466def95fd60d..c51c05301e97d1 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "07992d7c1ead60f610c17b7c1f9e50b6898adc87" - TFRT_SHA256 = "e1de8d371248d3dfc6e9ebd0e4094b57ce04d9545ae3756b5a84c33482614d5f" + TFRT_COMMIT = "4e088077c6e455610a39ed8d2a18fe195ada6137" + TFRT_SHA256 = "b2c5a3651ad13074760781ad6d8339da09352a2ce24db4fe86c856f7ca610a16" tf_http_archive( name = "tf_runtime", From 9d9a238b75278f1f4137e3124c4644506a6cf3f5 Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Thu, 19 Sep 2024 12:36:07 -0700 Subject: [PATCH 020/483] Update inference_diff to support more types PiperOrigin-RevId: 676520699 --- .../stages/inference_profiler_stage.cc | 33 ++++++++++++++++--- .../stages/inference_profiler_stage.h | 3 ++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc index 094d76134f3bac..b26af9f983a640 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h" #include +#include #include #include #include @@ -101,12 +102,12 @@ TfLiteStatus InferenceProfilerStage::Init( for (int i = 0; i < model_info_->inputs.size(); ++i) { const TfLiteType model_input_type = model_info_->inputs[i]->type; if (model_input_type == kTfLiteUInt8 || model_input_type == kTfLiteInt8 || - model_input_type == kTfLiteInt64 || - model_input_type == kTfLiteFloat32 || + model_input_type == kTfLiteInt32 || model_input_type == kTfLiteInt64 || + model_input_type == kTfLiteBool || model_input_type == kTfLiteFloat32 || model_input_type == kTfLiteFloat16) { } else { LOG(ERROR) << "InferenceProfilerStage only supports " - "float16/float32/int8/uint8/int64 " + "float16/float32/int8/uint8/int32/int64/bool " "input types"; return kTfLiteError; } @@ -121,14 +122,18 @@ TfLiteStatus InferenceProfilerStage::Init( int8_tensors_.emplace_back(); float16_tensors_.emplace_back(); int64_tensors_.emplace_back(); + int32_tensors_.emplace_back(); + bool_tensors_.emplace_back(); } // Preprocess output metadata for calculating diffs later. for (int i = 0; i < model_info_->outputs.size(); ++i) { const TfLiteType model_output_type = model_info_->outputs[i]->type; if (model_output_type == kTfLiteUInt8 || model_output_type == kTfLiteInt8 || + model_output_type == kTfLiteInt32 || model_output_type == kTfLiteBool || model_output_type == kTfLiteFloat32) { } else { - LOG(ERROR) << "InferenceProfilerStage only supports float32/int8/uint8 " + LOG(ERROR) << "InferenceProfilerStage only supports " + "float32/int8/uint8/int32/bool " "output types"; return kTfLiteError; } @@ -160,11 +165,20 @@ TfLiteStatus InferenceProfilerStage::Run() { input_num_elements_[i], std::numeric_limits::min(), std::numeric_limits::max(), &int8_tensors_[i]); input_ptrs.push_back(int8_tensors_[i].data()); + } else if (model_input_type == kTfLiteInt32) { + GenerateRandomGaussianData( + input_num_elements_[i], std::numeric_limits::min(), + std::numeric_limits::max(), &int32_tensors_[i]); + input_ptrs.push_back(int32_tensors_[i].data()); } else if (model_input_type == kTfLiteInt64) { GenerateRandomGaussianData( input_num_elements_[i], std::numeric_limits::min(), std::numeric_limits::max(), &int64_tensors_[i]); input_ptrs.push_back(int64_tensors_[i].data()); + } else if (model_input_type == kTfLiteBool) { + GenerateRandomGaussianData(input_num_elements_[i], 0, 1, + &bool_tensors_[i]); + input_ptrs.push_back(bool_tensors_[i].data()); } else if (model_input_type == kTfLiteFloat32) { GenerateRandomGaussianData(input_num_elements_[i], -1, 1, &(float_tensors_[i])); @@ -179,7 +193,7 @@ TfLiteStatus InferenceProfilerStage::Run() { input_ptrs.push_back(float16_tensors_[i].data()); } else { LOG(ERROR) << "InferenceProfilerStage only supports " - "float16/float32/int8/uint8/int64 " + "float16/float32/int8/uint8/int32/int64/bool " "input types"; return kTfLiteError; } @@ -205,6 +219,15 @@ TfLiteStatus InferenceProfilerStage::Run() { output_diff = CalculateAverageError(static_cast(reference_ptr), static_cast(test_ptr), output_num_elements_[i]); + } else if (model_output_type == kTfLiteInt32) { + output_diff = CalculateAverageError(static_cast(reference_ptr), + static_cast(test_ptr), + output_num_elements_[i]); + } else if (model_output_type == kTfLiteBool) { + // Use int8_t* for bool tensors to use void* casting. + output_diff = CalculateAverageError(static_cast(reference_ptr), + static_cast(test_ptr), + output_num_elements_[i]); } else if (model_output_type == kTfLiteFloat32) { output_diff = CalculateAverageError(static_cast(reference_ptr), static_cast(test_ptr), diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h index d48c836f035b44..a68049ed960b11 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h @@ -66,7 +66,10 @@ class InferenceProfilerStage : public EvaluationStage { std::vector> int8_tensors_; std::vector> uint8_tensors_; std::vector> float16_tensors_; + std::vector> int32_tensors_; std::vector> int64_tensors_; + // Use uint8_t for bool tensors to use void* casting. + std::vector> bool_tensors_; }; } // namespace evaluation From 3a8cec9aa2789c59137f6f25e4818d1189a7620b Mon Sep 17 00:00:00 2001 From: zoranjovanovic-ns <126815388+zoranjovanovic-ns@users.noreply.github.com> Date: Thu, 19 Sep 2024 13:08:13 -0700 Subject: [PATCH 021/483] PR #17142: [ROCm] Disable gemm triton fusions for ROCm Imported from GitHub PR https://github.com/openxla/xla/pull/17142 Until autotuner is functional, avoid preformance drop. Copybara import of the project: -- fbacfaff6580c93d4559489b3a0ec835c7b93451 by Zoran Jovanovic : [ROCm] Disable gemm triton fusions for ROCm, until autotuner is functional. Merging this change closes #17142 PiperOrigin-RevId: 676533105 --- .../triton_fusion_emitter_device_legacy_test.cc | 14 ++++++++++++++ .../triton/triton_fusion_emitter_large_test.cc | 14 ++++++++++++++ .../triton_fusion_emitter_parametrized_test.cc | 15 +++++++++++++++ third_party/xla/xla/service/gpu/gpu_compiler.cc | 5 ++--- .../xla/xla/service/gpu/gpu_compiler_test.cc | 11 +++++++++++ 5 files changed, 56 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index d07abdb4811224..7ac574504294a1 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -100,6 +100,20 @@ class TritonTest : public GpuCodegenTest { class TritonGemmTest : public TritonTest { public: + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative( + GetGpuComputeCapability())) { + GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled."; + } + } + DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); // Do not fall back to cuBLAS, we are testing Triton. diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc index a50776c6c54e9f..97634100ad3aa6 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc @@ -28,6 +28,20 @@ namespace { class TritonGemmTest : public GpuCodegenTest { public: + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative( + GetGpuComputeCapability())) { + GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled."; + } + } + DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_cublas_fallback(false); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index a3cafcdf40628f..11b9ae776a3157 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -53,6 +53,21 @@ struct MixTypeParams { class MixedTypeTest : public GpuCodegenTest, public ::testing::WithParamInterface { public: + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative( + GetGpuComputeCapability())) { + GTEST_SKIP() + << "Related fusions are not performed on ROCm without Triton."; + } + } + DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // We are testing Triton, remove cuBLAS fallback for these tests. diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 86f07cfb649404..f4aafc5972a3e6 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1472,9 +1472,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const auto* rocm_cc = std::get_if(&gpu_version); if (debug_options.xla_gpu_enable_triton_gemm() && - ((cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || - rocm_cc != nullptr)) { + (cuda_cc != nullptr && + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE))) { pipeline.AddPass(); pipeline.AddPass(gpu_version); } else if (cuda_cc != nullptr && diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 51b459e8a81a02..b37415c9d9382e 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -85,6 +85,13 @@ class GpuCompilerTest : public HloTestBase { return tensorflow::down_cast(compiler) ->RunPostSchedulingPipelines(module, 4 * 1024 * 1024, gpu_device_info); } + + const stream_executor::GpuComputeCapability& GpuComputeComp() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } }; TEST_F(GpuCompilerTest, CompiledProgramsCount) { @@ -970,6 +977,10 @@ using GpuCompilerPassTest = GpuCompilerTest; TEST_F(GpuCompilerPassTest, GpuCompilerRunsTritonGemmRewriterByDefaultFromAmpere) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "TritonGemmRewriter disabled for ROCm until autotuner " + << "is included."; + } auto cc = backend() .default_stream_executor() ->GetDeviceDescription() From 071602bc316a12b61d8e3e2f36506f52bf069f96 Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Thu, 19 Sep 2024 13:28:57 -0700 Subject: [PATCH 022/483] Add device_utils and generic device types. PiperOrigin-RevId: 676541460 --- .../third_party/tsl/tsl/profiler/utils/BUILD | 24 ++++++++++ .../tsl/tsl/profiler/utils/device_utils.cc | 37 +++++++++++++++ .../tsl/tsl/profiler/utils/device_utils.h | 37 +++++++++++++++ .../tsl/profiler/utils/device_utils_test.cc | 45 +++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.cc create mode 100644 third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.h create mode 100644 third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils_test.cc diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD index 8ea8fd71837272..4c0eda1496c5b7 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD @@ -568,3 +568,27 @@ tsl_cc_test( "@com_google_absl//absl/synchronization", ], ) + +cc_library( + name = "device_utils", + srcs = ["device_utils.cc"], + hdrs = ["device_utils.h"], + deps = [ + ":xplane_schema", + "//tsl/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/strings", + ], +) + +tsl_cc_test( + name = "device_utils_test", + srcs = ["device_utils_test.cc"], + deps = [ + ":device_utils", + ":xplane_schema", + "//tsl/platform:test", + "//tsl/platform:test_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.cc new file mode 100644 index 00000000000000..9caedcc47be08c --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/utils/device_utils.h" + +#include "absl/strings/match.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/xplane_schema.h" + +namespace tsl { +namespace profiler { + +DeviceType GetDeviceType(const tensorflow::profiler::XPlane& plane) { + if (plane.name() == kHostThreadsPlaneName) { + return DeviceType::kCpu; + } else if (absl::StartsWith(plane.name(), kTpuPlanePrefix)) { + return DeviceType::kTpu; + } else if (absl::StartsWith(plane.name(), kGpuPlanePrefix)) { + return DeviceType::kGpu; + } else { + return DeviceType::kUnknown; + } +} +} // namespace profiler +} // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.h new file mode 100644 index 00000000000000..33c331a0790a6e --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ +#define TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ + +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tsl { +namespace profiler { + +enum class DeviceType { + kUnknown, + kCpu, + kTpu, + kGpu, +}; + +// Get DeviceType from XPlane. +DeviceType GetDeviceType(const tensorflow::profiler::XPlane& plane); + +} // namespace profiler +} // namespace tsl + +#endif // TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils_test.cc new file mode 100644 index 00000000000000..e01680678c2b19 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/utils/device_utils.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/test.h" +#include "tsl/profiler/utils/xplane_schema.h" + +namespace tsl { +namespace profiler { +namespace { + +tensorflow::profiler::XPlane CreateXPlane(absl::string_view name) { + tensorflow::profiler::XPlane plane; + plane.set_name(name.data(), name.size()); + return plane; +} + +TEST(DeviceUtilsTest, GetDeviceType) { + EXPECT_EQ(GetDeviceType(CreateXPlane(kHostThreadsPlaneName)), + DeviceType::kCpu); + EXPECT_EQ(GetDeviceType(CreateXPlane(absl::StrCat(kTpuPlanePrefix, 0))), + DeviceType::kTpu); + EXPECT_EQ(GetDeviceType(CreateXPlane(absl::StrCat(kGpuPlanePrefix, 0))), + DeviceType::kGpu); + EXPECT_EQ(GetDeviceType(CreateXPlane("unknown")), DeviceType::kUnknown); +} + +} // namespace +} // namespace profiler +} // namespace tsl From 30133d4ac565fc0fc1589b5ceeb0dc20d51d6bec Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 13:29:44 -0700 Subject: [PATCH 023/483] Delete all redundant environment variables from RBE configs. These environment variables are not used in any repository rules initialized in RBE configs. PiperOrigin-RevId: 676541760 --- .../toolchains/remote_config/configs.bzl | 96 ------------------- .../toolchains/remote_config/rbe_config.bzl | 32 +------ .../toolchains/remote_config/configs.bzl | 96 ------------------- .../toolchains/remote_config/rbe_config.bzl | 32 +------ .../toolchains/remote_config/configs.bzl | 96 ------------------- .../toolchains/remote_config/rbe_config.bzl | 32 +------ 6 files changed, 12 insertions(+), 372 deletions(-) diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl index 2feee8960439e9..1182e52997fce0 100644 --- a/tensorflow/tools/toolchains/remote_config/configs.bzl +++ b/tensorflow/tools/toolchains/remote_config/configs.bzl @@ -9,7 +9,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -17,7 +16,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", @@ -25,8 +23,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -34,7 +30,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -42,8 +37,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -70,25 +63,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -99,23 +73,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -126,25 +83,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -155,23 +93,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:2d737fc9fe931507a89927eee792b1bb934215e6aaae58b1941586e3400e2645", "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:45ea78e79305f91cdae5a26094f80233bba54bbfbc612623381012f097035b9a", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -182,21 +103,4 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-cudnn9-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:daa5bdd802fe3def188e2200ed707c73d278f6f1930bf26c933d6ba041b0e027", "sigbuild-r2.17-clang-cudnn9-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:23e477895dd02e45df1056d4a0a9c4229dec3a20c23fb2f3fb5832ecbd0a29bc", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl index a916c10e77d634..8a6120efbbd69d 100644 --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl @@ -9,34 +9,13 @@ def _container_image_uri(container_name): container = containers[container_name] return "docker://%s/%s@%s" % (container["registry"], container["repository"], container["digest"]) -def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version = None, cudnn_version = None, compiler_prefix = None): +def _tensorflow_rbe_config(name, os, rocm_version = None, cuda_version = None, cudnn_version = None): if cuda_version != None and rocm_version != None: fail("Specifying both cuda_version and rocm_version is not supported.") - env = { - "ABI_VERSION": "gcc", - "ABI_LIBC_VERSION": "glibc_2.19", - "BAZEL_COMPILER": compiler, - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CC": compiler, - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": compiler, - "HOST_C_COMPILER": compiler, - } + env = {} if cuda_version != None: - # The cuda toolchain currently contains its own C++ toolchain definition, - # so we do not fetch local_config_cc. - env.update({ - "TF_ENABLE_XLA": "1", - "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", - "GCC_HOST_COMPILER_PREFIX": compiler_prefix if compiler_prefix != None else "/usr/bin", - }) - cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) container_name = "cuda%s-cudnn%s-%s" % ( @@ -49,13 +28,11 @@ def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version "container-image": container_image, "Pool": "default", } - elif rocm_version != None: # The rocm toolchain currently contains its own C++ toolchain definition, # so we do not fetch local_config_cc. env.update({ "TF_NEED_ROCM": "1", - "TF_ENABLE_XLA": "0", }) container_name = "rocm-%s" % (os) @@ -121,9 +98,8 @@ tensorflow_local_config = _tensorflow_local_config # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles -# These containers do not support ROCm and all have CUDA. We demand that the configuration -# provide all the env variables to remove hidden logic. -def sigbuild_tf_configs(name_container_map, env): +# These containers do not support ROCm and all have CUDA. +def sigbuild_tf_configs(name_container_map): for name, container in name_container_map.items(): exec_properties = { "container-image": container, diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 492d591d208a81..83f52d9af9970a 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -9,7 +9,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -17,7 +16,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", @@ -25,8 +23,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -34,7 +30,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -42,8 +37,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -70,25 +63,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -99,23 +73,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -126,25 +83,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -155,23 +93,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:2d737fc9fe931507a89927eee792b1bb934215e6aaae58b1941586e3400e2645", "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:45ea78e79305f91cdae5a26094f80233bba54bbfbc612623381012f097035b9a", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -182,21 +103,4 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-cudnn9-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:daa5bdd802fe3def188e2200ed707c73d278f6f1930bf26c933d6ba041b0e027", "sigbuild-r2.17-clang-cudnn9-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:23e477895dd02e45df1056d4a0a9c4229dec3a20c23fb2f3fb5832ecbd0a29bc", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl index fe29d52ad1bef9..280b8d914283dd 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl @@ -9,34 +9,13 @@ def _container_image_uri(container_name): container = containers[container_name] return "docker://%s/%s@%s" % (container["registry"], container["repository"], container["digest"]) -def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version = None, cudnn_version = None, compiler_prefix = None): +def _tensorflow_rbe_config(name, os, rocm_version = None, cuda_version = None, cudnn_version = None): if cuda_version != None and rocm_version != None: fail("Specifying both cuda_version and rocm_version is not supported.") - env = { - "ABI_VERSION": "gcc", - "ABI_LIBC_VERSION": "glibc_2.19", - "BAZEL_COMPILER": compiler, - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CC": compiler, - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": compiler, - "HOST_C_COMPILER": compiler, - } + env = {} if cuda_version != None: - # The cuda toolchain currently contains its own C++ toolchain definition, - # so we do not fetch local_config_cc. - env.update({ - "TF_ENABLE_XLA": "1", - "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", - "GCC_HOST_COMPILER_PREFIX": compiler_prefix if compiler_prefix != None else "/usr/bin", - }) - cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) container_name = "cuda%s-cudnn%s-%s" % ( @@ -49,13 +28,11 @@ def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version "container-image": container_image, "Pool": "default", } - elif rocm_version != None: # The rocm toolchain currently contains its own C++ toolchain definition, # so we do not fetch local_config_cc. env.update({ "TF_NEED_ROCM": "1", - "TF_ENABLE_XLA": "0", }) container_name = "rocm-%s" % (os) @@ -121,9 +98,8 @@ tensorflow_local_config = _tensorflow_local_config # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles -# These containers do not support ROCm and all have CUDA. We demand that the configuration -# provide all the env variables to remove hidden logic. -def sigbuild_tf_configs(name_container_map, env): +# These containers do not support ROCm and all have CUDA. +def sigbuild_tf_configs(name_container_map): for name, container in name_container_map.items(): exec_properties = { "container-image": container, diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl index 492d591d208a81..83f52d9af9970a 100644 --- a/third_party/xla/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl @@ -9,7 +9,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -17,7 +16,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", @@ -25,8 +23,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -34,7 +30,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -42,8 +37,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -70,25 +63,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -99,23 +73,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -126,25 +83,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -155,23 +93,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:2d737fc9fe931507a89927eee792b1bb934215e6aaae58b1941586e3400e2645", "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:45ea78e79305f91cdae5a26094f80233bba54bbfbc612623381012f097035b9a", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -182,21 +103,4 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-cudnn9-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:daa5bdd802fe3def188e2200ed707c73d278f6f1930bf26c933d6ba041b0e027", "sigbuild-r2.17-clang-cudnn9-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:23e477895dd02e45df1056d4a0a9c4229dec3a20c23fb2f3fb5832ecbd0a29bc", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) diff --git a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl index fe29d52ad1bef9..280b8d914283dd 100644 --- a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl @@ -9,34 +9,13 @@ def _container_image_uri(container_name): container = containers[container_name] return "docker://%s/%s@%s" % (container["registry"], container["repository"], container["digest"]) -def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version = None, cudnn_version = None, compiler_prefix = None): +def _tensorflow_rbe_config(name, os, rocm_version = None, cuda_version = None, cudnn_version = None): if cuda_version != None and rocm_version != None: fail("Specifying both cuda_version and rocm_version is not supported.") - env = { - "ABI_VERSION": "gcc", - "ABI_LIBC_VERSION": "glibc_2.19", - "BAZEL_COMPILER": compiler, - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CC": compiler, - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": compiler, - "HOST_C_COMPILER": compiler, - } + env = {} if cuda_version != None: - # The cuda toolchain currently contains its own C++ toolchain definition, - # so we do not fetch local_config_cc. - env.update({ - "TF_ENABLE_XLA": "1", - "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", - "GCC_HOST_COMPILER_PREFIX": compiler_prefix if compiler_prefix != None else "/usr/bin", - }) - cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) container_name = "cuda%s-cudnn%s-%s" % ( @@ -49,13 +28,11 @@ def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version "container-image": container_image, "Pool": "default", } - elif rocm_version != None: # The rocm toolchain currently contains its own C++ toolchain definition, # so we do not fetch local_config_cc. env.update({ "TF_NEED_ROCM": "1", - "TF_ENABLE_XLA": "0", }) container_name = "rocm-%s" % (os) @@ -121,9 +98,8 @@ tensorflow_local_config = _tensorflow_local_config # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles -# These containers do not support ROCm and all have CUDA. We demand that the configuration -# provide all the env variables to remove hidden logic. -def sigbuild_tf_configs(name_container_map, env): +# These containers do not support ROCm and all have CUDA. +def sigbuild_tf_configs(name_container_map): for name, container in name_container_map.items(): exec_properties = { "container-image": container, From 2bc6a68fd7de1655349b2b5cee337637a1c41ba3 Mon Sep 17 00:00:00 2001 From: Ziyin Huang Date: Thu, 19 Sep 2024 14:35:59 -0700 Subject: [PATCH 024/483] Add an option to enable fast table initialization. PiperOrigin-RevId: 676567135 --- tensorflow/python/tpu/tpu_embedding_v3.py | 10 ++++++++++ ...ental.embedding.-sparse-core-embedding-config.pbtxt | 6 +++++- ...ental.embedding.-sparse-core-embedding-config.pbtxt | 6 +++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/tpu/tpu_embedding_v3.py b/tensorflow/python/tpu/tpu_embedding_v3.py index 1536801f77a74c..c822ee9ddae177 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3.py +++ b/tensorflow/python/tpu/tpu_embedding_v3.py @@ -84,6 +84,7 @@ class SparseCoreEmbeddingConfig: max_unique_ids_per_table: Optional[Dict[str, int]] = None allow_id_dropping: bool = False initialize_tables_on_host: bool = True + enable_fast_table_initialization: bool = False class EmbeddingPipeliningContext(control_flow_ops.ControlFlowContext): @@ -812,8 +813,17 @@ def _create_variables( ) def table_initialize_fn(shape, dtype, shard_info=None): + # If enable fast table initialization, we will initialize the table + # directly on the device and use the initializer from the first table. + if self._sparse_core_embedding_config.enable_fast_table_initialization: + return stacked_tables[0].initializer( + shape=(shard_info.shape[0], stacked_tables[0].dim), + dtype=dtype, + ) + # Concat all the tables along the first axis. concat_tensors = [] + # Temporary patch, we need to initialize tables with the SC level # sharding. Note that we need to ensure that the vocab size is divisible # by the global number of SC. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt index 6a548287f35ce4..cd5e25b1908d72 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "disable_table_stacking" mtype: "" } + member { + name: "enable_fast_table_initialization" + mtype: "" + } member { name: "initialize_tables_on_host" mtype: "" @@ -28,6 +32,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'disable_table_stacking\', \'max_ids_per_chip_per_sample\', \'max_ids_per_table\', \'max_unique_ids_per_table\', \'allow_id_dropping\', \'initialize_tables_on_host\'], varargs=None, keywords=None, defaults=[\'False\', \'64\', \'None\', \'None\', \'False\', \'True\'], " + argspec: "args=[\'self\', \'disable_table_stacking\', \'max_ids_per_chip_per_sample\', \'max_ids_per_table\', \'max_unique_ids_per_table\', \'allow_id_dropping\', \'initialize_tables_on_host\', \'enable_fast_table_initialization\'], varargs=None, keywords=None, defaults=[\'False\', \'64\', \'None\', \'None\', \'False\', \'True\', \'False\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt index 6a548287f35ce4..cd5e25b1908d72 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "disable_table_stacking" mtype: "" } + member { + name: "enable_fast_table_initialization" + mtype: "" + } member { name: "initialize_tables_on_host" mtype: "" @@ -28,6 +32,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'disable_table_stacking\', \'max_ids_per_chip_per_sample\', \'max_ids_per_table\', \'max_unique_ids_per_table\', \'allow_id_dropping\', \'initialize_tables_on_host\'], varargs=None, keywords=None, defaults=[\'False\', \'64\', \'None\', \'None\', \'False\', \'True\'], " + argspec: "args=[\'self\', \'disable_table_stacking\', \'max_ids_per_chip_per_sample\', \'max_ids_per_table\', \'max_unique_ids_per_table\', \'allow_id_dropping\', \'initialize_tables_on_host\', \'enable_fast_table_initialization\'], varargs=None, keywords=None, defaults=[\'False\', \'64\', \'None\', \'None\', \'False\', \'True\', \'False\'], " } } From 9da44b9d7a6db48304982a519fc2633e171ee789 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Thu, 19 Sep 2024 14:36:57 -0700 Subject: [PATCH 025/483] Fix lint warning. PiperOrigin-RevId: 676567429 --- third_party/xla/xla/literal_test.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index 3ce2f675b1a5a1..8a351e2262b69e 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -812,7 +812,14 @@ template class LiteralUtilTestTemplated : public ::testing::Test {}; using TestedTypes = ::testing::Types; -TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes); +class TestNamer { + public: + template + static std::string GetName(int) { + return ::testing::internal::GetTypeName(); + } +}; +TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes, TestNamer); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. From b472d6bf0b9d3dc2b141baa3b644e4527b296a31 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 19 Sep 2024 14:43:14 -0700 Subject: [PATCH 026/483] [xla:nfc] Remove a repeated check. PiperOrigin-RevId: 676569686 --- third_party/xla/xla/service/hlo_replication_analysis_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/service/hlo_replication_analysis_test.cc b/third_party/xla/xla/service/hlo_replication_analysis_test.cc index e57e7112226072..401cdccbf45fed 100644 --- a/third_party/xla/xla/service/hlo_replication_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_replication_analysis_test.cc @@ -194,8 +194,6 @@ ENTRY entry { FindInstruction(module.get(), "subtract.2"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "add"), {})); - EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( - FindInstruction(module.get(), "add"), {})); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "replica-id"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( From eda2d321b216990d89ff308c4ae6b19d301c49f0 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Thu, 19 Sep 2024 14:53:40 -0700 Subject: [PATCH 027/483] [xla] Rename "original_value" attribute in an HLO instruction to "origin" Rename the attribute in HLO IR to make it clear the attribute is referring to the symbolic value produced in the input HLO module, instead of an actual runtime value. PiperOrigin-RevId: 676573775 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 2 +- .../xla/service/add_original_value_test.cc | 22 ++++++------- third_party/xla/xla/service/hlo_parser.cc | 4 +-- .../xla/xla/service/hlo_parser_test.cc | 14 ++++---- .../service/propagate_original_value_test.cc | 32 +++++++++---------- 5 files changed, 37 insertions(+), 37 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 37d7a39d8ee0e0..2860693f88b2c0 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -3681,7 +3681,7 @@ void HloInstruction::PrintWithCanonicalNameMap( PrintExtraAttributes(attr_printer, options); if (original_value_) { - printer->Append(", original_value={"); + printer->Append(", origin={"); printer->Append(OriginalValueToString(*original_value())); printer->Append("}"); } diff --git a/third_party/xla/xla/service/add_original_value_test.cc b/third_party/xla/xla/service/add_original_value_test.cc index f69ba94cba440e..ecfbe354b52c84 100644 --- a/third_party/xla/xla/service/add_original_value_test.cc +++ b/third_party/xla/xla/service/add_original_value_test.cc @@ -68,11 +68,11 @@ ENTRY test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]{0}), f32[2,3 )"; RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( -CHECK: %[[V1:.*]] = f32[] parameter(0), original_value={{[{]}}{"[[V1]]"} -CHECK: %[[V2:.*]] = f32[3]{0} parameter(1), original_value={{[{]}}{"[[V2]]"} -CHECK: %[[TUPLE:.*]] = (f32[], f32[3]{0}) tuple(%[[V1]], %[[V2]]), original_value={({"[[V1]]"}, {"[[V2]]"})} -CHECK: %[[V3:.*]] = f32[2,3]{1,0} parameter(2), original_value={{[{]}}{"[[V3]]"} -CHECK: ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple(%[[TUPLE]], %[[V3]]), original_value={(({"v1"}, {"v2"}), {"v3"})} +CHECK: %[[V1:.*]] = f32[] parameter(0), origin={{[{]}}{"[[V1]]"} +CHECK: %[[V2:.*]] = f32[3]{0} parameter(1), origin={{[{]}}{"[[V2]]"} +CHECK: %[[TUPLE:.*]] = (f32[], f32[3]{0}) tuple(%[[V1]], %[[V2]]), origin={({"[[V1]]"}, {"[[V2]]"})} +CHECK: %[[V3:.*]] = f32[2,3]{1,0} parameter(2), origin={{[{]}}{"[[V3]]"} +CHECK: ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple(%[[TUPLE]], %[[V3]]), origin={(({"v1"}, {"v2"}), {"v3"})} )"); } @@ -90,10 +90,10 @@ ENTRY test { )"; RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( -CHECK: %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), original_value={{[{]}}{"[[CONSTANT1]]"} -CHECK: %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), original_value={{[{]}}{"[[CONSTANT2]]"} -CHECK: %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), original_value={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})} -CHECK: s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, original_value={{[{]}}{"[[CONSTANT2]]"} +CHECK: %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), origin={{[{]}}{"[[CONSTANT1]]"} +CHECK: %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), origin={{[{]}}{"[[CONSTANT2]]"} +CHECK: %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), origin={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})} +CHECK: s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, origin={{[{]}}{"[[CONSTANT2]]"} )"); } @@ -109,8 +109,8 @@ ENTRY test { )"; RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( -CHECK: %[[PARAM:.*]] = (f32[], s32[]) parameter(0), original_value={({"p" {0}{{[}]}}, {"p" {1}})} -CHECK: s32[] get-tuple-element(%[[PARAM]]), index=1, original_value={{[{]}}{"[[PARAM]]" {1} +CHECK: %[[PARAM:.*]] = (f32[], s32[]) parameter(0), origin={({"p" {0}{{[}]}}, {"p" {1}})} +CHECK: s32[] get-tuple-element(%[[PARAM]]), index=1, origin={{[{]}}{"[[PARAM]]" {1} )"); } diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index 9ec2d0f1510301..7977e011fda528 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -1392,8 +1392,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, &predecessors}; optional> original_value; - attrs["original_value"] = {/*required=*/false, AttrTy::kOriginalValue, - &original_value}; + attrs["origin"] = {/*required=*/false, AttrTy::kOriginalValue, + &original_value}; optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index d6783d516807fb..e7c52987492fe0 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -1530,11 +1530,11 @@ ENTRY %test (p: f32[100]) -> u32[100] { R"(HloModule test, entry_computation_layout={(f32[], f32[3]{0}, f32[2,3]{1,0})->((f32[], f32[3]{0}), f32[2,3]{1,0})} ENTRY %test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]), f32[2,3]) { - %v1 = f32[] parameter(0), original_value={{"v1"}} - %v2 = f32[3]{0} parameter(1), original_value={{"v2"}} - %tuple = (f32[], f32[3]{0}) tuple(f32[] %v1, f32[3]{0} %v2), original_value={({"v1"}, {"v2"})} - %v3 = f32[2,3]{1,0} parameter(2), original_value={{"v3"}} - ROOT %nested_tuple = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) %tuple, f32[2,3]{1,0} %v3), original_value={(({"v1"}, {"v2"}), {"v3"})} + %v1 = f32[] parameter(0), origin={{"v1"}} + %v2 = f32[3]{0} parameter(1), origin={{"v2"}} + %tuple = (f32[], f32[3]{0}) tuple(f32[] %v1, f32[3]{0} %v2), origin={({"v1"}, {"v2"})} + %v3 = f32[2,3]{1,0} parameter(2), origin={{"v3"}} + ROOT %nested_tuple = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) %tuple, f32[2,3]{1,0} %v3), origin={(({"v1"}, {"v2"}), {"v3"})} } )" @@ -5537,8 +5537,8 @@ TEST_F(HloParserTest, OriginalValueWithoutShape) { const std::string hlo_string = R"(HloModule test ENTRY %test { - %a = f32[2,10]{1,0} parameter(0), original_value={{"a"}} - ROOT %v = abs(%a), original_value={{"v"}} + %a = f32[2,10]{1,0} parameter(0), origin={{"a"}} + ROOT %v = abs(%a), origin={{"v"}} } diff --git a/third_party/xla/xla/service/propagate_original_value_test.cc b/third_party/xla/xla/service/propagate_original_value_test.cc index 37340843c2fb5d..764786ceef23dd 100644 --- a/third_party/xla/xla/service/propagate_original_value_test.cc +++ b/third_party/xla/xla/service/propagate_original_value_test.cc @@ -29,17 +29,17 @@ TEST_F(PropagateOriginalValueTest, InstructionFusion) { HloModule test, entry_computation_layout={(s32[]{:T(256)})->u32[2]{0:T(256)}} ENTRY test { - Arg_0 = s32[]{:T(256)} parameter(0), original_value={{"Arg_0"}}, metadata={op_name="seed"} - constant = s32[]{:T(256)} constant(32), original_value={{"constant"}} - shift-right-logical = s32[]{:T(256)} shift-right-logical(Arg_0, constant), original_value={{"shift-right-logical"}} - convert = u32[]{:T(256)} convert(shift-right-logical), original_value={{"convert"}} - bitcast = u32[1]{0:T(256)} bitcast(convert), original_value={{"reshape"}} + Arg_0 = s32[]{:T(256)} parameter(0), origin={{"Arg_0"}}, metadata={op_name="seed"} + constant = s32[]{:T(256)} constant(32), origin={{"constant"}} + shift-right-logical = s32[]{:T(256)} shift-right-logical(Arg_0, constant), origin={{"shift-right-logical"}} + convert = u32[]{:T(256)} convert(shift-right-logical), origin={{"convert"}} + bitcast = u32[1]{0:T(256)} bitcast(convert), origin={{"reshape"}} constant.1 = u32[]{:T(256)} constant(0) pad = u32[2]{0:T(256)} pad(bitcast, constant.1), padding=0_1 - convert.1 = u32[]{:T(256)} convert(Arg_0), original_value={{"convert.1"}} - bitcast.1 = u32[1]{0:T(256)} bitcast(convert.1), original_value={{"reshape.1"}} + convert.1 = u32[]{:T(256)} convert(Arg_0), origin={{"convert.1"}} + bitcast.1 = u32[1]{0:T(256)} bitcast(convert.1), origin={{"reshape.1"}} pad.1 = u32[2]{0:T(256)} pad(bitcast.1, constant.1), padding=1_0 - ROOT add = u32[2]{0:T(256)} add(pad, pad.1), original_value={{"concatenate"}} + ROOT add = u32[2]{0:T(256)} add(pad, pad.1), origin={{"concatenate"}} } )"; @@ -49,19 +49,19 @@ ENTRY test { R"( CHECK: %fused_computation CHECK: %[[PARAM:.*]] = s32[]{:T(256)} parameter(0) -CHECK: %[[CONSTANT:.*]] = s32[]{:T(256)} constant(32), original_value={{[{]}}{"constant"}} -CHECK: %[[SHIFT:.*]] = s32[]{:T(256)} shift-right-logical(%[[PARAM]], %[[CONSTANT]]), original_value={{[{]}}{"shift-right-logical"} -CHECK: %[[CONVERT:.*]] = u32[]{:T(256)} convert(%[[SHIFT]]), original_value={{[{]}}{"convert"} -CHECK: %[[BITCAST:.*]] = u32[1]{0:T(256)} bitcast(%[[CONVERT]]), original_value={{[{]}}{"reshape"} +CHECK: %[[CONSTANT:.*]] = s32[]{:T(256)} constant(32), origin={{[{]}}{"constant"}} +CHECK: %[[SHIFT:.*]] = s32[]{:T(256)} shift-right-logical(%[[PARAM]], %[[CONSTANT]]), origin={{[{]}}{"shift-right-logical"} +CHECK: %[[CONVERT:.*]] = u32[]{:T(256)} convert(%[[SHIFT]]), origin={{[{]}}{"convert"} +CHECK: %[[BITCAST:.*]] = u32[1]{0:T(256)} bitcast(%[[CONVERT]]), origin={{[{]}}{"reshape"} CHECK: %[[CONSTANT1:.*]] = u32[]{:T(256)} constant(0) CHECK: %[[PAD:.*]] = u32[2]{0:T(256)} pad(%[[BITCAST]], %[[CONSTANT1]]), padding=0_1 -CHECK: %[[CONVERT1:.*]] = u32[]{:T(256)} convert(%[[PARAM]]), original_value={{[{]}}{"convert.1"} -CHECK: %[[BITCAST1:.*]] = u32[1]{0:T(256)} bitcast(%[[CONVERT1]]), original_value={{[{]}}{"reshape.1"} +CHECK: %[[CONVERT1:.*]] = u32[]{:T(256)} convert(%[[PARAM]]), origin={{[{]}}{"convert.1"} +CHECK: %[[BITCAST1:.*]] = u32[1]{0:T(256)} bitcast(%[[CONVERT1]]), origin={{[{]}}{"reshape.1"} CHECK: %[[PAD1:.*]] = u32[2]{0:T(256)} pad(%[[BITCAST1]], %[[CONSTANT1]]), padding=1_0 -CHECK: ROOT %[[ADD:.*]] = u32[2]{0:T(256)} add(%[[PAD]], %[[PAD1]]), original_value={{[{]}}{"concatenate"} +CHECK: ROOT %[[ADD:.*]] = u32[2]{0:T(256)} add(%[[PAD]], %[[PAD1]]), origin={{[{]}}{"concatenate"} CHECK: ENTRY %test -CHECK: %Arg_0 = s32[]{:T(256)} parameter(0), original_value={{[{]}}{"Arg_0"} +CHECK: %Arg_0 = s32[]{:T(256)} parameter(0), origin={{[{]}}{"Arg_0"} CHECK: ROOT %fusion = u32[2]{0:T(256)} fusion(%Arg_0), kind=kLoop, calls=%fused_computation )"); } From cf3b94729e62e4e6d81338e14fdf2c5e28ad5261 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 15:06:22 -0700 Subject: [PATCH 028/483] Add `nccl_headers` alias to the content of BUILD file in NCCL repository when NCCL stub is not used. PiperOrigin-RevId: 676578558 --- third_party/nccl/hermetic/nccl_configure.bzl | 9 +++++++++ .../tsl/third_party/nccl/hermetic/nccl_configure.bzl | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/nccl/hermetic/nccl_configure.bzl index 14469acdfc5aa1..c1e49a6b9f1dd2 100644 --- a/third_party/nccl/hermetic/nccl_configure.bzl +++ b/third_party/nccl/hermetic/nccl_configure.bzl @@ -60,6 +60,15 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + cc_library( name = "hermetic_nccl_config", hdrs = ["nccl_config.h"], diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl index 14469acdfc5aa1..c1e49a6b9f1dd2 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl @@ -60,6 +60,15 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + cc_library( name = "hermetic_nccl_config", hdrs = ["nccl_config.h"], From 13389dcee4b9b29a2477d5ac38542f9f1f78d692 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 15:25:52 -0700 Subject: [PATCH 029/483] Disable MLIR bridge for the tests that MLIR bridge silently fails PiperOrigin-RevId: 676585347 --- tensorflow/python/distribute/BUILD | 1 + tensorflow/python/distribute/vars_test.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 782f79e5bae63e..62f544970d4737 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1810,6 +1810,7 @@ distribute_py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", + "//tensorflow/python/framework:config", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py index 5dd2c5a3b1ae4a..4cf07ddd13d958 100644 --- a/tensorflow/python/distribute/vars_test.py +++ b/tensorflow/python/distribute/vars_test.py @@ -31,6 +31,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices @@ -654,7 +655,7 @@ def scatter_update(v): @combinations.generate(ms_combination + tpu_combination) def testScatterOpsWithNoneAggregation(self, distribution): - + config.disable_mlir_bridge() def assert_close(v, op, delta, expect): scatter_op = getattr(v, op) From dc7b71583e95fdebf113098fab1f5df97d597dc2 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Thu, 19 Sep 2024 15:31:49 -0700 Subject: [PATCH 030/483] [HLO Componentization] Create hlo/translate sub-component (Phase I). This CL takes care of 1. Migrating xla/translate --> xla/hlo/translate 2. Setting up build aliases in xla/translate ensuring external dependencies are still satisfied. Phase II will take care of migration of external projects dependencies from xla/translate --> xla/hlo/translate PiperOrigin-RevId: 676587696 --- third_party/xla/xla/hlo/translate/BUILD | 62 ++++ .../xla/xla/hlo/translate/hlo_to_mhlo/BUILD | 262 ++++++++++++++++ .../translate/hlo_to_mhlo/async_importer.cc | 6 +- .../translate/hlo_to_mhlo/async_importer.h | 6 +- .../hlo_to_mhlo/attribute_importer.cc | 2 +- .../hlo_to_mhlo/attribute_importer.h | 104 ++++++ .../hlo_to_mhlo/custom_call_importer.cc | 2 +- .../hlo_to_mhlo/custom_call_importer.h | 6 +- .../hlo_to_mhlo/hlo_function_importer.cc | 12 +- .../hlo_to_mhlo/hlo_function_importer.h | 258 +++++++++++++++ .../hlo_to_mhlo/hlo_module_importer.cc | 6 +- .../hlo_to_mhlo/hlo_module_importer.h | 64 ++++ .../translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc | 4 +- .../translate/hlo_to_mhlo/hlo_to_mlir_hlo.h | 71 +++++ .../translate/hlo_to_mhlo/hlo_utils.cc | 2 +- .../xla/hlo/translate/hlo_to_mhlo/hlo_utils.h | 249 +++++++++++++++ .../translate/hlo_to_mhlo/hlo_utils_test.cc | 2 +- .../hlo_to_mhlo/location_importer.cc | 4 +- .../translate/hlo_to_mhlo/location_importer.h | 6 +- .../hlo_to_mhlo/module_attributes_importer.cc | 6 +- .../hlo_to_mhlo/module_attributes_importer.h | 6 +- .../hlo_to_mhlo/stack_location_utils.cc | 2 +- .../hlo_to_mhlo/stack_location_utils.h | 6 +- .../translate/hlo_to_mhlo/tests/BUILD | 2 +- .../hlo_to_mhlo/tests/attributes.hlo | 24 +- .../hlo_to_mhlo/tests/bool_compare.hlo | 0 .../hlo_to_mhlo/tests/case_conditional.hlo | 0 .../hlo_to_mhlo/tests/composite_call.hlo | 0 .../hlo_to_mhlo/tests/custom_call.hlo | 0 .../hlo_to_mhlo/tests/dynamic_param.hlo | 0 .../hlo_to_mhlo/tests/frontend_attributes.hlo | 0 .../tests/fully_connected_reference_model.hlo | 0 .../translate/hlo_to_mhlo/tests/fusion.hlo | 0 .../hlo_to_mhlo/tests/if_conditional.hlo | 0 .../translate/hlo_to_mhlo/tests/import.hlo | 0 .../hlo_to_mhlo/tests/import_async.hlo | 52 +-- .../hlo_to_mhlo/tests/import_async2.hlo | 0 .../tests/import_entry_computation_layout.hlo | 0 .../hlo_to_mhlo/tests/layouts_and_names.hlo | 0 .../translate/hlo_to_mhlo/tests/location.hlo | 0 .../hlo_to_mhlo/tests/module_attributes.hlo | 0 .../hlo_to_mhlo/tests/module_config.hlo | 0 .../translate/hlo_to_mhlo/tests/simple.hlo | 0 .../tests/spmd_module_sharding.hlo | 0 .../tests/stacktrace_to_location.hlo | 0 .../translate/hlo_to_mhlo/tests/types.hlo | 0 .../translate/hlo_to_mhlo/tests/while.hlo | 0 .../translate/hlo_to_mhlo/translate.cc | 4 +- .../xla/hlo/translate/hlo_to_mhlo/translate.h | 88 ++++++ .../hlo_to_mhlo/translate_registration.cc | 2 +- .../xla/xla/hlo/translate/mhlo_to_hlo/BUILD | 296 ++++++++++++++++++ .../mhlo_to_hlo/attribute_exporter.cc | 2 +- .../mhlo_to_hlo/attribute_exporter.h | 75 +++++ .../translate/mhlo_to_hlo/layout_util.cc | 2 +- .../hlo/translate/mhlo_to_hlo/layout_util.h | 85 +++++ .../mhlo_to_hlo/location_exporter.cc | 4 +- .../translate/mhlo_to_hlo/location_exporter.h | 44 +++ .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 16 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.h | 96 ++++++ .../mhlo_to_hlo/mlir_hlo_to_hlo_test.cc | 2 +- .../mhlo_to_hlo/module_attributes_exporter.cc | 2 +- .../mhlo_to_hlo/module_attributes_exporter.h | 6 +- .../mhlo_to_hlo/operator_writer_gen.cc | 0 .../mhlo_to_hlo/stack_frame_index_builder.cc | 2 +- .../mhlo_to_hlo/stack_frame_index_builder.h | 6 +- .../translate/mhlo_to_hlo/tests/BUILD | 2 +- .../translate/mhlo_to_hlo/tests/add.mlir | 0 .../mhlo_to_hlo/tests/attributes.mlir | 0 .../translate/mhlo_to_hlo/tests/case.mlir | 0 .../mhlo_to_hlo/tests/composite.mlir | 0 .../translate/mhlo_to_hlo/tests/dynamic.mlir | 0 .../tests/export-with-layouts.mlir | 0 .../translate/mhlo_to_hlo/tests/export.mlir | 0 .../tests/export_and_check_layouts.mlir | 0 .../mhlo_to_hlo/tests/export_async.mlir | 0 .../export_entry_computation_layout.mlir | 0 .../tests/export_large_constants.mlir | 0 .../mhlo_to_hlo/tests/export_replicas.mlir | 0 .../tests/frontend_attributes.mlir | 0 .../translate/mhlo_to_hlo/tests/fusion.mlir | 0 .../translate/mhlo_to_hlo/tests/if.mlir | 0 .../tests/input_output_aliasing.mlir | 0 .../translate/mhlo_to_hlo/tests/int4.mlir | 0 .../mhlo_to_hlo/tests/layouts_and_names.mlir | 0 .../tests/location_to_op_metadata.mlir | 0 .../tests/location_to_stacktrace.mlir | 0 .../mhlo_to_hlo/tests/missing_main.mlir | 0 .../mhlo_to_hlo/tests/module_attributes.mlir | 0 .../mhlo_to_hlo/tests/module_config.mlir | 0 .../tests/multiple_return_tuple.mlir | 0 .../tests/opaque_elements_attr.mlir | 0 .../tests/rng_get_and_update_state.mlir | 0 .../translate/mhlo_to_hlo/tests/sharding.mlir | 0 .../translate/mhlo_to_hlo/tests/simple.mlir | 0 .../mhlo_to_hlo/tests/unsupported_type.mlir | 0 .../translate/mhlo_to_hlo/tests/while.mlir | 0 .../mhlo_to_hlo/tests/while_free_vars.mlir | 0 .../translate/mhlo_to_hlo/translate.cc | 6 +- .../xla/hlo/translate/mhlo_to_hlo/translate.h | 50 +++ .../mhlo_to_hlo/translate_registration.cc | 4 +- .../mhlo_to_hlo/translate_registration.h | 6 +- .../translate/mhlo_to_hlo/type_to_shape.cc | 3 +- .../hlo/translate/mhlo_to_hlo/type_to_shape.h | 31 ++ .../mhlo_to_hlo/type_to_shape_test.cc | 4 +- .../xla/hlo/translate/stablehlo_to_hlo/BUILD | 46 +++ .../translate/stablehlo_to_hlo/tests/BUILD | 2 +- .../stablehlo_to_hlo/tests/simple.mlir | 0 .../translate/stablehlo_to_hlo/translate.cc | 4 +- .../translate/stablehlo_to_hlo/translate.h | 50 +++ .../translate_registration.cc | 4 +- .../{ => hlo}/translate/xla_translate_main.cc | 0 .../translate/xla_translate_opt_main.cc | 0 third_party/xla/xla/translate/BUILD | 54 +--- .../xla/xla/translate/hlo_to_mhlo/BUILD | 219 ++----------- .../hlo_to_mhlo/attribute_importer.h | 86 +---- .../hlo_to_mhlo/hlo_function_importer.h | 240 +------------- .../hlo_to_mhlo/hlo_module_importer.h | 46 +-- .../translate/hlo_to_mhlo/hlo_to_mlir_hlo.h | 53 +--- .../xla/xla/translate/hlo_to_mhlo/hlo_utils.h | 229 +------------- .../xla/xla/translate/hlo_to_mhlo/translate.h | 70 +---- .../xla/xla/translate/mhlo_to_hlo/BUILD | 259 ++------------- .../mhlo_to_hlo/attribute_exporter.h | 57 +--- .../xla/translate/mhlo_to_hlo/layout_util.h | 65 +--- .../translate/mhlo_to_hlo/location_exporter.h | 26 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.h | 78 +---- .../xla/xla/translate/mhlo_to_hlo/translate.h | 32 +- .../xla/translate/mhlo_to_hlo/type_to_shape.h | 13 +- .../xla/xla/translate/stablehlo_to_hlo/BUILD | 31 +- .../translate/stablehlo_to_hlo/translate.h | 32 +- 129 files changed, 2138 insertions(+), 1622 deletions(-) create mode 100644 third_party/xla/xla/hlo/translate/BUILD create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/async_importer.cc (99%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/async_importer.h (95%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/attribute_importer.cc (99%) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/custom_call_importer.cc (99%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/custom_call_importer.h (89%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/hlo_function_importer.cc (99%) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/hlo_module_importer.cc (95%) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc (96%) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/hlo_utils.cc (99%) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/hlo_utils_test.cc (97%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/location_importer.cc (93%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/location_importer.h (85%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/module_attributes_importer.cc (98%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/module_attributes_importer.h (92%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/stack_location_utils.cc (96%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/stack_location_utils.h (85%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/BUILD (96%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/attributes.hlo (90%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/bool_compare.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/case_conditional.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/composite_call.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/custom_call.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/dynamic_param.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/frontend_attributes.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/fusion.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/if_conditional.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/import.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/import_async.hlo (78%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/import_async2.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/layouts_and_names.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/location.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/module_attributes.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/module_config.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/simple.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/types.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/tests/while.hlo (100%) rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/translate.cc (97%) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h rename third_party/xla/xla/{ => hlo}/translate/hlo_to_mhlo/translate_registration.cc (98%) create mode 100644 third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/attribute_exporter.cc (99%) create mode 100644 third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/layout_util.cc (98%) create mode 100644 third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.h rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/location_exporter.cc (97%) create mode 100644 third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.h rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc (99%) create mode 100644 third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc (97%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/module_attributes_exporter.cc (99%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/module_attributes_exporter.h (89%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/operator_writer_gen.cc (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/stack_frame_index_builder.cc (98%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/stack_frame_index_builder.h (88%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/BUILD (97%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/add.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/attributes.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/case.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/composite.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/dynamic.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/export-with-layouts.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/export.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/export_async.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/export_large_constants.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/export_replicas.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/frontend_attributes.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/fusion.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/if.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/int4.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/layouts_and_names.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/missing_main.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/module_attributes.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/module_config.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/sharding.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/simple.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/unsupported_type.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/while.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/tests/while_free_vars.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/translate.cc (98%) create mode 100644 third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/translate_registration.cc (95%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/translate_registration.h (91%) rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/type_to_shape.cc (99%) create mode 100644 third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.h rename third_party/xla/xla/{ => hlo}/translate/mhlo_to_hlo/type_to_shape_test.cc (98%) create mode 100644 third_party/xla/xla/hlo/translate/stablehlo_to_hlo/BUILD rename third_party/xla/xla/{ => hlo}/translate/stablehlo_to_hlo/tests/BUILD (93%) rename third_party/xla/xla/{ => hlo}/translate/stablehlo_to_hlo/tests/simple.mlir (100%) rename third_party/xla/xla/{ => hlo}/translate/stablehlo_to_hlo/translate.cc (96%) create mode 100644 third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.h rename third_party/xla/xla/{ => hlo}/translate/stablehlo_to_hlo/translate_registration.cc (95%) rename third_party/xla/xla/{ => hlo}/translate/xla_translate_main.cc (100%) rename third_party/xla/xla/{ => hlo}/translate/xla_translate_opt_main.cc (100%) diff --git a/third_party/xla/xla/hlo/translate/BUILD b/third_party/xla/xla/hlo/translate/BUILD new file mode 100644 index 00000000000000..9fd8ced45a9bd5 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/BUILD @@ -0,0 +1,62 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("//xla:xla.bzl", "xla_cc_binary") +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), + licenses = ["notice"], +) + +build_test( + name = "xla-translate_build_test", + targets = [ + ":xla-translate", + ], +) + +xla_cc_binary( + name = "xla-translate", + testonly = True, + srcs = ["xla_translate_main.cc"], + deps = [ + "//xla/hlo/translate/hlo_to_mhlo:translate_registration", + "//xla/hlo/translate/mhlo_to_hlo:translate_registration", + "//xla/hlo/translate/stablehlo_to_hlo:translate_registration", + "//xla/service/cpu:cpu_compiler", + "//xla/service/cpu:cpu_transfer_manager", + "//xla/stream_executor/host:host_platform", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:platform_port", + ], +) + +build_test( + name = "xla-translate-opt_build_test", + targets = [ + ":xla-translate-opt", + ], +) + +xla_cc_binary( + name = "xla-translate-opt", + testonly = True, + srcs = ["xla_translate_opt_main.cc"], + deps = [ + "//xla/mlir/framework/ir:xla_framework", + "//xla/mlir/framework/transforms:passes", + "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/service:cpu_plugin", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + "@local_tsl//tsl/platform:platform_port", + "@stablehlo//:register", + ], +) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD new file mode 100644 index 00000000000000..fb4f7b9fc662fd --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -0,0 +1,262 @@ +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), + licenses = ["notice"], +) + +cc_library( + name = "attribute_importer", + srcs = ["attribute_importer.cc"], + hdrs = ["attribute_importer.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/mlir_hlo", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "async_importer", + srcs = ["async_importer.cc"], + hdrs = ["async_importer.h"], + deps = [ + ":attribute_importer", + ":hlo_utils", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "custom_call_importer", + srcs = ["custom_call_importer.cc"], + hdrs = ["custom_call_importer.h"], + deps = [ + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "stack_location_utils", + srcs = ["stack_location_utils.cc"], + hdrs = ["stack_location_utils.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "hlo_function_importer", + srcs = ["hlo_function_importer.cc"], + hdrs = ["hlo_function_importer.h"], + deps = [ + ":async_importer", + ":attribute_importer", + ":custom_call_importer", + ":hlo_utils", + ":location_importer", + "//xla:comparison_util", + "//xla:literal", + "//xla:protobuf_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "hlo_module_importer", + srcs = [ + "hlo_module_importer.cc", + ], + hdrs = [ + "hlo_module_importer.h", + ], + deps = [ + ":hlo_function_importer", + ":module_attributes_importer", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "hlo_to_mlir_hlo", + srcs = ["hlo_to_mlir_hlo.cc"], + hdrs = ["hlo_to_mlir_hlo.h"], + deps = [ + ":hlo_module_importer", + "//xla:status_macros", + "//xla/mlir/utils:error_util", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "hlo_utils", + srcs = ["hlo_utils.cc"], + hdrs = ["hlo_utils.h"], + includes = ["include"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/mlir/utils:type_util", + "//xla/mlir_hlo", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:SparseTensorEnums", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "hlo_utils_test", + srcs = ["hlo_utils_test.cc"], + deps = [ + ":hlo_utils", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla/tsl/lib/core:status_test_util", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "location_importer", + srcs = ["location_importer.cc"], + hdrs = ["location_importer.h"], + deps = [ + "stack_location_utils", + "//xla/hlo/ir:hlo", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "module_attributes_importer", + srcs = ["module_attributes_importer.cc"], + hdrs = ["module_attributes_importer.h"], + deps = [ + ":hlo_function_importer", + ":hlo_utils", + "//xla:shape_layout", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service:computation_layout", + "//xla/service:hlo_module_config", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "translate", + srcs = ["translate.cc"], + hdrs = ["translate.h"], + deps = [ + ":hlo_to_mlir_hlo", + "//xla/mlir_hlo", + "//xla/mlir_hlo:mhlo_passes", + "//xla/service:hlo_parser", + "//xla/service:hlo_proto_cc", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:protobuf", + ], +) + +cc_library( + name = "translate_registration", + testonly = True, + srcs = ["translate_registration.cc"], + deps = [ + ":translate", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TranslateLib", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc similarity index 99% rename from third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc index 57bc78a0ead971..f3cb7b09c9a80e 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/async_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/async_importer.h" #include #include @@ -34,9 +34,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h similarity index 95% rename from third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h index efdd487c21f03d..906f9235f28498 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ #include @@ -85,4 +85,4 @@ absl::StatusOr ImportAsyncOpDone( } // namespace xla -#endif // XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc similarity index 99% rename from third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc index 7e3ea9b3d9e282..b80645a7d8925f 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h new file mode 100644 index 00000000000000..6a54b864e38f0d --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h @@ -0,0 +1,104 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Converts an XLA PrecisionConfig to the corresponding MLIR attribute. +mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, + mlir::Builder* builder); + +// Converts the gather dimensions to attributes. +mlir::mhlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( + const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the scatter dimensions to attributes. +mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the dot algorithm to attributes. +mlir::mhlo::DotAlgorithmAttr ConvertDotAlgorithm( + PrecisionConfig::Algorithm algorithm, mlir::Builder* builder); + +// Converts the dot dimensions to attributes. +mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( + const DotDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the conv dimensions to attributes. +mlir::mhlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers( + const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the output operand aliasing to attributes. +mlir::ArrayAttr ConvertOutputOperandAliasing( + const std::vector>>& aliaInfo, + mlir::Builder* builder); + +// Converts the sparsity descriptor to attributes. +absl::StatusOr ConvertSparsityDescriptor( + xla::SparsityDescriptor sparsity_descriptor, mlir::Builder* builder); + +absl::StatusOr ConvertFftType(FftType type); +absl::StatusOr ConvertTranspose( + TriangularSolveOptions_Transpose transpose); + +absl::StatusOr ConvertCustomCallApiVersion( + xla::CustomCallApiVersion api_version); + +mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, + mlir::Builder* builder); +mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, + mlir::Builder* builder); + +mlir::NamedAttribute ConvertReplicaGroups( + absl::Span replica_groups, mlir::Builder* builder); + +mlir::NamedAttribute ConvertSourceTargetPairs( + const std::vector>& source_target_pairs, + mlir::Builder* builder); + +mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder); + +// Extracts layouts from shapes and converts it into layout attributes (array of +// rank-1 index tensors). Returns an error if any of the shapes is a tuple. +absl::StatusOr ExtractLayoutsFromShapes( + const absl::Span shapes_with_layouts, mlir::Builder* builder); + +// Extracts the layouts of each element from a tuple shape and returns them as +// an array of rank-1 index tensors. Returns an error in presence of nested +// tuple shapes. +absl::StatusOr ExtractLayoutsFromTuple(const xla::Shape shape, + mlir::Builder* builder); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc similarity index 99% rename from third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc index 24f69b8ce5c595..962be80037c968 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/custom_call_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h" #include #include diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h similarity index 89% rename from third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h index 8ccf85c77b5f89..92424e85dd356f 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ #include "absl/status/statusor.h" #include "mlir/IR/Builders.h" @@ -42,4 +42,4 @@ bool IsOpEncodedCustomCall(const HloCustomCallInstruction* instruction); } // namespace xla -#endif // XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc similarity index 99% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 392794277900cc..b147a100900d42 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #include #include @@ -63,6 +63,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/hlo/translate/hlo_to_mhlo/async_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/location_importer.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -70,11 +75,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/translate/hlo_to_mhlo/async_importer.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/translate/hlo_to_mhlo/custom_call_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/hlo_to_mhlo/location_importer.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h new file mode 100644 index 00000000000000..c65c41e5bd9269 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h @@ -0,0 +1,258 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/hlo.pb.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class HloModule; +class HloComputation; +class HloInstruction; +class Shape; + +// HLO bounded dynamic shapes can be converted to either MLIR dynamic shapes +// (which lose the bound information) or casted to static shape using the +// bounds. +enum class DynamicShapeHandlingMode { kDynamic, kConvertToStatic }; + +// Helper class for importing HloComputations. +class HloFunctionImporter { + public: + // Imports the given computation as a function in the given symbol table and + // returns the FuncOp. This also imports any computations referred by + // instructions in this computation. + static absl::StatusOr ImportAsFunc( + const HloComputation& computation, mlir::SymbolTable& symbol_table, + std::unordered_map* + function_map, + mlir::Builder* builder, bool is_main, + bool flatten_computation_args_result = false); + + // Imports the given hlo computation to the specified region. + // + // Flattens the tuple-typed region argument(s) and return value(s). + static absl::Status ImportAsRegion( + const HloComputation& computation, mlir::SymbolTable& symbol_table, + mlir::Region* region, mlir::Builder* builder, + bool flatten_computation_args_result = false); + + // Imports the given computation to the given place specified by `builder`. + // `arguments` contains values for all parameters. + static absl::StatusOr ImportInstructions( + const HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, + bool flatten_computation_args_result = false); + + static absl::StatusOr ImportInstruction( + const HloInstruction* instr, + const llvm::SmallVectorImpl& operands, + mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, + bool flatten_computation_args_result = false, + DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); + + static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape, + llvm::StringRef attr_name); + + // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block + // arguments with 'implicit_operands'. Here | implicit_operands | == sum of + // the number of arguments in all the regions in IfOp or CaseOp. + void ReplaceBlockArgumentsWithImplicitOperands( + mlir::Operation* op, llvm::ArrayRef implicit_operands); + + // FlattenTupleType flattens the types in (nested) tuple-type 'type' and + // stores them in 'flattened_types'. + static void FlattenTupleType( + mlir::Type type, llvm::SmallVectorImpl& flattened_types); + + // FlattenTupleValue flattens the values in (nested) tuple-typed 'value' and + // stores them in 'flattened_values'. + static void FlattenTupleValue( + mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Value value, + llvm::SmallVectorImpl& flattened_values); + + // FlattenTupleValues flattens the values in (nested) tuple-typed 'values' and + // returns the flattened values. + static llvm::SmallVector FlattenTupleValues( + mlir::OpBuilder* func_builder, mlir::Location loc, + mlir::ValueRange values, std::optional reserve_size = std::nullopt); + + private: + HloFunctionImporter(mlir::SymbolTable& symbol_table, + std::unordered_map* function_map, + mlir::Builder* builder, + bool flatten_computation_args_result) + : context_(symbol_table.getOp()->getContext()), + symbol_table_(symbol_table), + builder_(builder), + function_map_(function_map), + flatten_computation_args_result_(flatten_computation_args_result) { + context_->loadDialect(); + context_->loadDialect(); + context_->loadDialect(); + context_->loadDialect(); + } + + // Imports the given computation as a new function, if it hasn't been already + // imported. + absl::StatusOr ImportAsFunc( + const HloComputation& computation, bool is_main); + + // Imports the given computation in the specified region. + absl::Status ImportAsRegion(const HloComputation& computation, + mlir::Region* region); + + // Imports instructions from the given computation in the specified block. + // Assumes that the block already has correct arguments populated. + absl::Status ImportInstructions(const HloComputation& computation, + mlir::Block* block); + absl::StatusOr ImportInstructionsImpl( + const HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::OpBuilder* builder); + + // Imports an instruction. + absl::StatusOr ImportInstructionWithLayout( + const HloInstruction* instruction, + const llvm::SmallVectorImpl& operands, + mlir::OpBuilder* func_builder, + DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); + + absl::StatusOr ImportInstructionImpl( + const HloInstruction* instruction, + const llvm::SmallVectorImpl& operands, + mlir::OpBuilder* func_builder, + DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); + + // Gets the MLIR operand values from an HLO Instruction. + absl::StatusOr> GetOperands( + const HloInstruction* instruction); + + // Converts xla Tensor type to the corresponding MLIR type. + absl::StatusOr ConvertTensorType(const Shape& shape); + + // Converts an XLA shape/layout to the corresponding MLIR layout, in + // flattened_attr, while flattening the tuple layout. + absl::Status ConvertShapeToMlirLayout( + const Shape& shape, + llvm::SmallVectorImpl& flattened_attr); + + // Returns the output type of an HloInstruction. + absl::StatusOr GetReturnType(const HloInstruction* instruction); + + // Takes a list of HloInstructions and generates the list of types used for + // input, bypassing tuples to subsets. + absl::Status GetMlirTypes( + absl::Span instructions, + llvm::SmallVectorImpl* types); + + // Returns the Mlir Value for the corresponding HloInstruction. + absl::StatusOr GetMlirValue(const HloInstruction* instruction); + + // TODO(b/179166199): Move attribute converters to attribute_importer. + // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertComparisonDirection( + ComparisonDirection direction); + + // Converts an XLA Comparison::Type to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); + + // Converts an XLA CustomCallSchedule to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertCustomCallSchedule(CustomCallSchedule schedule); + + // Converts the dimensions of an HLO instruction into an MLIR attribute. + mlir::DenseIntElementsAttr ConvertDimensions( + absl::Span op_dimensions); + + // Converts Array ref to an DenseIntElementsAttr. + mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); + + // Converts Array ref of bools to a DenseIntElementsAttr of I1 type. + mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); + + // Converts Array ref to padding attribute. Input is a flattened list of + // padding low and padding high for each of the spatial dimensions. + mlir::NamedAttribute ConvertPadding(llvm::ArrayRef padding); + + mlir::MLIRContext* context_; + + // SymbolTable to which new functions should be inserted. + mlir::SymbolTable& symbol_table_; + + mlir::Builder* builder_; + + // Mapping from HloComputation to the created MLIR function. + std::unordered_map* function_map_; + + // Mapping from HloInstructions to the associative MLIR values. + std::unordered_map instruction_value_map_; + + bool flatten_computation_args_result_; +}; + +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO C++ input_output_alias_config. +// Always succeeds and returns a non-empty attribute. +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder); + +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO C++ sharding. +// Always succeeds and returns a non-empty attribute. +mlir::Attribute ConvertSharding(const HloSharding& sharding, + mlir::Builder* builder); + +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO proto sharding. +// Will fail and return an empty attribute if the proto sharding cannot be +// converted to the C++ sharding. +mlir::Attribute ConvertSharding(const OpSharding& sharding, + mlir::Builder* builder); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc similarity index 95% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc index 8ad9d3844438e7..00e591751913fa 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/hlo_module_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h" #include @@ -27,9 +27,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/module_attributes_importer.h" #include "xla/xla.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h new file mode 100644 index 00000000000000..8937f673035a23 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h @@ -0,0 +1,64 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ + +#include + +#include "absl/status/status.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "xla/xla_data.pb.h" + +namespace xla { +class HloModule; +class HloModuleProto; +class HloComputation; +class HloInstruction; +class Shape; + +// Importer that takes an HloModule and imports it as an MLIR module in the XLA +// dialect. HloModuleImporter does not take ownership. +class HloModuleImporter { + public: + explicit HloModuleImporter(mlir::ModuleOp module, + bool import_all_computation = false, + bool flatten_computation_args_result = false); + + // Import the HloModule into the MLIR Module. + absl::Status Import(const xla::HloModule& module); + + // Import the HloModuleProto into the MLIR Module. + absl::Status Import(const xla::HloModuleProto& module); + + private: + bool import_all_computation_; + bool flatten_computation_args_result_; + mlir::SymbolTable symbol_table_; + mlir::Builder builder_; + + // Map for tracking which MLIR function map to which HLO Computation. This + // tracks functions as they are imported and provides a quick lookup for + // functions invoked by control flow related operations (e.g. while, call). + std::unordered_map + function_map_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc similarity index 96% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc index d6dafe01300c82..d9d50f47e13448 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "absl/status/statusor.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h" #include "xla/mlir/utils/error_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/status_macros.h" -#include "xla/translate/hlo_to_mhlo/hlo_module_importer.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h new file mode 100644 index 00000000000000..2489106527569b --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" + +namespace mlir { +class ModuleOp; +} // namespace mlir + +namespace xla { +class HloModule; +class HloModuleProto; + +// Converts an HLO module proto to a MLIR module in HLO dialect. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +absl::StatusOr> ConvertHloToMlirHlo( + mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, + xla::HloModuleProto const* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +// Converts an HLO module to a MLIR module in HLO dialect. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +absl::StatusOr> ConvertHloToMlirHlo( + mlir::MLIRContext& ctx, const xla::HloModule* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, + const xla::HloModule* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc similarity index 99% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index e6004cfe5291d6..564440ac00edcb 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -15,7 +15,7 @@ limitations under the License. // This file defines helpers useful when creating or manipulating lhlo/hlo. -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include #include diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h new file mode 100644 index 00000000000000..3e9cf589d84797 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h @@ -0,0 +1,249 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines helpers useful when creating or manipulating lhlo/hlo. + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/mlir/utils/type_util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +absl::StatusOr CreateDenseElementsAttrFromLiteral( + const LiteralBase& literal, mlir::Builder builder); + +// Creates an DenseIntElementsAttr using the elements of the vector and the +// optional shape. +mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( + const llvm::ArrayRef vector, mlir::Builder builder, + llvm::ArrayRef shape = {}); + +// Converts the given XLA shape for tensors to the template MLIR type. +template +static absl::StatusOr ConvertTensorShapeToType(const Shape& xla_ty, + mlir::Builder builder) { + auto element_type_or = + ConvertPrimitiveTypeToMlirType(xla_ty.element_type(), builder); + if (!element_type_or.ok()) return element_type_or.status(); + + bool is_bounded_dynamic = false; + int64_t rank = xla_ty.rank(); + llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); + llvm::SmallVector bounds(rank, mlir::ShapedType::kDynamic); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t dim_size = xla_ty.dimensions(dim); + if (xla_ty.is_dynamic_dimension(dim)) { + if (!xla_ty.is_unbounded_dynamic_dimension(dim)) { + bounds[dim] = dim_size; + is_bounded_dynamic = true; + } + } else { + shape[dim] = dim_size; + } + } + using mlir::mhlo::TypeExtensionsAttr; + mlir::Attribute encoding; + if (is_bounded_dynamic) { + encoding = TypeExtensionsAttr::get(builder.getContext(), bounds); + } + + using mlir::sparse_tensor::SparseTensorEncodingAttr; + // TODO(b/238903065): We don't yet support bounded dynamism shapes and + // sparsity at the same time, as we can currently only have one `encoding` on + // a RankedTensorType, and we don't currently have a meet of + // SparseTensorEncodingAttr and TypeExtensionsAttr (which holds bounds). + // + // For example, we wouldn't be able to represent the xla type + // `f32[4,<=4]{1,0:D(D,C)}`. + if (xla_ty.has_layout()) { + auto layout = xla_ty.layout(); + if (LayoutUtil::IsSparse(layout)) { + if (is_bounded_dynamic) + return Unimplemented( + "MHLO doesn't support bounded dynamic shapes for sparse tensors"); + llvm::SmallVector lts; + for (size_t i = 0, e = layout.dim_level_types_size(); i < e; ++i) { + auto dlt = layout.dim_level_type(i); + bool ordered = + i < layout.dim_ordered_size() ? layout.dim_ordered(i) : true; + bool unique = + i < layout.dim_unique_size() ? layout.dim_unique(i) : true; + switch (dlt) { + case DimLevelType::DIM_DENSE: + lts.push_back(*mlir::sparse_tensor::buildLevelType( + mlir::sparse_tensor::LevelFormat::Dense, ordered, unique)); + break; + case DimLevelType::DIM_COMPRESSED: + lts.push_back(*mlir::sparse_tensor::buildLevelType( + mlir::sparse_tensor::LevelFormat::Compressed, ordered, unique)); + break; + case DimLevelType::DIM_SINGLETON: + lts.push_back(*mlir::sparse_tensor::buildLevelType( + mlir::sparse_tensor::LevelFormat::Singleton, ordered, unique)); + break; + case DimLevelType::DIM_LOOSE_COMPRESSED: + lts.push_back(*mlir::sparse_tensor::buildLevelType( + mlir::sparse_tensor::LevelFormat::LooseCompressed, ordered, + unique)); + break; + default: + return InvalidArgument("Unknown DimLevelType from HLO"); + } + } + auto ordering = layout.minor_to_major(); + llvm::SmallVector major_to_minor = {ordering.rbegin(), + ordering.rend()}; + auto id_map = mlir::AffineMap::getPermutationMap(major_to_minor, + builder.getContext()); + // TODO(atondwal): support sizes other than 32 when XLA does + encoding = SparseTensorEncodingAttr::get( + builder.getContext(), lts, id_map, mlir::AffineMap(), 32, 32); + } + } + return TypeT::get(shape, element_type_or.value(), encoding); +} + +absl::StatusOr ConvertTensorShapeToMemRefType( + const Shape& shape, mlir::Builder builder); + +template <> +inline absl::StatusOr ConvertTensorShapeToType( + const Shape& shape, mlir::Builder builder) { + if (shape.is_dynamic()) { + return FailedPrecondition( // NOLINT + "MemRefType don't support dynamic shapes"); + } + return ConvertTensorShapeToMemRefType(shape, builder); +} + +// Converts the given XLA shape to the template MLIR type. +template +static absl::StatusOr ConvertShapeToType(const Shape& shape, + mlir::Builder builder) { + if (shape.IsTuple()) { + llvm::SmallVector contents; + contents.reserve(shape.tuple_shapes_size()); + for (const auto& subtype : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(auto mlir_subtype, + ConvertShapeToType(subtype, builder)); + contents.push_back(mlir_subtype); + } + return builder.getTupleType(contents); + } + if (shape.IsToken()) { + return mlir::mhlo::TokenType::get(builder.getContext()); + } + return ConvertTensorShapeToType(shape, builder); +} + +// CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using +// the non-tuple-typed values in 'flatten_values'. +// +// e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple>, +// The function returns %t2 such that: +// %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple +// %t2 = mhlo.tuple(V1,%t1): (T1,tuple) -> tuple> +// +// Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to +// resp. flatten and create tuples in the exact same order. +// 2. `flatten_values`, initially storing the flattened values, will be +// mutated to a 0-length array by the end of function invocation. +mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, + mlir::ValueRange& flatten_values, mlir::Type type); + +// Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType. +// Otherwise, return 'op'. +mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, + mlir::Location loc, + mlir::Operation* op, mlir::Type type); + +mlir::TypeRange Untuple(const mlir::Type& type); + +static std::pair GetLayoutAttribute( + mlir::Builder& b, const Shape& shape, + std::optional maybe_layout = std::nullopt) { + if (shape.IsTuple()) { + llvm::SmallVector element_attrs; + llvm::SmallVector tile_attrs; + for (const auto& tuple_shape : shape.tuple_shapes()) { + // TODO here we do not dissect the layout of a tuple into sublayouts. + // Presently ShapeLayout cannot represent an explicit layout for a tuple + // type so this should never occur. However, if this function were to + // be used in another context where this assumption were to be lifted. + // users should be aware of this limitation which will use the default + // layout for tuple subshapes. + std::pair inner = + tuple_shape.has_layout() + ? GetLayoutAttribute(b, tuple_shape, tuple_shape.layout()) + : GetLayoutAttribute(b, tuple_shape); + element_attrs.push_back(inner.first); + tile_attrs.push_back(inner.second); + } + return std::make_pair((mlir::Attribute)b.getArrayAttr(element_attrs), + b.getArrayAttr(tile_attrs)); + } + + Layout layout = maybe_layout.value_or( + shape.has_layout() ? shape.layout() + : LayoutUtil::GetDefaultLayoutForShape(shape)); + + llvm::SmallVector vec_of_tiles; + for (const Tile& tile : layout.tiles()) { + llvm::SmallVector tile_vec = {tile.dimensions().begin(), + tile.dimensions().end()}; + vec_of_tiles.push_back(b.getIndexTensorAttr(tile_vec)); + } + llvm::SmallVector layout_vec = {layout.minor_to_major().begin(), + layout.minor_to_major().end()}; + return std::make_pair(b.getIndexTensorAttr(layout_vec), + b.getArrayAttr(vec_of_tiles)); +} + +static bool HasCustomLayout(const Shape& shape) { + if (shape.IsTuple()) { + return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); + } + return shape.has_layout() && !shape.layout().minor_to_major().empty() && + shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); +} + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc similarity index 97% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc index b16e5870e99d79..87b3f685872ee8 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include #include diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/location_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc similarity index 93% rename from third_party/xla/xla/translate/hlo_to_mhlo/location_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc index b39d971141240a..f5b92846b4ed83 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/location_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/location_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/location_importer.h" #include #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" -#include "xla/translate/hlo_to_mhlo/stack_location_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h" namespace mlir { namespace mhlo { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/location_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h similarity index 85% rename from third_party/xla/xla/translate/hlo_to_mhlo/location_importer.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h index 23307e7fe135b7..08c313f478f29c 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/location_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ #include "mlir/IR/Location.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -30,4 +30,4 @@ mlir::Location GenerateInstructionLocation( } // namespace mhlo } // namespace mlir -#endif // XLA_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc similarity index 98% rename from third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc index d445640802dc56..02eb02fd869d8a 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/module_attributes_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h" #include #include @@ -30,6 +30,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/computation_layout.h" @@ -37,8 +39,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h similarity index 92% rename from third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h index b29c09c86e29a4..dfd6f7be702699 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" @@ -63,4 +63,4 @@ void ImportUseAutoSpmdPartitioning(const HloModule& hlo_module, } // namespace xla -#endif // XLA_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc similarity index 96% rename from third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc index d09b8ff1d56f18..ef274b4ede2121 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/stack_location_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h" #include diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h similarity index 85% rename from third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h index 09df9d91d148fb..f5210d558d9152 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" @@ -31,4 +31,4 @@ mlir::Location GetLocationFromFrameIndex(int frame_id, mlir::Builder &builder, } // namespace mhlo } // namespace mlir -#endif // XLA_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD similarity index 96% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD index 8e0d7d63df3551..ba0be262e73ab4 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD @@ -39,7 +39,7 @@ lit_test_suite( ), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/translate:xla-translate", + "//xla/hlo/translate:xla-translate", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/attributes.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/attributes.hlo similarity index 90% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/attributes.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/attributes.hlo index 0e927b07c5ab29..a76befcd04bc13 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/attributes.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/attributes.hlo @@ -7,7 +7,7 @@ HloModule dot_algorithm_f8_f8_f32, entry_computation_layout={(f32[2,2,2]{2,1,0}, ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_any_f8_any_f8_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:1 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_any_f8_any_f8_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:1 offset " source_line=7} } // ----- @@ -20,7 +20,7 @@ HloModule dot_algorithm_f8_f8_f32_fast_accum, entry_computation_layout={(f32[2,2 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_any_f8_any_f8_f32_fast_accum, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:23 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_any_f8_any_f8_f32_fast_accum, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:23 offset " source_line=7} } // ----- @@ -32,7 +32,7 @@ HloModule dot_algorithm_f16_f16_f16, entry_computation_layout={(f32[2,2,2]{2,1,0 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f16_f16_f16, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:45 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f16_f16_f16, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:45 offset " source_line=7} } // ----- @@ -44,7 +44,7 @@ HloModule dot_algorithm_f16_f16_f32, entry_computation_layout={(f32[2,2,2]{2,1,0 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f16_f16_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:67 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f16_f16_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:67 offset " source_line=7} } // ----- @@ -56,7 +56,7 @@ HloModule dot_algorithm_bf16_bf16_bf16, entry_computation_layout={(bf16[2,2,2]{2 ENTRY %main.4 (Arg_0.1: bf16[2,2,2], Arg_1.2: bf16[2,2,2]) -> bf16[2,2,2] { %Arg_0.1 = bf16[2,2,2] parameter(0) %Arg_1.2 = bf16[2,2,2] parameter(1) - ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_bf16, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:89 offset " source_line=7} + ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_bf16, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:89 offset " source_line=7} } // ----- @@ -68,7 +68,7 @@ HloModule dot_algorithm_bf16_bf16_f32, entry_computation_layout={(bf16[2,2,2]{2, ENTRY %main.4 (Arg_0.1: bf16[2,2,2], Arg_1.2: bf16[2,2,2]) -> bf16[2,2,2] { %Arg_0.1 = bf16[2,2,2] parameter(0) %Arg_1.2 = bf16[2,2,2] parameter(1) - ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:111 offset " source_line=7} + ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:111 offset " source_line=7} } // ----- @@ -80,7 +80,7 @@ HloModule dot_algorithm_bf16_bf16_f32_x3, entry_computation_layout={(bf16[2,2,2] ENTRY %main.4 (Arg_0.1: bf16[2,2,2], Arg_1.2: bf16[2,2,2]) -> bf16[2,2,2] { %Arg_0.1 = bf16[2,2,2] parameter(0) %Arg_1.2 = bf16[2,2,2] parameter(1) - ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32_x3, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:133 offset " source_line=7} + ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32_x3, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:133 offset " source_line=7} } // ----- @@ -92,7 +92,7 @@ HloModule dot_algorithm_bf16_bf16_f32_x6, entry_computation_layout={(bf16[2,2,2] ENTRY %main.4 (Arg_0.1: bf16[2,2,2], Arg_1.2: bf16[2,2,2]) -> bf16[2,2,2] { %Arg_0.1 = bf16[2,2,2] parameter(0) %Arg_1.2 = bf16[2,2,2] parameter(1) - ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32_x6, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:155 offset " source_line=7} + ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32_x6, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:155 offset " source_line=7} } // ----- @@ -104,7 +104,7 @@ HloModule dot_algorithm_tf32_tf32_f32, entry_computation_layout={(f32[2,2,2]{2,1 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_tf32_tf32_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:177 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_tf32_tf32_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:177 offset " source_line=7} } // ----- @@ -116,7 +116,7 @@ HloModule dot_algorithm_tf32_tf32_f32_x3, entry_computation_layout={(f32[2,2,2]{ ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_tf32_tf32_f32_x3, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:199 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_tf32_tf32_f32_x3, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:199 offset " source_line=7} } // ----- @@ -128,7 +128,7 @@ HloModule dot_algorithm_f32_f32_f32, entry_computation_layout={(f32[2,2,2]{2,1,0 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f32_f32_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:221 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f32_f32_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:221 offset " source_line=7} } // ----- @@ -140,5 +140,5 @@ HloModule dot_algorithm_f64_f64_f64, entry_computation_layout={(f64[2,2,2]{2,1,0 ENTRY %main.4 (Arg_0.1: f64[2,2,2], Arg_1.2: f64[2,2,2]) -> f64[2,2,2] { %Arg_0.1 = f64[2,2,2] parameter(0) %Arg_1.2 = f64[2,2,2] parameter(1) - ROOT %dot.3 = f64[2,2,2] dot(f64[2,2,2] %Arg_0.1, f64[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f64_f64_f64, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:243 offset " source_line=7} + ROOT %dot.3 = f64[2,2,2] dot(f64[2,2,2] %Arg_0.1, f64[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f64_f64_f64, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:243 offset " source_line=7} } diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/bool_compare.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/bool_compare.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/bool_compare.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/bool_compare.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/case_conditional.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/case_conditional.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/case_conditional.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/case_conditional.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/composite_call.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/composite_call.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/custom_call.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/custom_call.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/custom_call.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/custom_call.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/dynamic_param.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/dynamic_param.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/dynamic_param.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/dynamic_param.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/frontend_attributes.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/frontend_attributes.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/frontend_attributes.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/frontend_attributes.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/fusion.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/fusion.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/fusion.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/fusion.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/if_conditional.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/if_conditional.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/if_conditional.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/if_conditional.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo similarity index 78% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo index 4e9633014b332b..5aa09777f30022 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo @@ -41,8 +41,8 @@ HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}} ENTRY %async_all_gather_test (Arg_0.1: f32[128,32]) -> f32[128,128] { %Arg_0.1 = f32[128,32] parameter(0) - %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16} - ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17} + %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16} + ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17} } // ----- @@ -52,7 +52,7 @@ HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} %region_1.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { %Arg_0.3 = f32[] parameter(0) %Arg_1.4 = f32[] parameter(1) - ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7} + ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7} } // CHECK-LABEL: func.func private @all_reduce_ @@ -63,8 +63,8 @@ HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} // CHECK: mhlo.async_done ENTRY %async_all_reduce_test (Arg_0.1: f32[10]) -> f32[10] { %Arg_0.1 = f32[10] parameter(0) - %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22} - ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23} + %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22} + ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23} } // ----- @@ -79,8 +79,8 @@ HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} // CHECK: mhlo.async_done ENTRY %async_collective_permute_test (Arg_0.1: f32[128,32]) -> f32[128,32] { %Arg_0.1 = f32[128,32] parameter(0) - %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13} - ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14} + %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13} + ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14} } // ----- @@ -89,8 +89,8 @@ HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} ENTRY %async_copy_test (Arg_0.1: f32[128,32]) -> f32[128,32] { %Arg_0.1 = f32[128,32] parameter(0) - %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10} - ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11} + %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10} + ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11} } // ----- @@ -99,10 +99,10 @@ HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])} ENTRY %async_recv_test_tuple (Arg_0.1: token[]) -> (s32[3,4], token[]) { %Arg_0.1 = token[] parameter(0) - %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16} - %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} - %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} - %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16} + %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} ROOT %tuple.6 = (s32[3,4], token[]) tuple(s32[3,4] %get-tuple-element.4, token[] %get-tuple-element.5) } @@ -113,8 +113,8 @@ HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]} ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { %Arg_0.1 = s32[3,4] parameter(0) %Arg_1.2 = token[] parameter(1) - %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16} - ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17} + %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16} + ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17} } @@ -124,18 +124,18 @@ ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { // ENTRY %async_custom_call_test2 (Arg_0.1: f32[10]) -> (f32[20]) { // %Arg_0.1 = f32[10] parameter(0) -// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21} -// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22} -// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23} +// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21} +// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22} +// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23} // } // HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})} // ENTRY %async_custom_call_test (Arg_0.1: f32[10]) -> (f32[20]) { // %Arg_0.1 = f32[10] parameter(0) -// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16} -// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18} -// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20} +// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16} +// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18} +// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20} // } @@ -146,17 +146,17 @@ ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { // HloModule main, entry_computation_layout={(token[])->token[]} // ENTRY %async_send_test_empty (Arg_0.1: token[]) -> token[] { -// %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} +// %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} // %Arg_0.1 = token[] parameter(0) -// %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} -// ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16} +// %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} +// ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16} // } // HloModule main, entry_computation_layout={(token[])->((), token[])} // ENTRY %async_recv_test (Arg_0.1: token[]) -> ((), token[]) { // %Arg_0.1 = token[] parameter(0) -// %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17} -// ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18} +// %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17} +// ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18} // } diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async2.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async2.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/layouts_and_names.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/layouts_and_names.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/layouts_and_names.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/layouts_and_names.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/location.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/location.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/location.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/location.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/module_attributes.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/module_attributes.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/module_config.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/module_config.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/module_config.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/module_config.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/simple.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/simple.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/simple.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/simple.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/types.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/types.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/types.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/types.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/while.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/while.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/while.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/while.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/translate.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc similarity index 97% rename from third_party/xla/xla/translate/hlo_to_mhlo/translate.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc index 362e0e19ad8795..ec89c29cafa0ad 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc @@ -12,19 +12,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/translate.h" +#include "xla/hlo/translate/hlo_to_mhlo/translate.h" #include "absl/status/status.h" #include "llvm/Support/LogicalResult.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/Pass/PassManager.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_parser.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "tsl/platform/protobuf.h" namespace xla { diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h new file mode 100644 index 00000000000000..c07fd6485ccd16 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h @@ -0,0 +1,88 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ + +namespace llvm { +class StringRef; +} // namespace llvm + +namespace mlir { +class MLIRContext; +class ModuleOp; +template +class OwningOpRef; +} // namespace mlir + +namespace xla { + +// Converts a HloModuleProto stored in the file with the given `input_filename` +// into a MHLO module. Creates MLIR entities into the given MLIR `context`. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +mlir::OwningOpRef HloToMlirHloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +// Converts a HloModule stored in text form for a file with the given +// `input_filename` into a MHLO module. Creates MLIR entities into the given +// MLIR `context`. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +mlir::OwningOpRef HloTextToMlirHloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +// Converts a HloModuleProto stored in the file with the given `input_filename` +// into a StableHLO module. Creates MLIR entities into the given MLIR `context`. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +mlir::OwningOpRef HloToStablehloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +// Converts a HloModule stored in text form for a file with the given +// `input_filename` into a StableHLO module. Creates MLIR entities into the +// given MLIR `context`. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +mlir::OwningOpRef HloTextToStablehloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/translate_registration.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc similarity index 98% rename from third_party/xla/xla/translate/hlo_to_mhlo/translate_registration.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc index b3d8f2f97b414b..87aef2b743519a 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/translate_registration.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc @@ -15,7 +15,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/Tools/mlir-translate/Translation.h" -#include "xla/translate/hlo_to_mhlo/translate.h" +#include "xla/hlo/translate/hlo_to_mhlo/translate.h" namespace { // NOLINTNEXTLINE diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD new file mode 100644 index 00000000000000..c529ee12e035cd --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -0,0 +1,296 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_binary", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), + licenses = ["notice"], +) + +cc_library( + name = "attribute_exporter", + srcs = ["attribute_exporter.cc"], + hdrs = ["attribute_exporter.h"], + deps = [ + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service:hlo_parser", + "//xla/service:hlo_proto_cc", + "//xla/stream_executor:dnn", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:base", + ], +) + +cc_library( + name = "layout_util", + srcs = ["layout_util.cc"], + hdrs = ["layout_util.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/client:xla_builder", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "location_exporter", + srcs = ["location_exporter.cc"], + hdrs = ["location_exporter.h"], + deps = [ + ":stack_frame_index_builder", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/log", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "module_attributes_exporter", + srcs = ["module_attributes_exporter.cc"], + hdrs = ["module_attributes_exporter.h"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "stack_frame_index_builder", + srcs = ["stack_frame_index_builder.cc"], + hdrs = ["stack_frame_index_builder.h"], + deps = [ + "//xla/service:hlo_proto_cc", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "mlir_hlo_to_hlo", + srcs = [ + "mlir_hlo_to_hlo.cc", + "operator_writers.inc", + ], + hdrs = ["mlir_hlo_to_hlo.h"], + deps = [ + ":attribute_exporter", + ":layout_util", + ":location_exporter", + ":module_attributes_exporter", + ":operator_writer_inc", + ":stack_frame_index_builder", + ":type_to_shape", + "//xla:array", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/client:xla_builder", + "//xla/client:xla_computation", + "//xla/client/lib:approx_topk", + "//xla/client/lib:approx_topk_shape", + "//xla/client/lib:matrix", + "//xla/client/lib:quantize", + "//xla/client/lib:slicing", + "//xla/hlo/ir:hlo", + "//xla/mlir/utils:error_util", + "//xla/mlir/utils:type_util", + "//xla/mlir_hlo", + "//xla/mlir_hlo:mhlo_passes", + "//xla/service:computation_layout", + "//xla/service:hlo_module_config", + "//xla/service:hlo_parser", + "//xla/service:hlo_proto_cc", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:types", + "@stablehlo//:base", + "@stablehlo//:stablehlo_ops", + ], +) + +build_test( + name = "operator_writer_gen_build_test", + targets = [ + ":operator_writer_gen", + ], +) + +cc_binary( + name = "operator_writer_gen", + srcs = ["operator_writer_gen.cc"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TableGen", + "@llvm-project//mlir:TableGen", + ], +) + +gentbl_cc_library( + name = "operator_writer_inc", + compatible_with = get_compatible_with_portable(), + tbl_outs = [([], "operator_writers.inc")], + tblgen = ":operator_writer_gen", + td_file = "//xla/mlir_hlo:mhlo/IR/hlo_ops.td", + deps = [ + "//xla/mlir_hlo:hlo_ops_td_files", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +xla_cc_test( + name = "mlir_hlo_to_hlo_test", + srcs = ["mlir_hlo_to_hlo_test.cc"], + deps = [ + ":mlir_hlo_to_hlo", + "//xla/mlir/utils:error_util", + "//xla/tsl/lib/core:status_test_util", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:ShapeDialect", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@stablehlo//:register", + ], +) + +cc_library( + name = "translate", + srcs = ["translate.cc"], + hdrs = ["translate.h"], + deps = [ + ":mlir_hlo_to_hlo", + ":type_to_shape", + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla/client:xla_builder", + "//xla/client:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_proto_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "translate_registration", + testonly = True, + srcs = [ + "translate_registration.cc", + "translate_registration.h", + ], + deps = [ + ":translate", + "//xla/mlir_hlo:hlo_dialect_registration", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + ], + alwayslink = 1, +) + +cc_library( + name = "type_to_shape", + srcs = ["type_to_shape.cc"], + hdrs = ["type_to_shape.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/mlir/utils:type_util", + "//xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:SparseTensorEnums", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + +xla_cc_test( + name = "type_to_shape_test", + srcs = ["type_to_shape_test.cc"], + deps = [ + ":type_to_shape", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", + "//xla/mlir_hlo", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc index ae8537283d9148..cd1f3f36d58a6e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" #include diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h new file mode 100644 index 00000000000000..bc8344ce11b01d --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h @@ -0,0 +1,75 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ + +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Support/LLVM.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/dnn.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Converts the conv dimensions attribute to XLA HLO. +ConvolutionDimensionNumbers ConvertConvDimensionNumbers( + mlir::mhlo::ConvDimensionNumbersAttr input); + +// Converts the dot algorithm attribute to XLA HLO. +absl::StatusOr ConvertDotAlgorithm( + mlir::mhlo::DotAlgorithmAttr attr); + +absl::StatusOr> ConvertReplicaGroups( + mlir::DenseIntElementsAttr input); + +// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding +// and source-target pairs are defined in HLO. +absl::StatusOr>> ConvertNx2Attribute( + std::optional optional_attr); + +absl::StatusOr ConvertTranspose( + llvm::StringRef transpose_string); + +absl::StatusOr ConvertCustomCallSchedule( + mlir::mhlo::CustomCallSchedule schedule); + +absl::StatusOr ConvertCustomCallApiVersion( + mlir::mhlo::CustomCallApiVersion api_version); + +absl::StatusOr< + std::vector>>> +ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); + +// Returns an OpSharding that represents the result of parsing the given string: +// first, as serialized protobuf, and then as prettyprinted representation. +// Will fail if both attempts at parsing failed. +std::optional ConvertSharding(mlir::StringRef sharding); + +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing); + +} // namespace xla +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc similarity index 98% rename from third_party/xla/xla/translate/mhlo_to_hlo/layout_util.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc index a07bba4004b59f..ef76ba38c9e8ba 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/layout_util.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" #include "absl/status/status.h" #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.h new file mode 100644 index 00000000000000..ca432b0f0e16eb --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.h @@ -0,0 +1,85 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for working with XLA layout and shapes. + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/client/xla_builder.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace mlir { + +// XLA Layout preferences. Currently, when it comes to TPU, there are two +// primary layout choices for any XLA arguments (parameter or resource): (1) +// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU +// layout while Linear is native host (CPU) layout. +// This enum allows the caller of XLA to propagate layout preference to the XLA +// compiler. +// kNoPreference: the generic layout where the XLA compiler has the freedom +// to assign any layout. +// kTpuPreferCompactChunkPaddedLayout: use native TPU layout on TPU. +// kTpuPreferLinearLayout: use native CPU layout on TPU. The compiler may +// insert transformation TPU kernels. +// As the layout of any argument will change from a native host layout to a +// native TPU layout either on host or on device, XLA compiler and TPU runtime +// must be in coordination to transform the parameters in a consistent way. +enum class XlaLayoutPreference { + kNoPreference = 0, + kTpuPreferCompactChunkPaddedLayout = 1, + kTpuPreferLinearLayout = 2 +}; + +// The following defines the layout preference of an xla tensor. +// The return value of LayoutPreferenceFn can be used in +// ShapeRepresentationFn. +typedef std::function( + const xla::Shape& shape)> + LayoutPreferenceFn; + +typedef std::function( + const xla::Shape& shape, bool fast_mem, + XlaLayoutPreference layout_preference)> + ShapeRepresentationFn; + +// Return a LayoutPreferenceFn that always uses kNoPreference layout. +LayoutPreferenceFn UseNoPreferenceLayoutFn(); + +// Rewrites the layout of xla_shape if there is tiled sharding. +absl::Status RewriteLayoutWithShardedShape( + const std::optional& sharding, bool use_fast_memory, + const LayoutPreferenceFn& layout_preference_fn, + const ShapeRepresentationFn& shape_representation_fn, + xla::Shape* xla_shape); + +// Adds reshapes to fix the layout of an output, if a shape_representation_fn or +// sharding is present. +absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + const LayoutPreferenceFn& layout_preference_fn, + const ShapeRepresentationFn& shape_representation_fn, + std::optional sharding, bool fast_mem); + +} // namespace mlir + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.cc similarity index 97% rename from third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.cc index bb274f6bb99a51..c7f80898a21321 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/location_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" #include @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" +#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" #include "xla/xla_data.pb.h" namespace mlir { diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.h new file mode 100644 index 00000000000000..70ab1d6395076a --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.h @@ -0,0 +1,44 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ + +#include + +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" +#include "xla/xla_data.pb.h" + +namespace mlir { +namespace mhlo { + +// Returns a OpMetadata proto based on the location of the op. If the location +// is unknown, an empty proto is returned. `op_name` are populated with the op +// location (converted). FileLineColLoc locations are populated by taking the +// file name and line number, and populating `source_file` and `source_line` +// respectively. +xla::OpMetadata CreateOpMetadataFromLocation( + Operation* op, StackFrameIndexBuilder* frame_index_builder); + +// Returns a name that can be used for debugging purposes, e.g., naming +// variable names in generated IR or producing logging output. +std::string GetDebugNameFromLocation(Location location); + +} // namespace mhlo +} // namespace mlir + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index db21cfb6095d18..9b201f946ccaa5 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include #include @@ -76,6 +76,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" +#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -91,12 +97,6 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/translate/mhlo_to_hlo/layout_util.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" -#include "xla/translate/mhlo_to_hlo/module_attributes_exporter.h" -#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -3020,7 +3020,7 @@ LogicalResult ExportXlaOp(MinimumBroadcastShapesOp op, OpLoweringContext ctx) { } // namespace mhlo } // namespace mlir -#include "xla/translate/mhlo_to_hlo/operator_writers.inc" +#include "xla/hlo/translate/mhlo_to_hlo/operator_writers.inc" namespace mlir { namespace { diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h new file mode 100644 index 00000000000000..7de22766fcfb0c --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h @@ -0,0 +1,96 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinOps.h" +#include "xla/client/xla_builder.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" + +namespace mlir { + +struct MlirToHloConversionOptions { + // Best-effort propagation of the layouts. These layouts serve as performance + // hints to the backend. + // + // Note that non-array shapes are not carrying layouts, and users have to + // figure out the proper layouts of them through context. This is one of the + // reasons why the attribute-based solution is temporary. + // + // TODO(timshen): Investigate the necessity of having layouts in MHLO. + bool propagate_layouts = false; + + // Propagate the source and result layouts from mhlo bitcast op into the + // backend config for the bitcast. This is required for XLA:GPU backend to + // use elemental IR emitters for fused bitcasts without propagating layouts. + bool propagate_bitcast_layouts_to_backend_config = false; + + LayoutPreferenceFn layout_preference_fn; + ShapeRepresentationFn shape_representation_fn; + + // If use_tuple_args is set, then the entry computations's arguments are + // converted to a tuple and passed as a single parameter. + bool use_tuple_args = false; + + // If return tuple is true, then the entry function's return values + // are converted to a tuple even when there is only a single return value. + // Multiple return values are always converted to a tuple and returned as a + // single value. + bool return_tuple = true; +}; + +// Prefer `ConvertMlirHloToHloModule` over this method when possible, as it +// preserves more information and abstracts away the proto. This method is +// preserved for legacy reasons. +// TODO (b/345806521): Migrate callsites to ConvertMlirHloToHloModule, +// and delete this method. +// +// Converts a MLIR module in HLO dialect into a HloModuleProto. +// +absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, + ::xla::HloProto* hlo_proto, + bool use_tuple_args, bool return_tuple, + MlirToHloConversionOptions options = {}); + +// Converts a MLIR module in HLO dialect into a HloModule with HloModuleConfig. +// This method preserves config data stored in MHLO module attributes. +// +// See `MlirToHloConversionOptions` for details on conversion flags. +absl::StatusOr> ConvertMlirHloToHloModule( + mlir::ModuleOp module, MlirToHloConversionOptions options = {}); + +// Transforms a Block into HLO, where the HLO is represented as calls into an +// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp. +// xla_params are inputs to block. returns are the returned XlaOps. +absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, + llvm::ArrayRef xla_params, + std::vector& returns, + MlirToHloConversionOptions options = {}); + +} // namespace mlir + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc similarity index 97% rename from third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc index 10dd9cec91f529..ad96da29cd4cfc 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.cc similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.cc index 7022a110572d37..afdae4739d79df 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/module_attributes_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h" #include #include diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h similarity index 89% rename from third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.h rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h index ccff0f957e6406..2081f24aee8f97 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ #include "absl/status/status.h" #include "mlir/IR/BuiltinAttributes.h" @@ -48,4 +48,4 @@ absl::Status ExportModuleEntryComputationResultTiles( } // namespace mhlo } // namespace mlir -#endif // XLA_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/operator_writer_gen.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/operator_writer_gen.cc similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/operator_writer_gen.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/operator_writer_gen.cc diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc similarity index 98% rename from third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc index 790f606d6457fe..dc96c4192938c3 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" +#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" #include #include diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h similarity index 88% rename from third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h index db584a3ff58d6a..b8bed27e2ab091 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ #include #include @@ -53,4 +53,4 @@ class StackFrameIndexBuilder { }; } // namespace mlir -#endif // XLA_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD similarity index 97% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD index 70abbacdf0394b..c80c5204195c93 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD @@ -49,7 +49,7 @@ lit_test_suite( cfg = "//xla:lit.cfg.py", data = [":test_utilities"], tools = [ - "//xla/translate:xla-translate", + "//xla/hlo/translate:xla-translate", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/add.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/add.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/add.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/add.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/attributes.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/attributes.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/case.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/case.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/case.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/case.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/dynamic.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/dynamic.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/dynamic.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/dynamic.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export-with-layouts.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export-with-layouts.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_large_constants.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_large_constants.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_large_constants.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_large_constants.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_replicas.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_replicas.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_replicas.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_replicas.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/frontend_attributes.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/frontend_attributes.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/frontend_attributes.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/frontend_attributes.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/fusion.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/fusion.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/fusion.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/fusion.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/if.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/if.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/int4.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/int4.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/int4.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/int4.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/layouts_and_names.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/layouts_and_names.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/layouts_and_names.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/layouts_and_names.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/missing_main.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/missing_main.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/missing_main.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/missing_main.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/module_attributes.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/module_attributes.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/module_config.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/module_config.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/module_config.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/module_config.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/simple.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/simple.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/simple.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/simple.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/unsupported_type.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/unsupported_type.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/unsupported_type.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/unsupported_type.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/while.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc similarity index 98% rename from third_party/xla/xla/translate/mhlo_to_hlo/translate.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc index 7c07582a46c794..38f1cf6c5596fa 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/translate.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" #include #include @@ -42,14 +42,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_proto_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h new file mode 100644 index 00000000000000..064db33984b864 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ + +#include +#include + +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LogicalResult.h" + +namespace xla { + +mlir::LogicalResult MlirHloToHloTranslateFunction(mlir::ModuleOp module, + llvm::raw_ostream& output, + bool emit_return_tuple, + bool emit_use_tuple_arg); + +mlir::LogicalResult MlirHloToHloTextTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, + bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, + bool print_sugar, bool via_builder, bool with_layouts); + +// Translate the MHLO program in in-memory file 'buffer' to a HLO program +// written in a file represented with handle 'output_stream'; +mlir::LogicalResult MlirHloToHloTextMain( + std::unique_ptr buffer, + llvm::raw_ostream& output_stream, bool emit_return_tuple, + bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, + bool print_sugar, bool via_builder, bool with_layouts); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.cc similarity index 95% rename from third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.cc index ec5954af59a25d..ed0be0bee7345e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/translate_registration.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate_registration.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -22,8 +22,8 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/Translation.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" #include "xla/mlir_hlo/mhlo/IR/register.h" -#include "xla/translate/mhlo_to_hlo/translate.h" static mlir::LogicalResult MlirHloToHloTranslate(mlir::ModuleOp module, llvm::raw_ostream& output) { diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.h similarity index 91% rename from third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.h rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.h index 6cd3e8e4fdd898..42c480a15fbc4d 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ #include "llvm/Support/CommandLine.h" @@ -60,4 +60,4 @@ llvm::cl::opt via_builder( "via-builder", llvm::cl::desc("Translate MHLO->XLA HLO via XLA Builder"), llvm::cl::init(false)); -#endif // XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc index 82e79e7ff63197..89a3ab09b9f51e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include #include @@ -53,7 +53,6 @@ using xla::PrimitiveType; namespace xla { - std::optional> ConvertDimLevelType( mlir::sparse_tensor::LevelType lt) { auto f = mlir::sparse_tensor::getLevelFormat(lt); diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.h new file mode 100644 index 00000000000000..eb641ce44e3440 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Types.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape. +Shape TypeToShape(mlir::Type type); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape_test.cc similarity index 98% rename from third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape_test.cc index 9d09c79eeaa507..464a6f21f9dfb6 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include #include @@ -24,11 +24,11 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" diff --git a/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/BUILD new file mode 100644 index 00000000000000..c06de76069abf0 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/BUILD @@ -0,0 +1,46 @@ +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), + licenses = ["notice"], +) + +cc_library( + name = "translate", + srcs = ["translate.cc"], + hdrs = ["translate.h"], + deps = [ + "//xla/hlo/translate/mhlo_to_hlo:translate", + "//xla/mlir_hlo:mhlo_passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@stablehlo//:register", + ], +) + +cc_library( + name = "translate_registration", + testonly = True, + srcs = ["translate_registration.cc"], + deps = [ + ":translate", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + "@stablehlo//:register", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/tests/BUILD b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/BUILD similarity index 93% rename from third_party/xla/xla/translate/stablehlo_to_hlo/tests/BUILD rename to third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/BUILD index c68f2ea17d8423..a9535fa3e35b4a 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/BUILD @@ -19,7 +19,7 @@ lit_test_suite( cfg = "//xla:lit.cfg.py", data = [":test_utilities"], tools = [ - "//xla/translate:xla-translate", + "//xla/hlo/translate:xla-translate", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/tests/simple.mlir b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/simple.mlir similarity index 100% rename from third_party/xla/xla/translate/stablehlo_to_hlo/tests/simple.mlir rename to third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/simple.mlir diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.cc b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc similarity index 96% rename from third_party/xla/xla/translate/stablehlo_to_hlo/translate.cc rename to third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc index 5cf7b516f4229e..ee50467832e872 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/stablehlo_to_hlo/translate.h" +#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" #include #include @@ -29,8 +29,8 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "stablehlo/dialect/Register.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" -#include "xla/translate/mhlo_to_hlo/translate.h" namespace xla { diff --git a/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.h b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.h new file mode 100644 index 00000000000000..c3f0a86cb88340 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ +#define XLA_HLO_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ + +#include +#include + +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LogicalResult.h" + +namespace xla { + +mlir::LogicalResult StablehloToHloTranslateFunction(mlir::ModuleOp module, + llvm::raw_ostream& output, + bool emit_return_tuple, + bool emit_use_tuple_arg); + +mlir::LogicalResult StablehloToHloTextTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, + bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, + bool print_sugar, bool via_builder, bool with_layouts); + +// Translate the StableHLO program in in-memory file 'buffer' to a HLO program +// written in a file represented with handle 'output_stream'; +mlir::LogicalResult StablehloToHloTextMain( + std::unique_ptr buffer, + llvm::raw_ostream& output_stream, bool emit_return_tuple, + bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, + bool print_sugar, bool via_builder, bool with_layouts); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/translate_registration.cc b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate_registration.cc similarity index 95% rename from third_party/xla/xla/translate/stablehlo_to_hlo/translate_registration.cc rename to third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate_registration.cc index 38e827dac3475a..23258c1e212d6e 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/translate_registration.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate_registration.cc @@ -23,10 +23,10 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "stablehlo/dialect/Register.h" -#include "xla/translate/stablehlo_to_hlo/translate.h" +#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" // The following symbols are defined in -// tensorflow/compiler/xla/translate/mhlo_to_hlo/translate_registration.h +// tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/translate_registration.h extern llvm::cl::opt emit_use_tuple_arg; extern llvm::cl::opt emit_return_tuple; extern llvm::cl::opt with_layouts; diff --git a/third_party/xla/xla/translate/xla_translate_main.cc b/third_party/xla/xla/hlo/translate/xla_translate_main.cc similarity index 100% rename from third_party/xla/xla/translate/xla_translate_main.cc rename to third_party/xla/xla/hlo/translate/xla_translate_main.cc diff --git a/third_party/xla/xla/translate/xla_translate_opt_main.cc b/third_party/xla/xla/hlo/translate/xla_translate_opt_main.cc similarity index 100% rename from third_party/xla/xla/translate/xla_translate_opt_main.cc rename to third_party/xla/xla/hlo/translate/xla_translate_opt_main.cc diff --git a/third_party/xla/xla/translate/BUILD b/third_party/xla/xla/translate/BUILD index 4833529d409a24..a09a0a8ac6e142 100644 --- a/third_party/xla/xla/translate/BUILD +++ b/third_party/xla/xla/translate/BUILD @@ -1,5 +1,3 @@ -load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("//xla:xla.bzl", "xla_cc_binary") load("//xla/tsl:tsl.bzl", "internal_visibility") package( @@ -11,52 +9,16 @@ package( licenses = ["notice"], ) -build_test( - name = "xla-translate_build_test", - targets = [ - ":xla-translate", - ], -) - -xla_cc_binary( +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate:xla-translate +# instead. +alias( name = "xla-translate", - testonly = True, - srcs = ["xla_translate_main.cc"], - deps = [ - "//xla/service/cpu:cpu_compiler", - "//xla/service/cpu:cpu_transfer_manager", - "//xla/stream_executor/host:host_platform", - "//xla/translate/hlo_to_mhlo:translate_registration", - "//xla/translate/mhlo_to_hlo:translate_registration", - "//xla/translate/stablehlo_to_hlo:translate_registration", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", - "@local_tsl//tsl/platform:platform_port", - ], -) - -build_test( - name = "xla-translate-opt_build_test", - targets = [ - ":xla-translate-opt", - ], + actual = "//xla/hlo/translate:xla-translate", ) -xla_cc_binary( +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate:xla-translate-opt +# instead. +alias( name = "xla-translate-opt", - testonly = True, - srcs = ["xla_translate_opt_main.cc"], - deps = [ - "//xla/mlir/framework/ir:xla_framework", - "//xla/mlir/framework/transforms:passes", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/service:cpu_plugin", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:MlirOptLib", - "@local_tsl//tsl/platform:platform_port", - "@stablehlo//:register", - ], + actual = "//xla/hlo/translate:xla-translate-opt", ) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD index fb4f7b9fc662fd..3699083550c76f 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD @@ -1,5 +1,4 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//xla:xla.bzl", "xla_cc_test") load("//xla/tsl:tsl.bzl", "internal_visibility") package( @@ -11,252 +10,76 @@ package( licenses = ["notice"], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:attribute_importer # +# instead. cc_library( name = "attribute_importer", - srcs = ["attribute_importer.cc"], hdrs = ["attribute_importer.h"], deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/mlir_hlo", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "async_importer", - srcs = ["async_importer.cc"], - hdrs = ["async_importer.h"], - deps = [ - ":attribute_importer", - ":hlo_utils", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "custom_call_importer", - srcs = ["custom_call_importer.cc"], - hdrs = ["custom_call_importer.h"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "stack_location_utils", - srcs = ["stack_location_utils.cc"], - hdrs = ["stack_location_utils.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:hlo_proto_cc", - "@llvm-project//mlir:IR", + "//xla/hlo/translate/hlo_to_mhlo:attribute_importer", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_function_importer # +# instead. cc_library( name = "hlo_function_importer", - srcs = ["hlo_function_importer.cc"], hdrs = ["hlo_function_importer.h"], deps = [ - ":async_importer", - ":attribute_importer", - ":custom_call_importer", - ":hlo_utils", - ":location_importer", - "//xla:comparison_util", - "//xla:literal", - "//xla:protobuf_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_module_importer # +# instead. cc_library( name = "hlo_module_importer", - srcs = [ - "hlo_module_importer.cc", - ], hdrs = [ "hlo_module_importer.h", ], deps = [ - ":hlo_function_importer", - ":module_attributes_importer", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "@com_google_absl//absl/status", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/translate/hlo_to_mhlo:hlo_module_importer", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo # +# instead. cc_library( name = "hlo_to_mlir_hlo", - srcs = ["hlo_to_mlir_hlo.cc"], hdrs = ["hlo_to_mlir_hlo.h"], deps = [ - ":hlo_module_importer", - "//xla:status_macros", - "//xla/mlir/utils:error_util", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_utils # +# instead. cc_library( name = "hlo_utils", - srcs = ["hlo_utils.cc"], hdrs = ["hlo_utils.h"], includes = ["include"], deps = [ - "//xla:literal", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/mlir/utils:type_util", - "//xla/mlir_hlo", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:SparseTensorEnums", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "hlo_utils_test", - srcs = ["hlo_utils_test.cc"], - deps = [ - ":hlo_utils", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla/tsl/lib/core:status_test_util", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "location_importer", - srcs = ["location_importer.cc"], - hdrs = ["location_importer.h"], - deps = [ - "stack_location_utils", - "//xla/hlo/ir:hlo", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "module_attributes_importer", - srcs = ["module_attributes_importer.cc"], - hdrs = ["module_attributes_importer.h"], - deps = [ - ":hlo_function_importer", - ":hlo_utils", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "//xla/service:computation_layout", - "//xla/service:hlo_module_config", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:translate # +# instead. cc_library( name = "translate", - srcs = ["translate.cc"], hdrs = ["translate.h"], deps = [ - ":hlo_to_mlir_hlo", - "//xla/mlir_hlo", - "//xla/mlir_hlo:mhlo_passes", - "//xla/service:hlo_parser", - "//xla/service:hlo_proto_cc", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/status", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@local_tsl//tsl/platform:protobuf", + "//xla/hlo/translate/hlo_to_mhlo:translate", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:translate_registration # +# instead. cc_library( name = "translate_registration", testonly = True, - srcs = ["translate_registration.cc"], deps = [ - ":translate", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TranslateLib", + "//xla/hlo/translate/hlo_to_mhlo:translate_registration", ], alwayslink = 1, ) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h index ead3a3955fc79c..2b5f81982fd6d8 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h @@ -16,89 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Converts an XLA PrecisionConfig to the corresponding MLIR attribute. -mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, - mlir::Builder* builder); - -// Converts the gather dimensions to attributes. -mlir::mhlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( - const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); - -// Converts the scatter dimensions to attributes. -mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( - const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); - -// Converts the dot algorithm to attributes. -mlir::mhlo::DotAlgorithmAttr ConvertDotAlgorithm( - PrecisionConfig::Algorithm algorithm, mlir::Builder* builder); - -// Converts the dot dimensions to attributes. -mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( - const DotDimensionNumbers& dnums, mlir::Builder* builder); - -// Converts the conv dimensions to attributes. -mlir::mhlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers( - const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); - -// Converts the output operand aliasing to attributes. -mlir::ArrayAttr ConvertOutputOperandAliasing( - const std::vector>>& aliaInfo, - mlir::Builder* builder); - -// Converts the sparsity descriptor to attributes. -absl::StatusOr ConvertSparsityDescriptor( - xla::SparsityDescriptor sparsity_descriptor, mlir::Builder* builder); - -absl::StatusOr ConvertFftType(FftType type); -absl::StatusOr ConvertTranspose( - TriangularSolveOptions_Transpose transpose); - -absl::StatusOr ConvertCustomCallApiVersion( - xla::CustomCallApiVersion api_version); - -mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, - mlir::Builder* builder); -mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, - mlir::Builder* builder); - -mlir::NamedAttribute ConvertReplicaGroups( - absl::Span replica_groups, mlir::Builder* builder); - -mlir::NamedAttribute ConvertSourceTargetPairs( - const std::vector>& source_target_pairs, - mlir::Builder* builder); - -mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder); - -// Extracts layouts from shapes and converts it into layout attributes (array of -// rank-1 index tensors). Returns an error if any of the shapes is a tuple. -absl::StatusOr ExtractLayoutsFromShapes( - const absl::Span shapes_with_layouts, mlir::Builder* builder); - -// Extracts the layouts of each element from a tuple shape and returns them as -// an array of rank-1 index tensors. Returns an error in presence of nested -// tuple shapes. -absl::StatusOr ExtractLayoutsFromTuple(const xla::Shape shape, - mlir::Builder* builder); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h index fa22a6d11f1086..0ebd37fa6af125 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -16,243 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "xla/comparison_util.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/hlo/ir/hlo_sharding.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/hlo.pb.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -class HloModule; -class HloComputation; -class HloInstruction; -class Shape; - -// HLO bounded dynamic shapes can be converted to either MLIR dynamic shapes -// (which lose the bound information) or casted to static shape using the -// bounds. -enum class DynamicShapeHandlingMode { kDynamic, kConvertToStatic }; - -// Helper class for importing HloComputations. -class HloFunctionImporter { - public: - // Imports the given computation as a function in the given symbol table and - // returns the FuncOp. This also imports any computations referred by - // instructions in this computation. - static absl::StatusOr ImportAsFunc( - const HloComputation& computation, mlir::SymbolTable& symbol_table, - std::unordered_map* - function_map, - mlir::Builder* builder, bool is_main, - bool flatten_computation_args_result = false); - - // Imports the given hlo computation to the specified region. - // - // Flattens the tuple-typed region argument(s) and return value(s). - static absl::Status ImportAsRegion( - const HloComputation& computation, mlir::SymbolTable& symbol_table, - mlir::Region* region, mlir::Builder* builder, - bool flatten_computation_args_result = false); - - // Imports the given computation to the given place specified by `builder`. - // `arguments` contains values for all parameters. - static absl::StatusOr ImportInstructions( - const HloComputation& computation, - const llvm::SmallVectorImpl& arguments, - mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, - bool flatten_computation_args_result = false); - - static absl::StatusOr ImportInstruction( - const HloInstruction* instr, - const llvm::SmallVectorImpl& operands, - mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, - bool flatten_computation_args_result = false, - DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); - - static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape, - llvm::StringRef attr_name); - - // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block - // arguments with 'implicit_operands'. Here | implicit_operands | == sum of - // the number of arguments in all the regions in IfOp or CaseOp. - void ReplaceBlockArgumentsWithImplicitOperands( - mlir::Operation* op, llvm::ArrayRef implicit_operands); - - // FlattenTupleType flattens the types in (nested) tuple-type 'type' and - // stores them in 'flattened_types'. - static void FlattenTupleType( - mlir::Type type, llvm::SmallVectorImpl& flattened_types); - - // FlattenTupleValue flattens the values in (nested) tuple-typed 'value' and - // stores them in 'flattened_values'. - static void FlattenTupleValue( - mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Value value, - llvm::SmallVectorImpl& flattened_values); - - // FlattenTupleValues flattens the values in (nested) tuple-typed 'values' and - // returns the flattened values. - static llvm::SmallVector FlattenTupleValues( - mlir::OpBuilder* func_builder, mlir::Location loc, - mlir::ValueRange values, std::optional reserve_size = std::nullopt); - - private: - HloFunctionImporter(mlir::SymbolTable& symbol_table, - std::unordered_map* function_map, - mlir::Builder* builder, - bool flatten_computation_args_result) - : context_(symbol_table.getOp()->getContext()), - symbol_table_(symbol_table), - builder_(builder), - function_map_(function_map), - flatten_computation_args_result_(flatten_computation_args_result) { - context_->loadDialect(); - context_->loadDialect(); - context_->loadDialect(); - context_->loadDialect(); - } - - // Imports the given computation as a new function, if it hasn't been already - // imported. - absl::StatusOr ImportAsFunc( - const HloComputation& computation, bool is_main); - - // Imports the given computation in the specified region. - absl::Status ImportAsRegion(const HloComputation& computation, - mlir::Region* region); - - // Imports instructions from the given computation in the specified block. - // Assumes that the block already has correct arguments populated. - absl::Status ImportInstructions(const HloComputation& computation, - mlir::Block* block); - absl::StatusOr ImportInstructionsImpl( - const HloComputation& computation, - const llvm::SmallVectorImpl& arguments, - mlir::OpBuilder* builder); - - // Imports an instruction. - absl::StatusOr ImportInstructionWithLayout( - const HloInstruction* instruction, - const llvm::SmallVectorImpl& operands, - mlir::OpBuilder* func_builder, - DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); - - absl::StatusOr ImportInstructionImpl( - const HloInstruction* instruction, - const llvm::SmallVectorImpl& operands, - mlir::OpBuilder* func_builder, - DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); - - // Gets the MLIR operand values from an HLO Instruction. - absl::StatusOr> GetOperands( - const HloInstruction* instruction); - - // Converts xla Tensor type to the corresponding MLIR type. - absl::StatusOr ConvertTensorType(const Shape& shape); - - // Converts an XLA shape/layout to the corresponding MLIR layout, in - // flattened_attr, while flattening the tuple layout. - absl::Status ConvertShapeToMlirLayout( - const Shape& shape, - llvm::SmallVectorImpl& flattened_attr); - - // Returns the output type of an HloInstruction. - absl::StatusOr GetReturnType(const HloInstruction* instruction); - - // Takes a list of HloInstructions and generates the list of types used for - // input, bypassing tuples to subsets. - absl::Status GetMlirTypes( - absl::Span instructions, - llvm::SmallVectorImpl* types); - - // Returns the Mlir Value for the corresponding HloInstruction. - absl::StatusOr GetMlirValue(const HloInstruction* instruction); - - // TODO(b/179166199): Move attribute converters to attribute_importer. - // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertComparisonDirection( - ComparisonDirection direction); - - // Converts an XLA Comparison::Type to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); - - // Converts an XLA CustomCallSchedule to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertCustomCallSchedule(CustomCallSchedule schedule); - - // Converts the dimensions of an HLO instruction into an MLIR attribute. - mlir::DenseIntElementsAttr ConvertDimensions( - absl::Span op_dimensions); - - // Converts Array ref to an DenseIntElementsAttr. - mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); - - // Converts Array ref of bools to a DenseIntElementsAttr of I1 type. - mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); - - // Converts Array ref to padding attribute. Input is a flattened list of - // padding low and padding high for each of the spatial dimensions. - mlir::NamedAttribute ConvertPadding(llvm::ArrayRef padding); - - mlir::MLIRContext* context_; - - // SymbolTable to which new functions should be inserted. - mlir::SymbolTable& symbol_table_; - - mlir::Builder* builder_; - - // Mapping from HloComputation to the created MLIR function. - std::unordered_map* function_map_; - - // Mapping from HloInstructions to the associative MLIR values. - std::unordered_map instruction_value_map_; - - bool flatten_computation_args_result_; -}; - -// Returns a StringAttr that carries a prettyprinted representation of the -// given HLO C++ input_output_alias_config. -// Always succeeds and returns a non-empty attribute. -mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, - mlir::Builder* builder); - -// Returns a StringAttr that carries a prettyprinted representation of the -// given HLO C++ sharding. -// Always succeeds and returns a non-empty attribute. -mlir::Attribute ConvertSharding(const HloSharding& sharding, - mlir::Builder* builder); - -// Returns a StringAttr that carries a prettyprinted representation of the -// given HLO proto sharding. -// Will fail and return an empty attribute if the proto sharding cannot be -// converted to the C++ sharding. -mlir::Attribute ConvertSharding(const OpSharding& sharding, - mlir::Builder* builder); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h index 0cc1a39d8eb003..8577e86dc93839 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h @@ -16,49 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ -#include - -#include "absl/status/status.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/SymbolTable.h" -#include "xla/xla_data.pb.h" - -namespace xla { -class HloModule; -class HloModuleProto; -class HloComputation; -class HloInstruction; -class Shape; - -// Importer that takes an HloModule and imports it as an MLIR module in the XLA -// dialect. HloModuleImporter does not take ownership. -class HloModuleImporter { - public: - explicit HloModuleImporter(mlir::ModuleOp module, - bool import_all_computation = false, - bool flatten_computation_args_result = false); - - // Import the HloModule into the MLIR Module. - absl::Status Import(const xla::HloModule& module); - - // Import the HloModuleProto into the MLIR Module. - absl::Status Import(const xla::HloModuleProto& module); - - private: - bool import_all_computation_; - bool flatten_computation_args_result_; - mlir::SymbolTable symbol_table_; - mlir::Builder builder_; - - // Map for tracking which MLIR function map to which HLO Computation. This - // tracks functions as they are imported and provides a quick lookup for - // functions invoked by control flow related operations (e.g. while, call). - std::unordered_map - function_map_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h index 775d6367dc8fc9..4943ef790d35f1 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h @@ -16,56 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OwningOpRef.h" - -namespace mlir { -class ModuleOp; -} // namespace mlir - -namespace xla { -class HloModule; -class HloModuleProto; - -// Converts an HLO module proto to a MLIR module in HLO dialect. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -absl::StatusOr> ConvertHloToMlirHlo( - mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, - xla::HloModuleProto const* hlo_module, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -// Converts an HLO module to a MLIR module in HLO dialect. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -absl::StatusOr> ConvertHloToMlirHlo( - mlir::MLIRContext& ctx, const xla::HloModule* hlo_module, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, - const xla::HloModule* hlo_module, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h index bb9785d8242664..50e31028617463 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -18,232 +18,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/SparseTensor/IR/Enums.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "xla/layout.h" -#include "xla/layout_util.h" -#include "xla/literal.h" -#include "xla/mlir/utils/type_util.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { - -absl::StatusOr CreateDenseElementsAttrFromLiteral( - const LiteralBase& literal, mlir::Builder builder); - -// Creates an DenseIntElementsAttr using the elements of the vector and the -// optional shape. -mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( - const llvm::ArrayRef vector, mlir::Builder builder, - llvm::ArrayRef shape = {}); - -// Converts the given XLA shape for tensors to the template MLIR type. -template -static absl::StatusOr ConvertTensorShapeToType(const Shape& xla_ty, - mlir::Builder builder) { - auto element_type_or = - ConvertPrimitiveTypeToMlirType(xla_ty.element_type(), builder); - if (!element_type_or.ok()) return element_type_or.status(); - - bool is_bounded_dynamic = false; - int64_t rank = xla_ty.rank(); - llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); - llvm::SmallVector bounds(rank, mlir::ShapedType::kDynamic); - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t dim_size = xla_ty.dimensions(dim); - if (xla_ty.is_dynamic_dimension(dim)) { - if (!xla_ty.is_unbounded_dynamic_dimension(dim)) { - bounds[dim] = dim_size; - is_bounded_dynamic = true; - } - } else { - shape[dim] = dim_size; - } - } - using mlir::mhlo::TypeExtensionsAttr; - mlir::Attribute encoding; - if (is_bounded_dynamic) { - encoding = TypeExtensionsAttr::get(builder.getContext(), bounds); - } - - using mlir::sparse_tensor::SparseTensorEncodingAttr; - // TODO(b/238903065): We don't yet support bounded dynamism shapes and - // sparsity at the same time, as we can currently only have one `encoding` on - // a RankedTensorType, and we don't currently have a meet of - // SparseTensorEncodingAttr and TypeExtensionsAttr (which holds bounds). - // - // For example, we wouldn't be able to represent the xla type - // `f32[4,<=4]{1,0:D(D,C)}`. - if (xla_ty.has_layout()) { - auto layout = xla_ty.layout(); - if (LayoutUtil::IsSparse(layout)) { - if (is_bounded_dynamic) - return Unimplemented( - "MHLO doesn't support bounded dynamic shapes for sparse tensors"); - llvm::SmallVector lts; - for (size_t i = 0, e = layout.dim_level_types_size(); i < e; ++i) { - auto dlt = layout.dim_level_type(i); - bool ordered = - i < layout.dim_ordered_size() ? layout.dim_ordered(i) : true; - bool unique = - i < layout.dim_unique_size() ? layout.dim_unique(i) : true; - switch (dlt) { - case DimLevelType::DIM_DENSE: - lts.push_back(*mlir::sparse_tensor::buildLevelType( - mlir::sparse_tensor::LevelFormat::Dense, ordered, unique)); - break; - case DimLevelType::DIM_COMPRESSED: - lts.push_back(*mlir::sparse_tensor::buildLevelType( - mlir::sparse_tensor::LevelFormat::Compressed, ordered, unique)); - break; - case DimLevelType::DIM_SINGLETON: - lts.push_back(*mlir::sparse_tensor::buildLevelType( - mlir::sparse_tensor::LevelFormat::Singleton, ordered, unique)); - break; - case DimLevelType::DIM_LOOSE_COMPRESSED: - lts.push_back(*mlir::sparse_tensor::buildLevelType( - mlir::sparse_tensor::LevelFormat::LooseCompressed, ordered, - unique)); - break; - default: - return InvalidArgument("Unknown DimLevelType from HLO"); - } - } - auto ordering = layout.minor_to_major(); - llvm::SmallVector major_to_minor = {ordering.rbegin(), - ordering.rend()}; - auto id_map = mlir::AffineMap::getPermutationMap(major_to_minor, - builder.getContext()); - // TODO(atondwal): support sizes other than 32 when XLA does - encoding = SparseTensorEncodingAttr::get( - builder.getContext(), lts, id_map, mlir::AffineMap(), 32, 32); - } - } - return TypeT::get(shape, element_type_or.value(), encoding); -} - -absl::StatusOr ConvertTensorShapeToMemRefType( - const Shape& shape, mlir::Builder builder); - -template <> -inline absl::StatusOr ConvertTensorShapeToType( - const Shape& shape, mlir::Builder builder) { - if (shape.is_dynamic()) { - return FailedPrecondition( // NOLINT - "MemRefType don't support dynamic shapes"); - } - return ConvertTensorShapeToMemRefType(shape, builder); -} - -// Converts the given XLA shape to the template MLIR type. -template -static absl::StatusOr ConvertShapeToType(const Shape& shape, - mlir::Builder builder) { - if (shape.IsTuple()) { - llvm::SmallVector contents; - contents.reserve(shape.tuple_shapes_size()); - for (const auto& subtype : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(auto mlir_subtype, - ConvertShapeToType(subtype, builder)); - contents.push_back(mlir_subtype); - } - return builder.getTupleType(contents); - } - if (shape.IsToken()) { - return mlir::mhlo::TokenType::get(builder.getContext()); - } - return ConvertTensorShapeToType(shape, builder); -} - -// CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using -// the non-tuple-typed values in 'flatten_values'. -// -// e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple>, -// The function returns %t2 such that: -// %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple -// %t2 = mhlo.tuple(V1,%t1): (T1,tuple) -> tuple> -// -// Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to -// resp. flatten and create tuples in the exact same order. -// 2. `flatten_values`, initially storing the flattened values, will be -// mutated to a 0-length array by the end of function invocation. -mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, - mlir::ValueRange& flatten_values, mlir::Type type); - -// Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType. -// Otherwise, return 'op'. -mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, - mlir::Location loc, - mlir::Operation* op, mlir::Type type); - -mlir::TypeRange Untuple(const mlir::Type& type); - -static std::pair GetLayoutAttribute( - mlir::Builder& b, const Shape& shape, - std::optional maybe_layout = std::nullopt) { - if (shape.IsTuple()) { - llvm::SmallVector element_attrs; - llvm::SmallVector tile_attrs; - for (const auto& tuple_shape : shape.tuple_shapes()) { - // TODO here we do not dissect the layout of a tuple into sublayouts. - // Presently ShapeLayout cannot represent an explicit layout for a tuple - // type so this should never occur. However, if this function were to - // be used in another context where this assumption were to be lifted. - // users should be aware of this limitation which will use the default - // layout for tuple subshapes. - std::pair inner = - tuple_shape.has_layout() - ? GetLayoutAttribute(b, tuple_shape, tuple_shape.layout()) - : GetLayoutAttribute(b, tuple_shape); - element_attrs.push_back(inner.first); - tile_attrs.push_back(inner.second); - } - return std::make_pair((mlir::Attribute)b.getArrayAttr(element_attrs), - b.getArrayAttr(tile_attrs)); - } - - Layout layout = maybe_layout.value_or( - shape.has_layout() ? shape.layout() - : LayoutUtil::GetDefaultLayoutForShape(shape)); - - llvm::SmallVector vec_of_tiles; - for (const Tile& tile : layout.tiles()) { - llvm::SmallVector tile_vec = {tile.dimensions().begin(), - tile.dimensions().end()}; - vec_of_tiles.push_back(b.getIndexTensorAttr(tile_vec)); - } - llvm::SmallVector layout_vec = {layout.minor_to_major().begin(), - layout.minor_to_major().end()}; - return std::make_pair(b.getIndexTensorAttr(layout_vec), - b.getArrayAttr(vec_of_tiles)); -} - -static bool HasCustomLayout(const Shape& shape) { - if (shape.IsTuple()) { - return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); - } - return shape.has_layout() && !shape.layout().minor_to_major().empty() && - shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); -} - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/translate.h b/third_party/xla/xla/translate/hlo_to_mhlo/translate.h index 2594aa17fc2d21..4ed0dc5c1ba216 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/translate.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/translate.h @@ -16,73 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ -namespace llvm { -class StringRef; -} // namespace llvm - -namespace mlir { -class MLIRContext; -class ModuleOp; -template -class OwningOpRef; -} // namespace mlir - -namespace xla { - -// Converts a HloModuleProto stored in the file with the given `input_filename` -// into a MHLO module. Creates MLIR entities into the given MLIR `context`. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -mlir::OwningOpRef HloToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -// Converts a HloModule stored in text form for a file with the given -// `input_filename` into a MHLO module. Creates MLIR entities into the given -// MLIR `context`. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -mlir::OwningOpRef HloTextToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -// Converts a HloModuleProto stored in the file with the given `input_filename` -// into a StableHLO module. Creates MLIR entities into the given MLIR `context`. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -mlir::OwningOpRef HloToStablehloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -// Converts a HloModule stored in text form for a file with the given -// `input_filename` into a StableHLO module. Creates MLIR entities into the -// given MLIR `context`. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -mlir::OwningOpRef HloTextToStablehloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/translate.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD index b486d47233cfb3..13e2c3afe7441d 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD @@ -1,9 +1,5 @@ -load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_binary", "cc_library") -load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl:tsl.bzl", "internal_visibility") -load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -14,282 +10,87 @@ package( licenses = ["notice"], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:attribute_exporter +# instead. cc_library( name = "attribute_exporter", - srcs = ["attribute_exporter.cc"], hdrs = ["attribute_exporter.h"], deps = [ - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "//xla/service:hlo_parser", - "//xla/service:hlo_proto_cc", - "//xla/stream_executor:dnn", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@stablehlo//:base", + "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:layout_util +# instead. cc_library( name = "layout_util", - srcs = ["layout_util.cc"], hdrs = ["layout_util.h"], deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/translate/mhlo_to_hlo:layout_util", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:location_exporter +# instead. cc_library( name = "location_exporter", - srcs = ["location_exporter.cc"], hdrs = ["location_exporter.h"], deps = [ - ":stack_frame_index_builder", - "//xla:xla_data_proto_cc", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", + "//xla/hlo/translate/mhlo_to_hlo:location_exporter", ], ) -cc_library( +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:module_attributes_exporter +# instead. +alias( name = "module_attributes_exporter", - srcs = ["module_attributes_exporter.cc"], - hdrs = ["module_attributes_exporter.h"], - deps = [ - "//xla:xla_data_proto_cc", - "//xla/service:hlo_module_config", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/status", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], + actual = "//xla/hlo/translate/mhlo_to_hlo:module_attributes_exporter", ) -cc_library( +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:stack_frame_index_builder +# instead. +alias( name = "stack_frame_index_builder", - srcs = ["stack_frame_index_builder.cc"], - hdrs = ["stack_frame_index_builder.h"], - deps = [ - "//xla/service:hlo_proto_cc", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], + actual = "//xla/hlo/translate/mhlo_to_hlo:stack_frame_index_builder", ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo +# instead. cc_library( name = "mlir_hlo_to_hlo", - srcs = [ - "mlir_hlo_to_hlo.cc", - "operator_writers.inc", - ], hdrs = ["mlir_hlo_to_hlo.h"], deps = [ - ":attribute_exporter", - ":layout_util", - ":location_exporter", - ":module_attributes_exporter", - ":operator_writer_inc", - ":stack_frame_index_builder", - ":type_to_shape", - "//xla:array", - "//xla:comparison_util", - "//xla:debug_options_flags", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:approx_topk", - "//xla/client/lib:approx_topk_shape", - "//xla/client/lib:matrix", - "//xla/client/lib:quantize", - "//xla/client/lib:slicing", - "//xla/hlo/ir:hlo", - "//xla/mlir/utils:error_util", - "//xla/mlir/utils:type_util", - "//xla/mlir_hlo", - "//xla/mlir_hlo:mhlo_passes", - "//xla/service:computation_layout", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", - "//xla/service:hlo_proto_cc", - "//xla/service/gpu:backend_configs_cc", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:types", - "@stablehlo//:base", - "@stablehlo//:stablehlo_ops", - ], -) - -build_test( - name = "operator_writer_gen_build_test", - targets = [ - ":operator_writer_gen", - ], -) - -cc_binary( - name = "operator_writer_gen", - srcs = ["operator_writer_gen.cc"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//llvm:TableGen", - "@llvm-project//mlir:TableGen", - ], -) - -gentbl_cc_library( - name = "operator_writer_inc", - compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "operator_writers.inc")], - tblgen = ":operator_writer_gen", - td_file = "//xla/mlir_hlo:mhlo/IR/hlo_ops.td", - deps = [ - "//xla/mlir_hlo:hlo_ops_td_files", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -xla_cc_test( - name = "mlir_hlo_to_hlo_test", - srcs = ["mlir_hlo_to_hlo_test.cc"], - deps = [ - ":mlir_hlo_to_hlo", - "//xla/mlir/utils:error_util", - "//xla/tsl/lib/core:status_test_util", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:ShapeDialect", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@stablehlo//:register", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:translate +# instead. cc_library( name = "translate", - srcs = ["translate.cc"], hdrs = ["translate.h"], deps = [ - ":mlir_hlo_to_hlo", - ":type_to_shape", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/service:hlo_module_config", - "//xla/service:hlo_proto_cc", - "//xla/service:hlo_proto_util", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/translate/mhlo_to_hlo:translate", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:translate_registration +# instead. cc_library( name = "translate_registration", testonly = True, - srcs = [ - "translate_registration.cc", - "translate_registration.h", - ], deps = [ - ":translate", - "//xla/mlir_hlo:hlo_dialect_registration", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TranslateLib", + "//xla/hlo/translate/mhlo_to_hlo:translate_registration", ], alwayslink = 1, ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:type_to_shape +# instead. cc_library( name = "type_to_shape", - srcs = ["type_to_shape.cc"], hdrs = ["type_to_shape.h"], deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/mlir/utils:type_util", - "//xla/mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:SparseTensorEnums", - "@llvm-project//mlir:Support", - "@stablehlo//:stablehlo_ops", - ], -) - -xla_cc_test( - name = "type_to_shape_test", - srcs = ["type_to_shape_test.cc"], - deps = [ - ":type_to_shape", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/mlir_hlo", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:test_main", + "//xla/hlo/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h index 97c20f112af1ba..2caf77bf3a3d2a 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -16,60 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ -#include +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" -#include "absl/status/statusor.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Support/LLVM.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/dnn.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Converts the conv dimensions attribute to XLA HLO. -ConvolutionDimensionNumbers ConvertConvDimensionNumbers( - mlir::mhlo::ConvDimensionNumbersAttr input); - -// Converts the dot algorithm attribute to XLA HLO. -absl::StatusOr ConvertDotAlgorithm( - mlir::mhlo::DotAlgorithmAttr attr); - -absl::StatusOr> ConvertReplicaGroups( - mlir::DenseIntElementsAttr input); - -// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding -// and source-target pairs are defined in HLO. -absl::StatusOr>> ConvertNx2Attribute( - std::optional optional_attr); - -absl::StatusOr ConvertTranspose( - llvm::StringRef transpose_string); - -absl::StatusOr ConvertCustomCallSchedule( - mlir::mhlo::CustomCallSchedule schedule); - -absl::StatusOr ConvertCustomCallApiVersion( - mlir::mhlo::CustomCallApiVersion api_version); - -absl::StatusOr< - std::vector>>> -ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); - -// Returns an OpSharding that represents the result of parsing the given string: -// first, as serialized protobuf, and then as prettyprinted representation. -// Will fail if both attempts at parsing failed. -std::optional ConvertSharding(mlir::StringRef sharding); - -std::optional ConvertInputOutputAlias( - llvm::ArrayRef aliasing); - -} // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h index 2ecd4e3ef3ba3d..6005d23d69e910 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h @@ -18,68 +18,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/hlo/ir/hlo_sharding.h" -#include "xla/shape.h" -#include "xla/xla_data.pb.h" - -namespace mlir { - -// XLA Layout preferences. Currently, when it comes to TPU, there are two -// primary layout choices for any XLA arguments (parameter or resource): (1) -// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU -// layout while Linear is native host (CPU) layout. -// This enum allows the caller of XLA to propagate layout preference to the XLA -// compiler. -// kNoPreference: the generic layout where the XLA compiler has the freedom -// to assign any layout. -// kTpuPreferCompactChunkPaddedLayout: use native TPU layout on TPU. -// kTpuPreferLinearLayout: use native CPU layout on TPU. The compiler may -// insert transformation TPU kernels. -// As the layout of any argument will change from a native host layout to a -// native TPU layout either on host or on device, XLA compiler and TPU runtime -// must be in coordination to transform the parameters in a consistent way. -enum class XlaLayoutPreference { - kNoPreference = 0, - kTpuPreferCompactChunkPaddedLayout = 1, - kTpuPreferLinearLayout = 2 -}; - -// The following defines the layout preference of an xla tensor. -// The return value of LayoutPreferenceFn can be used in -// ShapeRepresentationFn. -typedef std::function( - const xla::Shape& shape)> - LayoutPreferenceFn; - -typedef std::function( - const xla::Shape& shape, bool fast_mem, - XlaLayoutPreference layout_preference)> - ShapeRepresentationFn; - -// Return a LayoutPreferenceFn that always uses kNoPreference layout. -LayoutPreferenceFn UseNoPreferenceLayoutFn(); - -// Rewrites the layout of xla_shape if there is tiled sharding. -absl::Status RewriteLayoutWithShardedShape( - const std::optional& sharding, bool use_fast_memory, - const LayoutPreferenceFn& layout_preference_fn, - const ShapeRepresentationFn& shape_representation_fn, - xla::Shape* xla_shape); - -// Adds reshapes to fix the layout of an output, if a shape_representation_fn or -// sharding is present. -absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( - xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, - const LayoutPreferenceFn& layout_preference_fn, - const ShapeRepresentationFn& shape_representation_fn, - std::optional sharding, bool fast_mem); - -} // namespace mlir +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h b/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h index d7ec94c1622918..b5c43ce49c481a 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ -#include - -#include "mlir/IR/Location.h" -#include "mlir/IR/Operation.h" -#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" -#include "xla/xla_data.pb.h" - -namespace mlir { -namespace mhlo { - -// Returns a OpMetadata proto based on the location of the op. If the location -// is unknown, an empty proto is returned. `op_name` are populated with the op -// location (converted). FileLineColLoc locations are populated by taking the -// file name and line number, and populating `source_file` and `source_line` -// respectively. -xla::OpMetadata CreateOpMetadataFromLocation( - Operation* op, StackFrameIndexBuilder* frame_index_builder); - -// Returns a name that can be used for debugging purposes, e.g., naming -// variable names in generated IR or producing logging output. -std::string GetDebugNameFromLocation(Location location); - -} // namespace mhlo -} // namespace mlir +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h index d55bfc15cf653d..1544b99e069571 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h @@ -16,81 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "llvm/ADT/ArrayRef.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinOps.h" -#include "xla/client/xla_builder.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_module_config.h" -#include "xla/translate/mhlo_to_hlo/layout_util.h" - -namespace mlir { - -struct MlirToHloConversionOptions { - // Best-effort propagation of the layouts. These layouts serve as performance - // hints to the backend. - // - // Note that non-array shapes are not carrying layouts, and users have to - // figure out the proper layouts of them through context. This is one of the - // reasons why the attribute-based solution is temporary. - // - // TODO(timshen): Investigate the necessity of having layouts in MHLO. - bool propagate_layouts = false; - - // Propagate the source and result layouts from mhlo bitcast op into the - // backend config for the bitcast. This is required for XLA:GPU backend to - // use elemental IR emitters for fused bitcasts without propagating layouts. - bool propagate_bitcast_layouts_to_backend_config = false; - - LayoutPreferenceFn layout_preference_fn; - ShapeRepresentationFn shape_representation_fn; - - // If use_tuple_args is set, then the entry computations's arguments are - // converted to a tuple and passed as a single parameter. - bool use_tuple_args = false; - - // If return tuple is true, then the entry function's return values - // are converted to a tuple even when there is only a single return value. - // Multiple return values are always converted to a tuple and returned as a - // single value. - bool return_tuple = true; -}; - -// Prefer `ConvertMlirHloToHloModule` over this method when possible, as it -// preserves more information and abstracts away the proto. This method is -// preserved for legacy reasons. -// TODO (b/345806521): Migrate callsites to ConvertMlirHloToHloModule, -// and delete this method. -// -// Converts a MLIR module in HLO dialect into a HloModuleProto. -// -absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, - ::xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple, - MlirToHloConversionOptions options = {}); - -// Converts a MLIR module in HLO dialect into a HloModule with HloModuleConfig. -// This method preserves config data stored in MHLO module attributes. -// -// See `MlirToHloConversionOptions` for details on conversion flags. -absl::StatusOr> ConvertMlirHloToHloModule( - mlir::ModuleOp module, MlirToHloConversionOptions options = {}); - -// Transforms a Block into HLO, where the HLO is represented as calls into an -// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp. -// xla_params are inputs to block. returns are the returned XlaOps. -absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, - llvm::ArrayRef xla_params, - std::vector& returns, - MlirToHloConversionOptions options = {}); - -} // namespace mlir +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.h b/third_party/xla/xla/translate/mhlo_to_hlo/translate.h index b65e7d9c78cdaa..373eaca3fca4f3 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/translate.h @@ -16,35 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ -#include -#include - -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/raw_os_ostream.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Support/LogicalResult.h" - -namespace xla { - -mlir::LogicalResult MlirHloToHloTranslateFunction(mlir::ModuleOp module, - llvm::raw_ostream& output, - bool emit_return_tuple, - bool emit_use_tuple_arg); - -mlir::LogicalResult MlirHloToHloTextTranslateFunction( - mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, - bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); - -// Translate the MHLO program in in-memory file 'buffer' to a HLO program -// written in a file represented with handle 'output_stream'; -mlir::LogicalResult MlirHloToHloTextMain( - std::unique_ptr buffer, - llvm::raw_ostream& output_stream, bool emit_return_tuple, - bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h index 6e75fc2b75df40..2e99276efe7c5b 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h @@ -16,16 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Types.h" -#include "xla/shape.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape. -Shape TypeToShape(mlir::Type type); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD b/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD index fb6c1a7d961b9c..c432091352b87c 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD +++ b/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD @@ -10,37 +10,12 @@ package( licenses = ["notice"], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/translate/stablehlo_to_hlo:translate # +# instead. cc_library( name = "translate", - srcs = ["translate.cc"], hdrs = ["translate.h"], deps = [ - "//xla/mlir_hlo:mhlo_passes", - "//xla/translate/mhlo_to_hlo:translate", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@stablehlo//:register", + "//xla/hlo/translate/stablehlo_to_hlo:translate", ], ) - -cc_library( - name = "translate_registration", - testonly = True, - srcs = ["translate_registration.cc"], - deps = [ - ":translate", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TranslateLib", - "@stablehlo//:register", - ], - alwayslink = 1, -) diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h b/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h index d21dcbdacfcf68..badaeeaa9acb30 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h +++ b/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h @@ -16,35 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ #define XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ -#include -#include - -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/raw_os_ostream.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Support/LogicalResult.h" - -namespace xla { - -mlir::LogicalResult StablehloToHloTranslateFunction(mlir::ModuleOp module, - llvm::raw_ostream& output, - bool emit_return_tuple, - bool emit_use_tuple_arg); - -mlir::LogicalResult StablehloToHloTextTranslateFunction( - mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, - bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); - -// Translate the StableHLO program in in-memory file 'buffer' to a HLO program -// written in a file represented with handle 'output_stream'; -mlir::LogicalResult StablehloToHloTextMain( - std::unique_ptr buffer, - llvm::raw_ostream& output_stream, bool emit_return_tuple, - bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" #endif // XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ From 6ed0d11d0c06cdaec096f6c4d5aa991f66cb6a41 Mon Sep 17 00:00:00 2001 From: Seher Ellis Date: Thu, 19 Sep 2024 15:50:26 -0700 Subject: [PATCH 031/483] [XLA:CollectivePipeliner] Avoid redundant broadcasts in the formatting ops of sunk collectives. Before this CL, the same broadcast could be added multiple times - to the formatting ops of a single sunk collective, and - to the modified HLO computation if the same broadcast appears in the formatting ops of different sunk collectives. PiperOrigin-RevId: 676593920 --- .../xla/xla/service/collective_pipeliner.cc | 15 ++-- .../xla/service/collective_pipeliner_test.cc | 83 +++++++++++++++++++ 2 files changed, 93 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index e32749ceb7f2f8..8b04e10f492868 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -275,8 +275,8 @@ bool CollectSimpleDependencies(HloInstruction* i, for (HloInstruction* op : i->mutable_operands()) { absl::InlinedVector to_add; if (op->opcode() == HloOpcode::kBroadcast) { - to_add.push_back(op); if (deps_set.insert(op).second) { + to_add.push_back(op); op = op->mutable_operand(0); if (op->opcode() == HloOpcode::kConstant) { if (deps_set.insert(op).second) { @@ -318,6 +318,7 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, absl::flat_hash_set added_instructions; HloInstruction* folded_instr = instr; std::vector formatting_ops; + absl::flat_hash_set formatting_set; // Returns if this is an acceptable user of a pipelined instruction. // Generic elementwise ops can have multiple operands that require the inputs // of being saved across the loop. So protect them through @@ -411,11 +412,12 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, auto& data = stack.back(); HloInstruction* instr = data.first; if (data.second == 0 && instr != folded_instr) { - if (!CollectSimpleDependencies(instr, formatting_ops, - added_instructions)) { + if (!CollectSimpleDependencies(instr, formatting_ops, formatting_set)) { return empty_pair; } - formatting_ops.push_back(instr); + if (formatting_set.insert(instr).second) { + formatting_ops.push_back(instr); + } } if (data.second == instr->user_count()) { stack.pop_back(); @@ -2330,9 +2332,9 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, // Create the new tuple with the original while tuple size. std::vector new_output_tuple; new_output_tuple.resize(operands_indices_count, nullptr); + InstructionMap pipelined_map; // Reproduce computation to the output after the loop on the full shape. for (auto& to_move : loop_analysis.GetMoveInfos()) { - InstructionMap pipelined_map; for (int64_t i = 0; i < to_move.collectives_to_move.size(); ++i) { HloInstruction* collective = to_move.collectives_to_move[i]; int64_t gte_index = collective_to_new_tuple_index[collective]; @@ -2419,6 +2421,9 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, // an effect on the instruction itself (like say broadcast, slices ... // etc). for (HloInstruction* formatting_op : to_move.formatting_ops) { + if (pipelined_map.contains(formatting_op)) { + continue; + } if (!to_add_batch_set.contains(formatting_op) && formatting_op->opcode() != HloOpcode::kBroadcast) { HloInstruction* cloned_not_to_batch = loop_computation->AddInstruction( diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 73efde7ce53096..6c48208f65d8fd 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -3766,5 +3766,88 @@ ENTRY entry { while_instr->while_body()->root_instruction()->operand(8))); } +TEST_F(CollectivePipelinerTest, NoRedundantBroadcastsInFormattingOps) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.396 = bf16[3,8,128] get-tuple-element(param), index=2 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=3 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + convert = bf16[] convert(add.232) + broadcast = bf16[1,8,128] broadcast(convert) + add.1 = bf16[1,8,128] add(ar.1, broadcast) + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, add.1, select.1348, constant.2561, constant.2561) + ar.2 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add.1, channel_id=2 + add.2 = bf16[1,8,128] add(ar.2, broadcast) + dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.396, add.2, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + // There should be only one broadcast instruction using a get-tuple-element + // from the while instruction. + EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(), + [](const HloInstruction* instr) { + return instr->opcode() == + HloOpcode::kBroadcast && + instr->operand(0)->opcode() == + HloOpcode::kGetTupleElement && + instr->operand(0)->operand(0)->opcode() == + HloOpcode::kWhile; + }), + 1); +} + } // namespace } // namespace xla From bb06950a3e2f308367fe30ed6c5b29504a4c9fec Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Thu, 19 Sep 2024 16:37:17 -0700 Subject: [PATCH 032/483] [XLA:TPU] Use heap simulator in memory bound loop optimizer to account for memory fragmentation. - Tests were updated in previous cl/666051782, so that they would break with this change. - Update tests and add tests for the change. PiperOrigin-RevId: 676609440 --- .../xla/service/memory_space_assignment/BUILD | 1 + .../memory_space_assignment/algorithm.cc | 11 +- .../memory_bound_loop_optimizer.cc | 440 ++++++++++++------ .../memory_bound_loop_optimizer.h | 78 ++-- .../memory_bound_loop_optimizer_test.cc | 91 +++- 5 files changed, 427 insertions(+), 194 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index f3f989d083f8ea..a9505ecadc02e2 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -437,6 +437,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", ], diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 67371364f1cd0d..530c17775b064a 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -954,13 +954,10 @@ absl::Status MsaAlgorithm::OptimizeMemoryBoundLoop(int loop_start_idx, const int iteration_start_idx = loop_start_idx + loop_size; const int iteration_end_idx = iteration_start_idx + loop_size; - TF_ASSIGN_OR_RETURN( - std::unique_ptr optimizer, - MemoryBoundLoopOptimizer::Create( - iteration_start_idx, iteration_end_idx, options_.max_size_in_bytes, - options_.memory_bound_loop_optimizer_options, hlo_live_range_, - alias_analysis_, *options_.cost_analysis, options_.size_fn, - options_.reserved_scoped_memory_fn)); + TF_ASSIGN_OR_RETURN(std::unique_ptr optimizer, + MemoryBoundLoopOptimizer::Create( + iteration_start_idx, iteration_end_idx, + hlo_live_range_, alias_analysis_, options_)); optimizer->Optimize(); const int loop_optimized_allocations_original_size = diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc index 9c4b1a2e8bd39b..0cfbc0db37f178 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -62,6 +63,21 @@ namespace xla { namespace memory_space_assignment { namespace { +struct LoopOptimizerChunkInterval { + int64_t begin_idx_in_loop; + int64_t end_idx_in_loop; + EvenOddChunkPair chunks; + + std::string ToString() const { + CHECK(chunks.first.has_value() && chunks.second.has_value()); + return absl::StrFormat( + "begin_idx_in_loop: %d, end_idx_in_loop: %d, even chunk: %s, odd " + "chunk: %s", + begin_idx_in_loop, end_idx_in_loop, chunks.first->ToString(), + chunks.second->ToString()); + } +}; + std::optional GetInstructionIndex( const HloInstruction* instruction, const absl::flat_hash_map& @@ -137,7 +153,7 @@ void LoopOptimizerBestFitHeap::RemoveEvenOddChunkPair( EvenOddChunkPair& chunks) { CheckAllocationIntervalValid(begin_idx_in_loop, end_idx_in_loop); ShiftAllocationIntervalIfRequired(begin_idx_in_loop, end_idx_in_loop); - auto [even_chunk, odd_chunk] = chunks; + auto& [even_chunk, odd_chunk] = chunks; RemoveEvenChunks(begin_idx_in_loop, end_idx_in_loop, even_chunk); RemoveOddChunks(begin_idx_in_loop, end_idx_in_loop, odd_chunk); } @@ -325,18 +341,17 @@ int64_t LoopOptimizerBestFitHeap::LastMemoryOffsetOccupied() const { } /*static*/ absl::StatusOr> -MemoryBoundLoopOptimizer::Create( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, - const CostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn) { +MemoryBoundLoopOptimizer::Create(int loop_start, int loop_end, + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis, + const Options& options) { + CHECK(options.cost_analysis != nullptr); std::unique_ptr optimizer = absl::WrapUnique(new MemoryBoundLoopOptimizer( - loop_start, loop_end, alternate_memory_size, options, hlo_live_range, - alias_analysis, cost_analysis, size_function, - reserved_scoped_memory_fn)); + loop_start, loop_end, options.max_size_in_bytes, + options.memory_bound_loop_optimizer_options, hlo_live_range, + alias_analysis, *options.cost_analysis, options.size_fn, + options.reserved_scoped_memory_fn, options.alignment_in_bytes)); TF_RETURN_IF_ERROR(optimizer->Initialize()); return std::move(optimizer); } @@ -347,7 +362,8 @@ MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn) + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn, + int64_t alignment_in_bytes) : loop_start_(loop_start), loop_end_(loop_end), loop_size_(loop_end - loop_start), @@ -357,13 +373,17 @@ MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( alias_analysis_(alias_analysis), cost_analysis_(cost_analysis), size_function_(size_function), - reserved_scoped_memory_fn_(reserved_scoped_memory_fn) {} + reserved_scoped_memory_fn_(reserved_scoped_memory_fn), + heap_(LoopOptimizerBestFitHeap(alternate_memory_size, + /*loop_size=*/loop_end - loop_start, + alignment_in_bytes)) {} absl::Status MemoryBoundLoopOptimizer::Initialize() { const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); VLOG(3) << "MemoryBoundLoopOptimizer::Initialize, loop start: " << loop_start_ - << ", loop end: " << loop_end_ << ", loop size: " << loop_size_; + << ", loop end: " << loop_end_ << ", loop size: " << loop_size_ + << ", alternate memory size: " << alternate_memory_size_; const HloComputation* loop_computation = nullptr; // Initialize the remaining memory array with the size of the alternate // memory. Also populate instructions_in_loop_ and @@ -387,11 +407,21 @@ absl::Status MemoryBoundLoopOptimizer::Initialize() { } else { TF_RET_CHECK(loop_computation == loop_inst->parent()); } - remaining_memory_.push_back( - alternate_memory_size_ - + int64_t reserved_memory = reserved_scoped_memory_fn_(loop_inst, /*operands_in_alternate_memory=*/{}, - /*outputs_in_alternate_memory=*/{})); + /*outputs_in_alternate_memory=*/{}); + if (reserved_memory == 0) { + continue; + } + // Chunks for reserved scoped memory should always be found at offset 0. + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + i, i, reserved_memory, /*preferred_offsets=*/{0, 0}); + CHECK(chunks.first.has_value()); + CHECK(chunks.second.has_value()); + CHECK(chunks.first->size == reserved_memory); + VLOG(3) << "Reserved chunk: " << chunks.first->ToString() + << " loop index: " << i; } // Create a tree set to keep track of all the values that the loop @@ -809,11 +839,20 @@ std::string MemoryBoundLoopOptimizer::LoopValue::ToString() const { for (const auto& allocation : allocations) { absl::StrAppend(&allocations_str, "\n - ", allocation->ToString()); } + std::string chunk_str; + if (HasEvenAndOddChunks()) { + absl::StrAppend(&chunk_str, "\n", "even chunk: ", chunks.first->ToString()); + absl::StrAppend(&chunk_str, "\n", "odd chunk: ", chunks.second->ToString()); + absl::StrAppend(&chunk_str, "\n", "alternate memory begin idx in loop: ", + alternate_memory_begin_idx_in_loop.value()); + absl::StrAppend(&chunk_str, "\n", "alternate memory end idx in loop: ", + alternate_memory_end_idx_in_loop.value()); + } return absl::StrCat( "Size: ", size, " savings: ", savings, " savings per byte: ", savings_per_byte, - " allocation type: ", AllocationTypeToString(allocation_type), "\n", - values_str, "\n", allocations_str); + " allocation type: ", AllocationTypeToString(allocation_type), chunk_str, + "\n", values_str, "\n", allocations_str); } bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { @@ -822,6 +861,18 @@ bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { allocation_type == AllocationType::kPrefetch; } +bool MemoryBoundLoopOptimizer::LoopValue::HasEvenAndOddChunks() const { + return chunks.first.has_value() && chunks.second.has_value(); +} + +void MemoryBoundLoopOptimizer::LoopValue::SetChunkPairAndInterval( + EvenOddChunkPair chunk_pair, int64_t begin_idx_in_loop, + int64_t end_idx_in_loop) { + chunks = chunk_pair; + alternate_memory_begin_idx_in_loop = begin_idx_in_loop; + alternate_memory_end_idx_in_loop = end_idx_in_loop; +} + void MemoryBoundLoopOptimizer::SortLoopValues() { absl::c_stable_sort(loop_values_, [](const LoopValue& a, const LoopValue& b) { return a.savings_per_byte > b.savings_per_byte; @@ -850,9 +901,13 @@ void MemoryBoundLoopOptimizer::AllocateLoopValues() { VLOG(1) << "Unsupported allocation: " << value.ToString(); } } + VLOG(3) << "Heap after allocating temporaries:\n" + << heap_.MemoryUsageToAsciiArt(); VLOG(3) << "Execution time after allocating temporaries: " << CalculateExecutionTime(); AllocatePrefetches(absl::MakeSpan(prefetch_values)); + VLOG(3) << "Heap after allocating prefetches:\n" + << heap_.MemoryUsageToAsciiArt(); VLOG(3) << "Execution time after allocating prefetches: " << CalculateExecutionTime(); } @@ -897,26 +952,10 @@ void MemoryBoundLoopOptimizer::PostProcess() { value.allocations.back()->AddUse(use); } } + VLOG(3) << "LoopValue: " << value.ToString(); } } -bool MemoryBoundLoopOptimizer::AllocateBetween(int64_t begin_idx, - int64_t end_idx, int64_t size) { - int64_t end_idx_sentinel = end_idx; - if (end_idx < begin_idx) { - end_idx_sentinel += loop_size_; - } - for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { - if (remaining_memory_[i % loop_size_] < size) { - return false; - } - } - for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { - remaining_memory_[i % loop_size_] -= size; - } - return true; -} - bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { VLOG(3) << "AllocateTemporary: " << value.ToString(); if (value.hlo_values.size() > 1) { @@ -925,37 +964,59 @@ bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { } int64_t definition_idx = value.loop_positions.front().first; int64_t max_use_idx; + int64_t begin_idx_in_loop = definition_idx; + int64_t end_idx_in_loop; if (!value.next_iteration_uses.empty()) { max_use_idx = value.next_iteration_uses.back().first; // If max_use_idx >= definition_idx, then this is a loop carried dependence // and we should not have called this function. CHECK_LT(max_use_idx, definition_idx); + end_idx_in_loop = max_use_idx + loop_size_; } else { max_use_idx = value.loop_uses.back().first; + end_idx_in_loop = max_use_idx; } - bool success = AllocateBetween(definition_idx, max_use_idx, value.size); - if (success) { - VLOG(3) << "Pos: " << value.loop_positions[0].second; - value.allocations.push_back(std::make_unique( - value.loop_positions[0].second, MemorySpace::kAlternate, std::nullopt, - definition_idx, max_use_idx, - /*is_scoped_allocation=*/false)); - AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); + EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, value.size); + if (!chunks.first.has_value() || !chunks.second.has_value()) { + VLOG(3) << "Could not find Allocation for temporary value: " + << value.ToString(); + return false; } - return success; + value.SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); + VLOG(3) << "Pos: " << value.loop_positions[0].second; + VLOG(3) << "Allocation found for temporary value: " << value.ToString(); + VLOG(3) << "Heap after allocating temporary value: " + << heap_.MemoryUsageToAsciiArt(); + value.allocations.push_back(std::make_unique( + value.loop_positions[0].second, MemorySpace::kAlternate, std::nullopt, + definition_idx, max_use_idx, + /*is_scoped_allocation=*/false)); + AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); + return true; } bool MemoryBoundLoopOptimizer::AllocatePinned(LoopValue& value) { - bool success = AllocateBetween(0, loop_size_ - 1, value.size); - if (success) { - CHECK(value.header_position); - value.allocations.push_back(std::make_unique( - *value.header_position, MemorySpace::kAlternate, std::nullopt, 0, - loop_size_, - /*is_scoped_allocation=*/false)); - AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); + int64_t begin_idx_in_loop = 0; + int64_t end_idx_in_loop = loop_size_ - 1; + EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, value.size); + if (!chunks.first.has_value() || !chunks.second.has_value()) { + VLOG(3) << "Could not find Allocation for pinned value: " + << value.ToString(); + return false; } - return success; + value.SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); + CHECK(value.header_position); + VLOG(3) << "Allocation found for pinned value: " << value.ToString(); + VLOG(3) << "Heap after allocating pinned value: " + << heap_.MemoryUsageToAsciiArt(); + value.allocations.push_back(std::make_unique( + *value.header_position, MemorySpace::kAlternate, std::nullopt, 0, + loop_size_, + /*is_scoped_allocation=*/false)); + AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); + return true; } bool MemoryBoundLoopOptimizer::AllocatePrefetches( @@ -1005,8 +1066,6 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( << *context.bandwidth_idle_times.rbegin(); } - context.additional_memory_used.resize(loop_size_, 0); - // Allocate prefetches by traversing the loop values in reverse order of // the first uses. for (int value_index : context.value_indices) { @@ -1014,10 +1073,6 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( } for (int i = 0; i < loop_size_; ++i) { - remaining_memory_[i] -= context.additional_memory_used[i]; - VLOG(3) << "Additional memory [" << i - << "]: " << context.additional_memory_used[i]; - VLOG(3) << "Remaining memory [" << i << "]: " << remaining_memory_[i]; VLOG(3) << "Remaining bandwidth [" << i << "] : " << context.bandwidth_idle_times[i]; } @@ -1026,7 +1081,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( bool MemoryBoundLoopOptimizer::AllocatePrefetch( int value_index, AllocatePrefetchesContext& context) { - LoopValue* value = context.values.at(value_index); + LoopValue* value = context.values[value_index]; VLOG(3) << "Allocating value: " << value->ToString(); int first_use_idx = value->loop_uses.front().first; int last_use_idx = value->loop_uses.back().first; @@ -1036,24 +1091,22 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( last_use_idx_sentinel = last_use_idx + loop_size_; CHECK_LT(last_use_idx, first_use_idx); } - bool out_of_memory = false; - for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { - int loop_idx = i % loop_size_; - if (context.additional_memory_used[loop_idx] + value->size > - remaining_memory_[loop_idx]) { - VLOG(3) << "Ran out of memory allocating for uses."; - out_of_memory = true; - } - } - if (out_of_memory) { - return false; - } float copy_resource = cost_analysis_.GetAsyncCopyElapsed(value->hlo_values.front()->shape()); VLOG(3) << "First use: " << value->loop_uses.begin()->second << " use idx: " << first_use_idx << " copy resource: " << copy_resource; - std::optional copy_start_time; + const auto& [even_chunk, odd_chunk] = heap_.FindEvenAndOddAllocationBetween( + first_use_idx, last_use_idx_sentinel, value->size); + if (!even_chunk.has_value() || !odd_chunk.has_value()) { + // Not enough memory to even fit the value in the alternate memory for the + // duration of its live range. + VLOG(3) << "Could not find Allocation for prefetch value: " + << value->ToString(); + return false; + } + + std::optional copy_start_loop_idx; // The general allocation algorithm for prefetches is to first calculate the // default-memory bandwidth idle times at each point (assuming all prefetches // succeeded). We show this pictorially below. We also show the previous @@ -1160,7 +1213,14 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( float accumulated_copy_resource = 0; std::vector early_forced_prefetch_value_indices; int early_forced_prefetch_value_search_index = 0; - float early_forced_prefetch_additional_memory = 0; + VLOG(3) << "Memory usage before allocating prefetch value: " + << value->ToString() << "\n" + << heap_.MemoryUsageToAsciiArt(); + // NOTE: We can, in practice, run the following loop for loop_size + // iterations(one full loop), till first_use_idx - loop_size, as opposed to + // limiting it till last_use_idx_sentinel - loop_size. This will allow a + // prefetch to use all the idle bandwidth available during one full loop + // iteration. for (int i = first_use_idx - 1; i >= last_use_idx_sentinel - loop_size_; --i) { int loop_idx = (i + loop_size_) % loop_size_; @@ -1175,8 +1235,9 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( ++early_forced_prefetch_value_search_index) { VLOG(3) << "Searching for early forced: " << early_forced_prefetch_value_search_index; - LoopValue* early_forced_value = context.values.at( - context.value_indices[early_forced_prefetch_value_search_index]); + LoopValue* early_forced_value = + context.values[context.value_indices + [early_forced_prefetch_value_search_index]]; if (early_forced_value->allocations.empty()) { continue; } @@ -1199,31 +1260,85 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( } early_forced_prefetch_value_indices.push_back( early_forced_prefetch_value_search_index); - early_forced_prefetch_additional_memory += early_forced_value->size; - VLOG(3) << "Found early-forced prefetch value: " + VLOG(3) + << "Memory usage before removing prefetch value for early force: " + << early_forced_value->ToString() << "\n" + << heap_.MemoryUsageToAsciiArt(); + // Remove the original chunk from the heap. + heap_.RemoveEvenOddChunkPair( + early_forced_value->alternate_memory_begin_idx_in_loop.value(), + early_forced_value->alternate_memory_end_idx_in_loop.value(), + early_forced_value->chunks); + } + } + + VLOG(3) << "Loop idx:" << loop_idx << " Early force prefetch values: " + << early_forced_prefetch_value_indices.size(); + VLOG(3) << "Memory usage before adding pending chunks: \n" + << heap_.MemoryUsageToAsciiArt(); + std::vector pending_chunk_intervals; + for (int early_forced_prefetch_value_index : + early_forced_prefetch_value_indices) { + LoopValue* early_forced_value = + context + .values[context.value_indices[early_forced_prefetch_value_index]]; + int64_t begin_idx_in_loop = loop_idx; + int64_t end_idx_in_loop = + early_forced_value->alternate_memory_end_idx_in_loop.value(); + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, early_forced_value->size); + if (!chunks.first.has_value() || !chunks.second.has_value()) { + VLOG(3) << "Could not allocate between " << begin_idx_in_loop << " and " + << end_idx_in_loop << " for early forced value: " << early_forced_value->ToString(); - VLOG(3) << "Early forced prefetch additional memory: " - << early_forced_prefetch_additional_memory; + VLOG(3) << "Memory usage after failed allocation: \n" + << heap_.MemoryUsageToAsciiArt(); + break; } + pending_chunk_intervals.push_back( + {begin_idx_in_loop, end_idx_in_loop, chunks}); + VLOG(3) << "Added pending chunk: " + << pending_chunk_intervals.back().ToString() + << " for value: " << early_forced_value->ToString(); } - // Overlap memory overhead only happens if the copy start overlaps with the - // first use (i.e. fully pipelined), so we'd need to account for 2X the - // buffer at this time. - int64_t overlap_memory_overhead = 0; - if (loop_idx == last_use_idx) { - overlap_memory_overhead = value->size; - VLOG(3) << "Loop idx == last use idx (" << loop_idx - << "), overlap memory overhead = " << overlap_memory_overhead; + if (pending_chunk_intervals.size() == + early_forced_prefetch_value_indices.size()) { + int64_t begin_idx_in_loop = i; + int64_t end_idx_in_loop = last_use_idx_sentinel; + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, value->size); + if (chunks.first.has_value() && chunks.second.has_value()) { + pending_chunk_intervals.push_back( + {begin_idx_in_loop, end_idx_in_loop, chunks}); + VLOG(3) << "Added pending chunk: " + << pending_chunk_intervals.back().ToString() + << " for current value: " << value->ToString(); + } else { + VLOG(3) << "Could not allocate between " << begin_idx_in_loop << " and " + << end_idx_in_loop << " for value: " << value->ToString(); + VLOG(3) << "Memory usage after failed allocation: \n" + << heap_.MemoryUsageToAsciiArt(); + } + } + + bool out_of_memory = pending_chunk_intervals.size() < + early_forced_prefetch_value_indices.size() + 1; + + // Remove the pending chunks from the heap. + for (auto& pending_chunk_interval : pending_chunk_intervals) { + VLOG(3) << "Removing pending chunk: " + << pending_chunk_interval.ToString(); + heap_.RemoveEvenOddChunkPair(pending_chunk_interval.begin_idx_in_loop, + pending_chunk_interval.end_idx_in_loop, + pending_chunk_interval.chunks); } - // OOM; give up prefetch. - if (context.additional_memory_used[loop_idx] + value->size + - overlap_memory_overhead + early_forced_prefetch_additional_memory > - remaining_memory_[loop_idx]) { - VLOG(3) << "Ran out of memory. Accumulated copy resource " - << accumulated_copy_resource << " out of " << copy_resource - << " at " << loop_idx; + VLOG(3) << "Memory usage after removing pending chunks: " + << heap_.MemoryUsageToAsciiArt(); + + if (out_of_memory) { + VLOG(3) << "Ran out of memory for value: " << value->ToString(); break; } @@ -1243,16 +1358,16 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( (copy_resource - accumulated_copy_resource)); if (bandwidth_idle_time >= copy_resource - accumulated_copy_resource) { accumulated_copy_resource = copy_resource; - copy_start_time = loop_idx; + copy_start_loop_idx = i; VLOG(3) << "Found the complete copy ratio and updated accumulated copy " "resource: " << accumulated_copy_resource; break; - } else if (!copy_start_time && + } else if (!copy_start_loop_idx.has_value() && accumulated_copy_resource + bandwidth_idle_time >= copy_resource * options_.desired_copy_ratio()) { accumulated_copy_resource += bandwidth_idle_time; - copy_start_time = loop_idx; + copy_start_loop_idx = i; VLOG(3) << "Found the desired copy ratio and updated accumulated copy " "resource: " << accumulated_copy_resource; @@ -1261,7 +1376,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( // Even if desired resource isn't reached, and if the options allow it, // allow a fully pipelined prefetch. accumulated_copy_resource += bandwidth_idle_time; - copy_start_time = loop_idx; + copy_start_loop_idx = i; VLOG(3) << "Could not reach the desired copy ratio but scheduling " "fully pipelined prefetch anyway: " << accumulated_copy_resource; @@ -1274,26 +1389,43 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( } // Could not find a suitable copy start time. - if (!copy_start_time) { + if (!copy_start_loop_idx.has_value()) { + // Restore original heap state as is. + VLOG(3) << "Could not find a suitable copy start time for value: " + << value->ToString(); + VLOG(3) << "Memory usage before restoring original state: " + << heap_.MemoryUsageToAsciiArt(); + for (int early_forced_prefetch_value_index : + early_forced_prefetch_value_indices) { + LoopValue* early_forced_value = + context + .values[context.value_indices[early_forced_prefetch_value_index]]; + // Allocate a chunk in at the same offset as the original prefetch. + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + early_forced_value->alternate_memory_begin_idx_in_loop.value(), + early_forced_value->alternate_memory_end_idx_in_loop.value(), + early_forced_value->size, + {early_forced_value->chunks.first->offset, + early_forced_value->chunks.second->offset}); + // The chunk should always be present as we are allocating at the same + // offset. + CHECK(chunks.first.has_value() && chunks.second.has_value()); + CHECK_EQ(chunks.first->offset, early_forced_value->chunks.first->offset); + CHECK_EQ(chunks.second->offset, + early_forced_value->chunks.second->offset); + } + VLOG(3) << "Memory usage after restoring original state: " + << heap_.MemoryUsageToAsciiArt(); return false; } - VLOG(3) << "Success: copy_start_time: " << *copy_start_time + VLOG(3) << "Success: copy_start_loop_idx: " << copy_start_loop_idx.value() << " leftover copy resource: " << (copy_resource - accumulated_copy_resource); - auto update_additional_memory_used = [&](int loop_idx, int64_t addition) { - VLOG(4) << "Updating additional memory used at " << loop_idx << ". " - << context.additional_memory_used[loop_idx] << " + " << addition - << " => " << (context.additional_memory_used[loop_idx] + addition) - << " (remaining: " << remaining_memory_[loop_idx] << ")"; - context.additional_memory_used[loop_idx] += addition; - CHECK_LE(context.additional_memory_used[loop_idx], - remaining_memory_[loop_idx]); - }; - for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { - int loop_idx = i % loop_size_; - update_additional_memory_used(loop_idx, value->size); - } + // We are early forcing the prefetches of the previous iteration. This is the + // corresponding copy start index in the previous iteration. + int early_prefetch_copy_start_loop_idx = + (copy_start_loop_idx.value() + loop_size_) % loop_size_; // We reset accumulated copy resource and then reuse it to accumulate copy // resource time in order to replay the previous for loop. It is important // that we use the same arithmetic operations (as opposed to subtracting from @@ -1303,58 +1435,78 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( --i) { int loop_idx = (i + loop_size_) % loop_size_; float& bandwidth_idle_time = context.bandwidth_idle_times[loop_idx]; - // Overlap memory overhead only happens if the copy start overlaps with the - // first use (i.e. fully pipelined), so we'd need to account for 2X the - // buffer at this time. - int64_t overlap_memory_overhead = 0; - update_additional_memory_used(loop_idx, - value->size + overlap_memory_overhead); if (bandwidth_idle_time < copy_resource - accumulated_copy_resource) { accumulated_copy_resource += bandwidth_idle_time; bandwidth_idle_time = 0; - if (loop_idx == *copy_start_time) { + if (loop_idx == early_prefetch_copy_start_loop_idx) { VLOG(3) << "Remaining copy resource: " << (copy_resource - accumulated_copy_resource); break; } } else { bandwidth_idle_time -= copy_resource - accumulated_copy_resource; - CHECK_EQ(loop_idx, *copy_start_time); + CHECK_EQ(loop_idx, early_prefetch_copy_start_loop_idx); break; } } - // Create the Allocation objects that correspond to the scheduled prefetch. - CHECK(value->header_position); - value->allocations.push_back(std::make_unique( - *value->header_position, MemorySpace::kDefault, std::nullopt, 0, - loop_size_, /*is_scoped_allocation=*/false)); - value->allocations.push_back(std::make_unique( - *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, - ((*copy_start_time - 1) + loop_size_) % loop_size_, first_use_idx, - last_use_idx_sentinel)); - AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); - // Account for the additional memory used by early forcing the already // scheduled prefetches. Also modify the start times of these to this // prefetch's copy start time. + // Allocate the force-early prefetches first, and allocate them in the same + // order as we did to check for out-of-memory, so we can reproduce the same + // allocation pattern. + // TODO(subhankarshah): Instead of depending on the order of allocation, store + // the offsets of the early forced prefetches and use that to allocate them. for (int early_forced_prefetch_value_index : early_forced_prefetch_value_indices) { - LoopValue* early_forced_value = context.values.at( - context.value_indices[early_forced_prefetch_value_index]); + LoopValue* early_forced_value = + context + .values[context.value_indices[early_forced_prefetch_value_index]]; CHECK(!early_forced_value->allocations.empty()); CopyAllocation* early_forced_prefetch = static_cast( early_forced_value->allocations.back().get()); - for (int index = early_forced_prefetch->copy_start_schedule_after(); - index >= *copy_start_time; --index) { - update_additional_memory_used(index, early_forced_value->size); - VLOG(3) << "Additional memory used: " << index << " " - << context.additional_memory_used[index]; - } + int64_t begin_idx_in_loop = early_prefetch_copy_start_loop_idx; + int64_t end_idx_in_loop = + early_forced_value->alternate_memory_end_idx_in_loop.value(); + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, early_forced_value->size); + // The chunk should always be present as we reproducing the same allocation + // pattern as the out-of-memory check. + CHECK(chunks.first.has_value() && chunks.second.has_value()); + CHECK_LT(begin_idx_in_loop, + early_forced_value->alternate_memory_begin_idx_in_loop.value()); + early_forced_value->SetChunkPairAndInterval(chunks, begin_idx_in_loop, + end_idx_in_loop); early_forced_prefetch->set_copy_start_schedule_after( - ((*copy_start_time - 1) + loop_size_) % loop_size_); - VLOG(3) << "Updated prefetch: " << early_forced_prefetch->ToString(); + ((early_prefetch_copy_start_loop_idx - 1) + loop_size_) % loop_size_); + VLOG(3) << "Early forced prefetch: " << early_forced_value->ToString(); + VLOG(3) << "Memory usage after allocating early forced prefetch: " + << heap_.MemoryUsageToAsciiArt(); } + + // Create the Allocation objects that correspond to the scheduled prefetch. + CHECK(value->header_position); + value->allocations.push_back(std::make_unique( + *value->header_position, MemorySpace::kDefault, std::nullopt, 0, + loop_size_, /*is_scoped_allocation=*/false)); + int64_t begin_idx_in_loop = copy_start_loop_idx.value(); + int64_t end_idx_in_loop = last_use_idx_sentinel; + // The chunk should always be present as we reproducing the same allocation + // pattern as the out-of-memory check. + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, value->size); + CHECK(chunks.first.has_value() && chunks.second.has_value()); + value->SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); + value->allocations.push_back(std::make_unique( + *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, + ((early_prefetch_copy_start_loop_idx - 1) + loop_size_) % loop_size_, + first_use_idx, last_use_idx_sentinel)); + VLOG(3) << "Allocation found for prefetch: " << value->ToString(); + VLOG(3) << "Memory usage after allocating prefetch: " << value->ToString() + << "\n" + << heap_.MemoryUsageToAsciiArt(); + AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); return true; } diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h index 5af196b4323af7..a1f0769fc14658 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ -#include #include #include #include @@ -26,6 +25,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -134,10 +134,10 @@ class LoopOptimizerBestFitHeap private: // REQUIRES: - // - begin_idx_in_loop <= end_idx_in_loop - // - begin_idx_in_loop is within [-loop_size loop_size) - // - end_idx_in_loop is within [0, 2 * loop_size) - // - end_idx_in_loop - begin_idx_in_loop + 1 <= 2 * loop_size (allocation + // * begin_idx_in_loop <= end_idx_in_loop + // * begin_idx_in_loop is within [-loop_size loop_size) + // * end_idx_in_loop is within [0, 2 * loop_size) + // * end_idx_in_loop - begin_idx_in_loop + 1 <= 2 * loop_size (allocation // colocated in even (or odd) iterations cannot span more than 2 loop // iterations) void CheckAllocationIntervalValid(int64_t begin_idx_in_loop, @@ -254,6 +254,7 @@ class MemoryBoundLoopOptimizer { // We represent each tensor used in the current iteration as a LoopValue, // wrapping the relevant information such as its HLO value, indices and // pointers to its use and position sites in different iterations. + // TODO(b/364621066): Make LoopValue a class. struct LoopValue { // An enum that encodes the allocation type that is suitable for this // LoopValue. See the comment above on what each of these mean. @@ -269,10 +270,20 @@ class MemoryBoundLoopOptimizer { static std::string AllocationTypeToString(AllocationType allocation_type); std::string ToString() const; + // Returns true if the LoopValue has chunks for even and odd loop + // iterations. + bool HasEvenAndOddChunks() const; + // Returns true if memory-bound loop optimizer supports allocating this type // of a loop value. bool IsAllocationTypeSupported() const; + // Sets the data members `chunks`, `alternate_memory_begin_idx_in_loop`, and + // `alternate_memory_end_idx_in_loop`. + void SetChunkPairAndInterval(EvenOddChunkPair chunk_pair, + int64_t begin_idx_in_loop, + int64_t end_idx_in_loop); + // The HloValues that correspond to this LoopValue. std::vector hlo_values; // The position in the header, if any. @@ -299,17 +310,25 @@ class MemoryBoundLoopOptimizer { float savings_per_byte; // The optimized AllocationSequence. AllocationSequence allocations; + // Chunks for even and odd iterations. If a loop value is double buffered + // then it must have different chunks for even and odd iterations. + EvenOddChunkPair chunks; + // Begin index of loop value in alternate memory. + // REQUIRES: + // * (-loop_size) <= alternate_memory_begin_idx_in_loop + // * alternate_memory_begin_idx_in_loop < loop_size + std::optional alternate_memory_begin_idx_in_loop = std::nullopt; + // End index of loop value in alternate memory. + // REQUIRES: + // * 0 <= alternate_memory_end_idx_in_loop + // * alternate_memory_end_idx_in_loop < 2*loop_size + std::optional alternate_memory_end_idx_in_loop = std::nullopt; }; // Factory method to create and initialize a MemoryBoundLoopOptimizer. static absl::StatusOr> Create( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis_, - const CostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn); + int loop_start, int loop_end, const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis, const Options& options); // Optimize the loop. Initialize must be called first. void Optimize(); @@ -324,13 +343,16 @@ class MemoryBoundLoopOptimizer { // Return the remaining memory vector for each point in time in the loop using // the allocation decisions so far. - const std::vector& remaining_memory() const { - return remaining_memory_; + std::vector RemainingMemory() const { + return heap_.RemainingMemoryByTime(); } int64_t MaxAlternateMemoryUsed() const { - return alternate_memory_size_ - *std::min_element(remaining_memory_.begin(), - remaining_memory_.end()); + return heap_.LastMemoryOffsetOccupied(); + } + + std::string MemoryUsageToAsciiArt() const { + return heap_.MemoryUsageToAsciiArt(); } // The loop start, end, and size accessors. @@ -344,15 +366,12 @@ class MemoryBoundLoopOptimizer { // The values that are requested to be prefetched. absl::Span values; - // A list of indices into values array, sorted by the start time of the - // first use. + // A list of indices into values array, sorted by the (descending) start + // time of the first use. std::vector value_indices; // Default memory remaining bandwidths assuming all prefetches succeeded. std::vector bandwidth_idle_times; - - // Additional memory used while performing prefetching. - std::vector additional_memory_used; }; MemoryBoundLoopOptimizer( @@ -362,7 +381,8 @@ class MemoryBoundLoopOptimizer { const HloAliasAnalysis& alias_analysis_, const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn); + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn, + int64_t alignment_in_bytes); // Initializes the data structures used by the optimizer. absl::Status Initialize(); @@ -384,9 +404,6 @@ class MemoryBoundLoopOptimizer { // Allocate LoopValues by dispatching to the correct Allocate method. void AllocateLoopValues(); - // Allocate and reserve memory between the given indices. - bool AllocateBetween(int64_t begin_idx, int64_t end_idx, int64_t size); - // Perform allocation type kTemporary. Return true if successful. bool AllocateTemporary(LoopValue& value); @@ -440,13 +457,22 @@ class MemoryBoundLoopOptimizer { absl::flat_hash_map instructions_in_next_iteration_; std::vector loop_values_; - std::vector remaining_memory_; absl::flat_hash_map>> uses_in_alternate_mem_; absl::flat_hash_map> positions_in_alternate_mem_; const ReservedScopedMemoryFunction& reserved_scoped_memory_fn_; + + // The heap used to allocate loop values. Since some loop values can be double + // buffered, between successive iterations, they must have different chunks + // for even and odd iterations. We model 4 iterations of the loop to allocate + // the loop values to alternate memory so we can model the buffers that cross + // one or two loop boundaries. The allocations in the 2nd and 3rd iterations + // represent the actual memory view. The 0th and 1st iteration serve to + // account for allocations, whose buffers cross one or two loop boundaries, + // into the 2nd and 3rd iterations. + LoopOptimizerBestFitHeap heap_; }; } // namespace memory_space_assignment diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc index f241269bb6fa77..460793d8309ec8 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -314,12 +314,17 @@ class MemoryBoundLoopOptimizerTest : public HloTestBase { optimizer_options.set_enabled(true); optimizer_options.set_desired_copy_ratio(0.7); optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); - TF_ASSIGN_OR_RETURN( - optimizer_, - MemoryBoundLoopOptimizer::Create( - loop_start, loop_end, alternate_memory_size, optimizer_options, - *live_range_, *alias_analysis_, *cost_analysis_, SizeFunction, - reserved_scoped_memory_fn)); + Options options; + options.max_size_in_bytes = alternate_memory_size; + options.alignment_in_bytes = 8; + options.alternate_memory_space = kAlternateMemorySpace; + options.cost_analysis = cost_analysis_.get(); + options.size_fn = SizeFunction; + options.reserved_scoped_memory_fn = reserved_scoped_memory_fn; + options.memory_bound_loop_optimizer_options = optimizer_options; + TF_ASSIGN_OR_RETURN(optimizer_, MemoryBoundLoopOptimizer::Create( + loop_start, loop_end, *live_range_, + *alias_analysis_, options)); return optimizer_.get(); } @@ -702,7 +707,10 @@ TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch) { )"; int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 64; + // Although alternate_memory_size=64 is minimum memory needed to fit the copy + // of param0 with desired copy ratio. alternate_memory_size=80 memory will + // ensure complete copy of param0 to alternate memory. + int64_t alternate_memory_size = 80; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -736,6 +744,55 @@ TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch) { EXPECT_EQ(optimizer->MaxAlternateMemoryUsed(), alternate_memory_size); } +TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch2) { + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) + $op1 = f32[1,4] add(f32[1,4] $prev_op4, f32[1,4] $op0) + $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op1) + $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $op3) + ROOT $root = tuple($op4, $param0) + )"; + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + // alternate_memory_size=64 is minimum memory needed to fit the copy of param0 + // with desired copy ratio. + int64_t alternate_memory_size = 64; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + absl::flat_hash_set seen_uses; + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + LOG(INFO) << loop_value.ToString(); + if (loop_value.hlo_values.front() + ->defining_position() + .instruction->name() == "param0") { + EXPECT_TRUE(loop_value.allocations.back()->is_copy_allocation()); + } + for (const auto& allocation : loop_value.allocations) { + for (const HloUse& use : allocation->uses()) { + EXPECT_FALSE(seen_uses.contains(use)) << use.ToString(); + seen_uses.insert(use); + } + } + } + + // Ensure all of the uses in the loop have an associated use. + for (absl::string_view inst_name : {"op0", "op1", "op2", "op3", "op4"}) { + HloInstruction* inst = + module->entry_computation()->GetInstructionWithName(inst_name); + EXPECT_TRUE(seen_uses.contains(HloUse{inst, 0})) << inst_name; + EXPECT_TRUE(seen_uses.contains(HloUse{inst, 1})) << inst_name; + } + // Check that execution time has increased to 2 since we will wait on copy + // done for param0. + EXPECT_EQ(optimizer->CalculateExecutionTime(), 2); + EXPECT_EQ(optimizer->MaxAlternateMemoryUsed(), alternate_memory_size); +} + // Specify a ReservedScopedMemoryFunction to the loop optimizer that causes each // HLO to reserve the entire alternate memory. If the loop optimizer is // correctly accounting for reserved scoped memory, it should not put any @@ -773,10 +830,10 @@ TEST_F(MemoryBoundLoopOptimizerTest, ReservedScopedMemory) { // Check that a spurious GetTupleElement instruction in a later iteration of a // loop does not cause MSA to CHECK fail, when identifying loops. Prior to the -// change instroduced with this test, IdentifyAndOptimizeMemoryBoundLoops() +// change introduced with this test, IdentifyAndOptimizeMemoryBoundLoops() // would recognize 4 iterations to the loop thinking that gte is a repeat of // op2. Doing so triggers the CHECKs introduced by the change that added this -// test to fail. So, the point of this test is to verfiy that we do not check +// test to fail. So, the point of this test is to verify that we do not check // fail. TEST_F(MemoryBoundLoopOptimizerTest, GetTupleElement) { absl::string_view hlo_string = R"( @@ -909,7 +966,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 432; + int64_t alternate_memory_size = 464; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -985,7 +1042,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { EXPECT_EQ(optimizer->CalculateExecutionTime(), 12.5); // Check the memory used at each point of the loop. - const std::vector& remaining_memory = optimizer->remaining_memory(); + std::vector remaining_memory = optimizer->RemainingMemory(); // Time 0: 3 temporaries (16 B) + param0 (128 B) + param1 (128 B) EXPECT_EQ(remaining_memory.at(0), alternate_memory_size - (3 * 16 + 128 + 128)); @@ -1049,7 +1106,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithoutOverlap) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 192; + int64_t alternate_memory_size = 208; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -1133,7 +1190,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap2) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 432; + int64_t alternate_memory_size = 464; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -1302,13 +1359,13 @@ TEST_F(MemoryBoundLoopOptimizerTest, TempAndPinnedAllocations) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); - int64_t alternate_memory_size = 64; + int64_t alternate_memory_size = 80; TF_ASSERT_OK_AND_ASSIGN( auto optimizer, CreateOptimizer(19, 24, module.get(), alternate_memory_size)); optimizer->Optimize(); - const std::vector& remaining_memory = optimizer->remaining_memory(); + std::vector remaining_memory = optimizer->RemainingMemory(); // Time 0: 3 temporaries (16 B) + 1 pinned (16 B) EXPECT_EQ(remaining_memory.at(0), alternate_memory_size - (3 * 16 + 16)); // Time 1: 3 temporaries (16 B) + 1 pinned (16 B) @@ -1373,12 +1430,12 @@ TEST_F(MemoryBoundLoopOptimizerTest, NegativeSavingNotPinned) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); - int64_t alternate_memory_size = 52; + int64_t alternate_memory_size = 72; TF_ASSERT_OK_AND_ASSIGN( auto optimizer, CreateOptimizer(21, 27, module.get(), alternate_memory_size)); optimizer->Optimize(); - const std::vector& remaining_memory = optimizer->remaining_memory(); + std::vector remaining_memory = optimizer->RemainingMemory(); // We expect that pinned_prev_param0 would not get pinned due to negative // savings: 32(uses) - 28 * 16(size) = -416 Time 0: 3 temporaries (16 B) + 1 // pinned (4 B) From 0ae01ddde526cb0d8b828dcc167429b857bd9e9f Mon Sep 17 00:00:00 2001 From: Xuefei Jiang Date: Thu, 19 Sep 2024 16:54:23 -0700 Subject: [PATCH 033/483] PR #16938: Add NANOO FP8 support for collaborative communication unit tests Imported from GitHub PR https://github.com/openxla/xla/pull/16938 This PR adds support for NANOO FP8 data format in the collaborative communication unit tests. - For the context on OCP FP8 and NANOO FP8, please refer to this comment: https://github.com/google/flax/pull/3993#issue-2350000228 - The unit tests in this PR are similar to GEMM unit test introduced in the following PR to be able to deal with both OCP and NANOO fp8 formats: https://github.com/openxla/xla/pull/10488 Copybara import of the project: -- 0fc74ccae6cfcaf4e8627ea338ee03783af0626b by Wen Chen : [AMD] Added NCCL support for fp8e4m3fnuz and fp8e5m2fnuz. -- d247af5cd33fe42698bb55ef1c18f32df8a02a21 by scxfjiang : refactor tests for collective comm ops -- 6f8c418b3052f7c531896bd5f8cbbc7a766ef7fc by scxfjiang : rafactor collective comm e2e tests -- 8ecb6ecf08a1536c5b3f8ba87e0e9f8813b1b359 by scxfjiang : update: replace str -- 338d3af2ca1a32302fdfe9d7abee335d24539ee9 by scxfjiang : get rid of macros Merging this change closes #16938 PiperOrigin-RevId: 676615012 --- .../xla/xla/service/gpu/runtime/nccl_api.cc | 2 + .../gpu/runtime/nccl_collective_thunk.cc | 2 + third_party/xla/xla/tests/BUILD | 1 + .../xla/xla/tests/collective_ops_e2e_test.cc | 40 ++-- .../xla/xla/tests/collective_ops_test.cc | 179 ++++++++++-------- 5 files changed, 139 insertions(+), 85 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc index 77f022da6ec64f..15949ac9cae999 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc @@ -112,6 +112,8 @@ static absl::StatusOr ToNcclDataType(PrimitiveType dtype, case S8: case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return ncclInt8; case PRED: case U8: diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc index 8e075c8d01c730..fb2282c8e73ae7 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -92,6 +92,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type, // they involve actual computation and not just data movement. case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return !IsReductionCollective(reduction_op); default: return false; diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 16f1f8e8f07c40..43e24cfd8d86f3 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2374,6 +2374,7 @@ xla_test( "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 01c60e5d6ac683..7f30b975ebfe02 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -54,6 +55,13 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) { class CollectiveOpsTestE2E : public HloTestBase { public: + CollectiveOpsTestE2E() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + bool IsCuda() { return std::holds_alternative(Capability()); } @@ -108,6 +116,13 @@ class CollectiveOpsTestE2E : public HloTestBase { /*argument_provider*/ [](int64_t, int64_t) { return nullptr; }, num_replicas, /*run_hlo_passes=*/false, &device_assignment); } + + protected: + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; }; // E2E tests for collective ops. These will generally verify some HLO transform @@ -811,11 +826,11 @@ ENTRY main.12 { TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherAndReduceScatterF8) { absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<>[2,16,48]{2,1,0}, <>[48,192]{1,0}, <>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 ENTRY main.12 { - Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + Arg_0.1 = <>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + Arg_1.2 = <>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} Arg_2.3 = bf16[] parameter(3) Arg_3.4 = bf16[] parameter(4) broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} @@ -834,12 +849,12 @@ ENTRY main.12 { constant.1 = bf16[] constant(448.) broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp) + convert.2 = <>[2,16,192]{2,1,0} convert(clamp) Arg_5.6 = bf16[] parameter(6) broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} + Arg_6.7 = <>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} Arg_7.8 = bf16[] parameter(7) broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) @@ -852,8 +867,9 @@ ENTRY main.12 { // Disable the dot merger pass which can prevent the creation of FP8 GEMM // Custom Calls. - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, - /*disable_dot_merger=*/true); + CollectiveOpsCompareWindowedNonWindowed( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), + /*disable_dot_merger=*/true); // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer // architectures. @@ -863,7 +879,8 @@ ENTRY main.12 { opts.set_xla_gpu_graph_min_graph_size(200); opts.set_xla_gpu_enable_triton_gemm(false); opts.add_xla_disable_hlo_passes("dot-merger"); - CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); + CollectiveOpsVerifyF8Matmul( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -1023,7 +1040,7 @@ while_body { r = bf16[32,128] bitcast(dynamic-slice.k) a = bf16[32,128] add(r, r), control-predecessors={constant.2559} // A fp8 pattern of quant-dequant before the collective AG. - qa = f8e4m3fn[32,128] convert(a) + qa = <>[32,128] convert(a) dqa = bf16[32,128] convert(qa) a_scale = bf16[] get-tuple-element(param), index=3 a_scales = bf16[32,128] broadcast(a_scale), dimensions={} @@ -1031,7 +1048,7 @@ while_body { mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128} - qma = f8e4m3fn[128,128] convert(ma) + qma = <>[128,128] convert(ma) dqma = bf16[128,128] convert(qma) ma_scale = bf16[] get-tuple-element(param), index=4 ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={} @@ -1061,7 +1078,8 @@ ENTRY entry { opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); opts.set_xla_gpu_enable_pipelined_collectives(true); opts.set_xla_gpu_enable_triton_gemm(false); - CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); + CollectiveOpsVerifyF8Matmul( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); } TEST_F(CollectiveOpsTestE2E, diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index 9cd874c9e03c13..fcecf8f4a66cef 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -1753,80 +1753,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) { } } -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[1,2] constant({{1,2}}) - allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0} - p = f8e4m3fn[4] reshape(allgather) - ROOT out = f32[4] convert(p) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - for (const Literal& result : results) { - LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); - } -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[2] constant({1,2}) - a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0} - ROOT out = f32[2] convert(a2a) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); - LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e5m2[2] constant({1,2}) - a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} - ROOT out = f32[2] convert(a1) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); -} - XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) { const char* const kModuleStr = R"( HloModule test @@ -2273,5 +2199,110 @@ body { results[1])); } +class Fp8CollectiveOpsTest : public CollectiveOpsTest { + public: + Fp8CollectiveOpsTest() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + + protected: + bool IsCuda() { + return std::holds_alternative(Capability()); + } + + const se::GpuComputeCapability& Capability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; +}; + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[1,2] constant({{1,2}}) + allgather = <>[2, 2] all-gather(a0), dimensions={0} + p = <>[4] reshape(allgather) + ROOT out = f32[4] convert(p) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); + } +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a2a = <>[2] all-to-all(a0), dimensions={0} + ROOT out = f32[2] convert(a2a) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); + LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a1 = <>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} + ROOT out = f32[2] convert(a1) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); +} + } // namespace } // namespace xla From 6892ee6e8fd51df8ecea8def2f2afd1aaba30467 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 16:56:24 -0700 Subject: [PATCH 034/483] Remove a couple unused functions for generating reshape strategies. PiperOrigin-RevId: 676615632 --- .../auto_sharding/auto_sharding.cc | 78 ------------------- 1 file changed, 78 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 66370ab0020b32..3feff04d56e1b3 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1110,84 +1110,6 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, } } -void BuildStrategyAndCostForReshape(const HloInstruction* ins, - const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - absl::Span tensor_dims, - StrategyGroup& strategy_group); - -// Enumerate all partitions for reshape. Batch dim is always partitioned. -void EnumeratePartitionReshape( - const HloInstruction* ins, const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - const bool only_allow_divisible, const int64_t partition_dimensions, - const std::vector& tensor_dims, StrategyGroup& strategy_group) { - const auto tensor_dims_size = tensor_dims.size(); - if (tensor_dims_size == partition_dimensions) { - BuildStrategyAndCostForReshape(ins, device_mesh, cluster_env, strategy_map, - tensor_dims, strategy_group); - return; - } - // Split batch dim + another dim - for (int64_t i = 0; i < ins->shape().rank(); ++i) { - auto tensor_it = std::find(tensor_dims.begin(), tensor_dims.end(), i); - if (tensor_it != tensor_dims.end()) { - continue; - } - if (ins->shape().dimensions(i) < device_mesh.dim(tensor_dims_size)) { - continue; - } - if (only_allow_divisible && - !IsDivisible(ins->shape().dimensions(i), - device_mesh.dim(tensor_dims_size))) { - continue; - } - - std::vector next_tensor_dims = tensor_dims; - next_tensor_dims.push_back(i); - EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - only_allow_divisible, partition_dimensions, - next_tensor_dims, strategy_group); - } -} - -void BuildStrategyAndCostForReshape(const HloInstruction* ins, - const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - absl::Span tensor_dims, - StrategyGroup& strategy_group) { - const HloInstruction* operand = ins->operand(0); - const Shape& operand_shape = operand->shape(); - const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); - std::vector mesh_dims(tensor_dims.size()); - std::iota(mesh_dims.begin(), mesh_dims.end(), 0); - const HloSharding output_spec = - Tile(ins->shape(), tensor_dims, mesh_dims, device_mesh); - std::optional input_spec = hlo_sharding_util::ReshapeSharding( - ins->shape(), operand_shape, output_spec); - if (!input_spec.has_value()) { // invalid reshape - return; - } - std::string name = - absl::StrFormat("S%s @ {%s}", absl::StrJoin(tensor_dims, ""), - absl::StrJoin(mesh_dims, ",")); - double compute_cost = 0, communication_cost = 0; - double memory_cost = ByteSizeOfShapeWithSharding(ins->shape(), output_spec); - - ReshardingCosts communication_resharding_costs{ - CommunicationReshardingCostVector(operand_strategy_group, operand_shape, - *input_spec, cluster_env)}; - ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( - operand_strategy_group, operand_shape, *input_spec, cluster_env)}; - strategy_group.AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, - memory_cost, std::move(communication_resharding_costs), - std::move(memory_resharding_costs)}), - {*input_spec}); -} - // Return the maximum number of tiles among all strategies of an instruction. int64_t MaxNumTiles(const StrategyMap& strategy_map, const HloInstruction* ins) { From 2516447da9eae56946f08aa68f82f6b5102f5d29 Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Thu, 19 Sep 2024 17:04:32 -0700 Subject: [PATCH 035/483] [XLA] Make async call names unique We need the instruction name to go through the name uniquer in order for it to be correctly parsable when being cloned, as we'll attempt to increment the trailing suffix digit. Otherwise the cloned instruction will have the same name as its original version. PiperOrigin-RevId: 676617905 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 10 +++--- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/hlo_computation_test.cc | 36 +++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 6ab8e712c1eee7..34ea42a0536f9e 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -1156,7 +1156,8 @@ absl::StatusOr HloComputation::CreateAsyncInstructions( HloInstruction* root = builder.AddInstruction( instruction->CloneWithNewOperands(instruction->shape(), parameters)); if (override_names) { - root->SetAndSanitizeName(absl::StrCat(instruction->name(), ".cloned")); + parent()->SetAndUniquifyInstrName( + root, absl::StrCat(instruction->name(), ".cloned")); } HloComputation* async_computation = parent_->AddEmbeddedComputation(builder.Build(root)); @@ -1171,9 +1172,10 @@ absl::StatusOr HloComputation::CreateAsyncInstructions( async_done = AddInstruction( HloInstruction::CreateAsyncDone(root->shape(), async_start)); if (override_names) { - async_start->SetAndSanitizeName( - absl::StrCat(root->name(), ".call-start")); - async_done->SetAndSanitizeName(absl::StrCat(root->name(), ".call-done")); + parent()->SetAndUniquifyInstrName( + async_start, absl::StrCat(root->name(), ".call-start")); + parent()->SetAndUniquifyInstrName( + async_done, absl::StrCat(root->name(), ".call-done")); } } async_start->set_metadata(instruction->metadata()); diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 095fd1a1f21fc3..808f7e3e00440e 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4549,6 +4549,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc index a7190b33f2088d..87b09dbc4ce28d 100644 --- a/third_party/xla/xla/service/hlo_computation_test.cc +++ b/third_party/xla/xla/service/hlo_computation_test.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -972,5 +973,40 @@ TEST_F(HloComputationTest, CompositeCall) { EXPECT_EQ(composite_call->frontend_attributes().map().size(), 3); } +TEST_F(HloComputationTest, CloneComputationWithAsyncInstructions) { + constexpr std::string_view hlo = R"( +HloModule main + +comp.0 { + ROOT custom-call.0 = () custom-call(), custom_call_target="foo" +} + +ENTRY main { + in.0 = () parameter(0) + call.0 = () call(), to_apply=comp.0 + ROOT out.0 = () tuple() +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + HloComputation* comp0 = FindComputation(module.get(), "comp.0"); + HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.0"); + TF_ASSERT_OK(comp0->CreateAsyncInstructions( + custom_call, /*context_shapes=*/{ShapeUtil::MakeScalarShape(U32)}, + /*async_execution_thread=*/HloInstruction::kMainExecutionThread, + /*replace=*/true, + /*override_names=*/true)); + + HloComputation* comp1 = module->AddEmbeddedComputation(comp0->Clone()); + HloComputation* comp2 = module->AddEmbeddedComputation(comp0->Clone()); + EXPECT_NE(comp0->root_instruction()->name(), + comp1->root_instruction()->name()); + EXPECT_NE(comp0->root_instruction()->operand(0)->name(), + comp1->root_instruction()->operand(0)->name()); + EXPECT_NE(comp1->root_instruction()->name(), + comp2->root_instruction()->name()); + EXPECT_NE(comp1->root_instruction()->operand(0)->name(), + comp2->root_instruction()->operand(0)->name()); +} + } // namespace } // namespace xla From 37095221b92df1559123ed9ce0444ade6b1b6067 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Thu, 19 Sep 2024 17:07:47 -0700 Subject: [PATCH 036/483] Tag `linalg:linear_operator_block_lower_triangular_test` with `no_gpu` because of toolchain change causing issues in fastbuild with GPU. PiperOrigin-RevId: 676619014 --- tensorflow/python/kernel_tests/linalg/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 4f12dc12ed3b7f..c84540eb30daa2 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -213,7 +213,10 @@ cuda_py_strict_test( size = "medium", srcs = ["linear_operator_block_lower_triangular_test.py"], shard_count = 8, - tags = ["optonly"], + tags = [ + "no_gpu", # Seg fault. http://b/365525243 + "optonly", + ], deps = [ "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", From 1c86f98911b60f89ad5040031859bbcc450e8fd1 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Thu, 19 Sep 2024 17:10:41 -0700 Subject: [PATCH 037/483] [HLO Componentization] Create hlo/translate sub-component (Phase II). This CL takes care of 1. Migrating external projects dependencies from xla/translate --> xla/hlo/translate Phase I takes care of 1. Migrating xla/translate --> xla/hlo/translate 2. Setting up build aliases in xla/translate ensuring external dependencies are still satisfied. PiperOrigin-RevId: 676619763 --- third_party/xla/xla/python/pjrt_ifrt/BUILD | 2 +- third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index a87d033d727d04..3bfb4fc90312d4 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -222,6 +222,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "//xla/pjrt:host_callback", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", @@ -238,7 +239,6 @@ cc_library( "//xla/python/ifrt:attribute_map", "//xla/python/ifrt/hlo:hlo_program", "//xla/service:hlo_proto_cc", - "//xla/translate/mhlo_to_hlo:type_to_shape", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index c19fdece16f028..c665cbb0cdd68e 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -57,7 +58,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" #include "xla/xla_data.pb.h" From 35ce1c48f2165b514333e5351c6a141d17058dad Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 19 Sep 2024 17:26:29 -0700 Subject: [PATCH 038/483] Cleanup. Remove `index_parallel_in_dim` from `GatherScatterParallelDims` since `indices_parallel_dims` and `index_parallel_in_dim` represents the same stuff. We also remove `IndexAlignedOperandParallelDims(const GatherScatterParallelDims&)` since the `indices_parallel_dims` and `operand_parallel_dims` are aligned when we generating them. PiperOrigin-RevId: 676624119 --- .../auto_sharding/auto_sharding_strategy.cc | 6 +-- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 37 +++++------------- .../xla/xla/hlo/utils/hlo_sharding_util.h | 7 ---- .../xla/xla/service/sharding_propagation.cc | 6 +-- .../xla/service/spmd/spmd_partitioner_util.cc | 38 ++++++++----------- .../xla/xla/service/spmd/spmd_prepare.cc | 4 +- 6 files changed, 30 insertions(+), 68 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 3bfd5e2838cf54..5e05087c2da1cd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -174,8 +174,7 @@ void GenerateScatterShardingFromOperands( } absl::InlinedVector aligned_operand_parallel_dims = - hlo_sharding_util::IndexAlignedOperandParallelDims( - *scatter_parallel_dims); + scatter_parallel_dims->operand_parallel_dims; absl::InlinedVector update_parallel_dims = hlo_sharding_util::GetScatterParallelUpdateDims(*scatter, *scatter_parallel_dims); @@ -436,8 +435,7 @@ BuildStrategyAndCost( HloSharding output_spec = indices_to_combine_spec; if (gather_parallel_dims) { auto aligned_operand_parallel_dims = - hlo_sharding_util::IndexAlignedOperandParallelDims( - *gather_parallel_dims); + gather_parallel_dims->operand_parallel_dims; auto output_parallel_dims = hlo_sharding_util::GetGatherParallelOutputDims( *ins, *gather_parallel_dims); diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index c3cd98219899ea..881744646f0baa 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -1533,7 +1533,7 @@ std::optional GatherOperandShardingFromOutputParallelDimensions( auto output_parallel_dims = GetGatherParallelOutputDims(gather, *parallel_dims); auto output_aligned_operand_parallel_dims = - IndexAlignedOperandParallelDims(*parallel_dims); + parallel_dims->operand_parallel_dims; const Shape gather_shape = gather.shape(); CHECK_EQ(output_parallel_dims.size(), output_aligned_operand_parallel_dims.size()); @@ -1769,7 +1769,7 @@ std::optional ScatterUpdateShardingFromOutputParallelDimensions( auto update_parallel_dims = GetScatterParallelUpdateDims(scatter, *parallel_dims); auto index_aligned_operand_parallel_dims = - IndexAlignedOperandParallelDims(*parallel_dims); + parallel_dims->operand_parallel_dims; auto operand_parallel_dims_sorted = index_aligned_operand_parallel_dims; absl::c_sort(operand_parallel_dims_sorted); auto operand_aligned_update_parallel_dims = AlignSmallContainers( @@ -2247,8 +2247,7 @@ std::optional GetGatherScatterBatchParallelDims( // %indices = concatenate(..., %iota.1, ...) // ... = gather(..., %indices) // is common for tf.reverse_sequence and would match this case. - const int num_indices = index_map.size(); - std::vector index_parallel_in_dim(num_indices, -1); + std::vector index_parallel_in_dim(index_map.size(), -1); // looks through any copies to find the concatenate. auto findConcatenate = [&](const HloInstruction* indices) { @@ -2320,8 +2319,8 @@ std::optional GetGatherScatterBatchParallelDims( } } if (!indices_parallel_dims.empty()) { - return GatherScatterParallelDims{ - indices_parallel_dims, operand_parallel_dims, index_parallel_in_dim}; + return GatherScatterParallelDims{indices_parallel_dims, + operand_parallel_dims}; } return std::nullopt; } @@ -2357,10 +2356,9 @@ std::optional GetScatterParallelBatchDims( static absl::InlinedVector GetGatherOutputOrScatterUpdateParallelDims( - const Shape& shape, const GatherScatterParallelDims& parallel_dim, + const Shape& shape, absl::Span indices_parallel_dims, int64_t index_vector_dim, absl::Span offset_or_window_dims) { absl::InlinedVector output_parallel_dims; - auto indices_parallel_dims = parallel_dim.indices_parallel_dims; for (int64_t indices_parallel_dim : indices_parallel_dims) { for (int i = 0, idx_dim = 0; i < shape.dimensions_size(); ++i) { if (absl::c_linear_search(offset_or_window_dims, i)) { @@ -2385,7 +2383,8 @@ absl::InlinedVector GetGatherParallelOutputDims( int64_t index_vector_dim = dnums.index_vector_dim(); const auto& offset_dims = dnums.offset_dims(); return GetGatherOutputOrScatterUpdateParallelDims( - output_shape, parallel_dim, index_vector_dim, offset_dims); + output_shape, parallel_dim.indices_parallel_dims, index_vector_dim, + offset_dims); } absl::InlinedVector GetScatterParallelUpdateDims( @@ -2397,7 +2396,8 @@ absl::InlinedVector GetScatterParallelUpdateDims( int64_t index_vector_dim = dnums.index_vector_dim(); const auto& window_dims = dnums.update_window_dims(); return GetGatherOutputOrScatterUpdateParallelDims( - update_shape, parallel_dim, index_vector_dim, window_dims); + update_shape, parallel_dim.indices_parallel_dims, index_vector_dim, + window_dims); } absl::InlinedVector GetGatherOperandPassthroughOperandDims( @@ -2537,23 +2537,6 @@ HloSharding InferGatherScatterParallelShardingFromOperandSharding( replicate_non_parallel_dims.metadata()); } -absl::InlinedVector IndexAlignedOperandParallelDims( - const GatherScatterParallelDims& parallel_dims) { - CHECK_EQ(parallel_dims.indices_parallel_dims.size(), - parallel_dims.operand_parallel_dims.size()); - std::vector index_parallel_in_dim = - parallel_dims.index_parallel_in_dim; - // Remove all -1s in `index_parallel_in_dim`. - index_parallel_in_dim.erase(std::remove(index_parallel_in_dim.begin(), - index_parallel_in_dim.end(), -1), - index_parallel_in_dim.end()); - // Populate the operand parallel dimensions based on the order of the index - // batch dims (which is the same order as the output). - return AlignSmallContainers(parallel_dims.operand_parallel_dims, - index_parallel_in_dim, - parallel_dims.indices_parallel_dims); -} - std::string GroupedSharding::ToString() const { auto result = absl::StrCat("group dims: ", absl::StrJoin(group_dims, ","), "\n"); diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index fc440f38f3215d..6b6463befb10a8 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -44,7 +44,6 @@ namespace hlo_sharding_util { struct GatherScatterParallelDims { absl::InlinedVector indices_parallel_dims; absl::InlinedVector operand_parallel_dims; - std::vector index_parallel_in_dim; }; // Determines if the first operand 'potential_subsharding' is a subsharding of @@ -364,12 +363,6 @@ HloSharding InferGatherScatterParallelShardingFromOperandSharding( absl::Span output_aligned_operand_parallel_dims, absl::Span output_parallel_dims); -// Returns the parallel dimensions of the data operand of a gather/scatter with -// the order of the parallel dimensions matching that of the parallel dimensions -// of the indices. -absl::InlinedVector IndexAlignedOperandParallelDims( - const GatherScatterParallelDims& parallel_dims); - // Represents grouping devices in a tiled sharding along certain dimensions. // Elements in group dimensions define different device groups, and the sharding // represents the in-group sharding. diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 316644bf87ea8b..886a3d63a90aed 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -427,8 +427,7 @@ bool InferGatherParallelShardingFromOperands( bool may_combine_partial_sharding) { CHECK(DynCast(instruction)); bool changed = false; - auto aligned_operand_parallel_dims = - hlo_sharding_util::IndexAlignedOperandParallelDims(parallel_dims); + auto aligned_operand_parallel_dims = parallel_dims.operand_parallel_dims; auto output_parallel_dims = hlo_sharding_util::GetGatherParallelOutputDims( *instruction, parallel_dims); // Infer output sharding from scatter operand sharding. @@ -469,8 +468,7 @@ bool InferScatterParallelShardingFromOperands( auto scatter_indices = scatter->scatter_indices(); auto scatter_updates = scatter->scatter_updates(); bool changed = false; - auto aligned_operand_parallel_dims = - hlo_sharding_util::IndexAlignedOperandParallelDims(parallel_dims); + auto aligned_operand_parallel_dims = parallel_dims.operand_parallel_dims; auto update_parallel_dims = hlo_sharding_util::GetScatterParallelUpdateDims( *instruction, parallel_dims); auto output_parallel_dims = aligned_operand_parallel_dims; diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc index 83c41ec91f840f..5034d168589bf9 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc @@ -2142,8 +2142,8 @@ std::optional GatherScatterOperandsShardedAcrossParallelDims( const HloInstruction& operand, const HloInstruction& indices, const hlo_sharding_util::GatherScatterParallelDims& parallel_dims) { - auto& indices_parallel_dims = parallel_dims.indices_parallel_dims; - auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; + const auto& indices_parallel_dims = parallel_dims.indices_parallel_dims; + const auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; if (indices_parallel_dims.size() != operand_parallel_dims.size()) { return std::nullopt; } @@ -2154,32 +2154,26 @@ GatherScatterOperandsShardedAcrossParallelDims( if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) { return std::nullopt; } - absl::InlinedVector indices_parallel_dims_ordered_as_operand; - for (int idx : parallel_dims.index_parallel_in_dim) { - if (idx != -1) { - indices_parallel_dims_ordered_as_operand.push_back(idx); - } - } + if (new_index_shard.IsReplicated()) { return GatherScatterParallelDimSharding{ CreateMatchingShardingOnDims(indices.shape(), new_operand_shard, - indices_parallel_dims_ordered_as_operand, + indices_parallel_dims, operand_parallel_dims), new_operand_shard}; } if (new_operand_shard.IsReplicated()) { return GatherScatterParallelDimSharding{ - new_index_shard, - CreateMatchingShardingOnDims(operand.shape(), new_index_shard, - operand_parallel_dims, - indices_parallel_dims_ordered_as_operand)}; + new_index_shard, CreateMatchingShardingOnDims( + operand.shape(), new_index_shard, + operand_parallel_dims, indices_parallel_dims)}; } // Parallel dimension distribution needs to be the same, so try to steal // sharding from partial replication to compensate. if (idx_parallel_tiles_num != op_parallel_tiles_num) { auto to_adjust_dims = operand_parallel_dims; - auto target_dims = indices_parallel_dims_ordered_as_operand; + auto target_dims = indices_parallel_dims; HloSharding* target = &new_index_shard; HloSharding* to_adjust = &new_operand_shard; if (idx_parallel_tiles_num < op_parallel_tiles_num) { @@ -2231,19 +2225,17 @@ GatherScatterOperandsShardedAcrossParallelDims( std::vector operand_shard_tile_dims( new_operand_shard.tile_assignment().dimensions().begin(), new_operand_shard.tile_assignment().dimensions().end()); - for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) { + for (int i = 0; i < indices_parallel_dims.size(); ++i) { operand_shard_tile_dims[operand_parallel_dims[i]] = - new_index_shard.tile_assignment().dim( - indices_parallel_dims_ordered_as_operand[i]); + new_index_shard.tile_assignment().dim(indices_parallel_dims[i]); } auto operand_shard_tiles = new_operand_shard.tile_assignment().Reshape(operand_shard_tile_dims); - new_operand_shard = - AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(operand_shard_tiles) - : HloSharding::Tile(operand_shard_tiles), - operand_parallel_dims, new_index_shard, - indices_parallel_dims_ordered_as_operand); + new_operand_shard = AlignShardingOnDims( + new_operand_shard.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(operand_shard_tiles) + : HloSharding::Tile(operand_shard_tiles), + operand_parallel_dims, new_index_shard, indices_parallel_dims); return GatherScatterParallelDimSharding{new_index_shard, new_operand_shard}; } diff --git a/third_party/xla/xla/service/spmd/spmd_prepare.cc b/third_party/xla/xla/service/spmd/spmd_prepare.cc index 51655b90861d48..83bf8495cb18ca 100644 --- a/third_party/xla/xla/service/spmd/spmd_prepare.cc +++ b/third_party/xla/xla/service/spmd/spmd_prepare.cc @@ -108,9 +108,7 @@ absl::StatusOr ProcessScatter(HloInstruction* hlo, if (lhs_parallel_dims->operand_parallel_dims != rhs_parallel_dims->operand_parallel_dims || lhs_parallel_dims->indices_parallel_dims != - rhs_parallel_dims->indices_parallel_dims || - lhs_parallel_dims->index_parallel_in_dim != - rhs_parallel_dims->index_parallel_in_dim) { + rhs_parallel_dims->indices_parallel_dims) { return false; } if (lhs_parallel_dims->operand_parallel_dims.size() != From db75103e284795e468160ca86442a7b14f9bd1d4 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 19 Sep 2024 17:51:22 -0700 Subject: [PATCH 039/483] [PjRt-IFRT] Refactor the functions converting IFRT DType and xla::PrimitiveType IFRT-XLA conversion functions `xla::ifrt::ToPrimitiveType()` and `xla::ifrt::ToDType()` are pulled out of `pjrt_array.{h,cc}` into `pjrt_dtype.{h,cc}` that have no other PjRt-IFRT dependencies. This makes it easier to use these conversion functions in subsequent CLs without pulling the dependency to the entire PjRt-IFRT. To make migration easy, `pjrt_array.h` includes `pjrt_dtype.h` to allow the functions to be transitively defined. The downstream user code will be migrated incrementally to use `pjrt_dtype.h` directly before this include is removed. PiperOrigin-RevId: 676631522 --- third_party/xla/xla/python/pjrt_ifrt/BUILD | 14 +++ .../xla/xla/python/pjrt_ifrt/pjrt_array.cc | 76 ------------- .../xla/xla/python/pjrt_ifrt/pjrt_array.h | 7 +- .../xla/xla/python/pjrt_ifrt/pjrt_dtype.cc | 103 ++++++++++++++++++ .../xla/xla/python/pjrt_ifrt/pjrt_dtype.h | 35 ++++++ 5 files changed, 153 insertions(+), 82 deletions(-) create mode 100644 third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.cc create mode 100644 third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.h diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 3bfb4fc90312d4..ed0107e97be201 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -215,6 +215,7 @@ cc_library( deps = [ ":basic_string_array", ":pjrt_attribute_map_util", + ":pjrt_dtype", ":xla_ifrt", "//xla:literal", "//xla:shape_util", @@ -300,6 +301,19 @@ xla_cc_test( ], ) +cc_library( + name = "pjrt_dtype", + srcs = ["pjrt_dtype.cc"], + hdrs = ["pjrt_dtype.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/python/ifrt", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "basic_string_array", srcs = ["basic_string_array.cc"], diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 4b0949165e75fe..88a0d858d33b97 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -130,82 +130,6 @@ absl::StatusOr GetMemoryKindFromPjRtBuffers( char PjRtCompatibleArray::ID = 0; char PjRtArray::ID = 0; -absl::StatusOr ToPrimitiveType(DType dtype) { - switch (dtype.kind()) { -#define CASE(DT, PT) \ - case DT: \ - static_assert(PT == \ - static_cast(static_cast(DT))); \ - return PT - CASE(DType::kInvalid, xla::PrimitiveType::PRIMITIVE_TYPE_INVALID); - CASE(DType::kPred, xla::PrimitiveType::PRED); - CASE(DType::kS2, xla::PrimitiveType::S2); - CASE(DType::kS4, xla::PrimitiveType::S4); - CASE(DType::kS8, xla::PrimitiveType::S8); - CASE(DType::kS16, xla::PrimitiveType::S16); - CASE(DType::kS32, xla::PrimitiveType::S32); - CASE(DType::kS64, xla::PrimitiveType::S64); - CASE(DType::kU2, xla::PrimitiveType::U2); - CASE(DType::kU4, xla::PrimitiveType::U4); - CASE(DType::kU8, xla::PrimitiveType::U8); - CASE(DType::kU16, xla::PrimitiveType::U16); - CASE(DType::kU32, xla::PrimitiveType::U32); - CASE(DType::kU64, xla::PrimitiveType::U64); - CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); - CASE(DType::kF8E4M3B11FNUZ, xla::PrimitiveType::F8E4M3B11FNUZ); - CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); - CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); - CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); - CASE(DType::kF16, xla::PrimitiveType::F16); - CASE(DType::kF32, xla::PrimitiveType::F32); - CASE(DType::kBF16, xla::PrimitiveType::BF16); - CASE(DType::kF64, xla::PrimitiveType::F64); - CASE(DType::kC64, xla::PrimitiveType::C64); - CASE(DType::kC128, xla::PrimitiveType::C128); - CASE(DType::kToken, xla::PrimitiveType::TOKEN); -#undef CASE - case DType::kString: - return InvalidArgument("Not supported as XLA PrimitiveType: %d", - static_cast(dtype.kind())); - } - return InvalidArgument("Invalid DType: %d", static_cast(dtype.kind())); -} - -absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { - switch (primitive_type) { - case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID: - case xla::PrimitiveType::PRED: - case xla::PrimitiveType::S2: - case xla::PrimitiveType::S4: - case xla::PrimitiveType::S8: - case xla::PrimitiveType::S16: - case xla::PrimitiveType::S32: - case xla::PrimitiveType::S64: - case xla::PrimitiveType::U2: - case xla::PrimitiveType::U4: - case xla::PrimitiveType::U8: - case xla::PrimitiveType::U16: - case xla::PrimitiveType::U32: - case xla::PrimitiveType::U64: - case xla::PrimitiveType::F8E4M3FN: - case xla::PrimitiveType::F8E4M3B11FNUZ: - case xla::PrimitiveType::F8E4M3FNUZ: - case xla::PrimitiveType::F8E5M2: - case xla::PrimitiveType::F8E5M2FNUZ: - case xla::PrimitiveType::F16: - case xla::PrimitiveType::F32: - case xla::PrimitiveType::BF16: - case xla::PrimitiveType::F64: - case xla::PrimitiveType::C64: - case xla::PrimitiveType::C128: - case xla::PrimitiveType::TOKEN: - return DType(static_cast(static_cast(primitive_type))); - default: - return InvalidArgument("Invalid XLA PrimitiveType: %d", - static_cast(primitive_type)); - } -} - MemoryKind MakeMemoryKindFromPjRtBuffer(PjRtBuffer* pjrt_buffer) { if (pjrt_buffer->memory_space() == nullptr) { return MemoryKind(); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h index f91de6531c234e..a48156d9ffd61b 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h @@ -39,17 +39,12 @@ limitations under the License. #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" // IWYU pragma: keep // TODO(hyeontaek): Remove this include once downstream users are migrated to use the new header directly. #include "xla/tsl/concurrency/ref_count.h" namespace xla { namespace ifrt { -// Converts IFRT `DType` into `xla::PrimitiveType`. -absl::StatusOr ToPrimitiveType(DType dtype); - -// Converts `xla::PrimitiveType` into IFRT `DType`. -absl::StatusOr ToDType(xla::PrimitiveType primitive_type); - // Creates IFRT `MemoryKind` from an XLA `PjRtBuffer`. MemoryKind MakeMemoryKindFromPjRtBuffer(PjRtBuffer* pjrt_buffer); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.cc new file mode 100644 index 00000000000000..36d492f27569a9 --- /dev/null +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -0,0 +1,103 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" + +#include "absl/status/statusor.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace ifrt { + +absl::StatusOr ToPrimitiveType(DType dtype) { + switch (dtype.kind()) { +#define CASE(DT, PT) \ + case DT: \ + static_assert(PT == \ + static_cast(static_cast(DT))); \ + return PT + CASE(DType::kInvalid, xla::PrimitiveType::PRIMITIVE_TYPE_INVALID); + CASE(DType::kPred, xla::PrimitiveType::PRED); + CASE(DType::kS2, xla::PrimitiveType::S2); + CASE(DType::kS4, xla::PrimitiveType::S4); + CASE(DType::kS8, xla::PrimitiveType::S8); + CASE(DType::kS16, xla::PrimitiveType::S16); + CASE(DType::kS32, xla::PrimitiveType::S32); + CASE(DType::kS64, xla::PrimitiveType::S64); + CASE(DType::kU2, xla::PrimitiveType::U2); + CASE(DType::kU4, xla::PrimitiveType::U4); + CASE(DType::kU8, xla::PrimitiveType::U8); + CASE(DType::kU16, xla::PrimitiveType::U16); + CASE(DType::kU32, xla::PrimitiveType::U32); + CASE(DType::kU64, xla::PrimitiveType::U64); + CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); + CASE(DType::kF8E4M3B11FNUZ, xla::PrimitiveType::F8E4M3B11FNUZ); + CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); + CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); + CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); + CASE(DType::kF16, xla::PrimitiveType::F16); + CASE(DType::kF32, xla::PrimitiveType::F32); + CASE(DType::kBF16, xla::PrimitiveType::BF16); + CASE(DType::kF64, xla::PrimitiveType::F64); + CASE(DType::kC64, xla::PrimitiveType::C64); + CASE(DType::kC128, xla::PrimitiveType::C128); + CASE(DType::kToken, xla::PrimitiveType::TOKEN); +#undef CASE + case DType::kString: + return InvalidArgument("Not supported as XLA PrimitiveType: %d", + static_cast(dtype.kind())); + } + return InvalidArgument("Invalid DType: %d", static_cast(dtype.kind())); +} + +absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { + switch (primitive_type) { + case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID: + case xla::PrimitiveType::PRED: + case xla::PrimitiveType::S2: + case xla::PrimitiveType::S4: + case xla::PrimitiveType::S8: + case xla::PrimitiveType::S16: + case xla::PrimitiveType::S32: + case xla::PrimitiveType::S64: + case xla::PrimitiveType::U2: + case xla::PrimitiveType::U4: + case xla::PrimitiveType::U8: + case xla::PrimitiveType::U16: + case xla::PrimitiveType::U32: + case xla::PrimitiveType::U64: + case xla::PrimitiveType::F8E4M3FN: + case xla::PrimitiveType::F8E4M3B11FNUZ: + case xla::PrimitiveType::F8E4M3FNUZ: + case xla::PrimitiveType::F8E5M2: + case xla::PrimitiveType::F8E5M2FNUZ: + case xla::PrimitiveType::F16: + case xla::PrimitiveType::F32: + case xla::PrimitiveType::BF16: + case xla::PrimitiveType::F64: + case xla::PrimitiveType::C64: + case xla::PrimitiveType::C128: + case xla::PrimitiveType::TOKEN: + return DType(static_cast(static_cast(primitive_type))); + default: + return InvalidArgument("Invalid XLA PrimitiveType: %d", + static_cast(primitive_type)); + } +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.h new file mode 100644 index 00000000000000..f0ace0292e82dd --- /dev/null +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_PJRT_IFRT_PJRT_DTYPE_H_ +#define XLA_PYTHON_PJRT_IFRT_PJRT_DTYPE_H_ + +#include "absl/status/statusor.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace ifrt { + +// Converts IFRT `DType` into `xla::PrimitiveType`. +absl::StatusOr ToPrimitiveType(DType dtype); + +// Converts `xla::PrimitiveType` into IFRT `DType`. +absl::StatusOr ToDType(xla::PrimitiveType primitive_type); + +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_PJRT_IFRT_PJRT_DTYPE_H_ From f2aa4902471e5ec86619147b3e0aaf28001ba6ad Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Thu, 19 Sep 2024 18:05:34 -0700 Subject: [PATCH 040/483] [HLO Componentization] Create hlo/translate sub-component (Phase II). This CL takes care of 1. Migrating external projects dependencies from xla/translate --> xla/hlo/translate Phase I takes care of 1. Migrating xla/translate --> xla/hlo/translate 2. Setting up build aliases in xla/translate ensuring external dependencies are still satisfied. PiperOrigin-RevId: 676635487 --- tensorflow/compiler/mlir/quantization/stablehlo/BUILD | 2 +- .../stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 6d1f4ed5abe4ed..b95ec31e959b47 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -340,9 +340,9 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_passes", ], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc index 5575a7516fccc9..527cec44d284f9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc @@ -49,8 +49,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor.h" From 1f4cf24791a0f832a306154f205576ab5cef02dc Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Thu, 19 Sep 2024 19:04:15 -0700 Subject: [PATCH 041/483] [XLA:TPU] Make minor code simplification in MSA algorithm. PiperOrigin-RevId: 676650005 --- .../xla/service/memory_space_assignment/BUILD | 1 + .../memory_space_assignment/algorithm.cc | 19 +++++++------------ .../memory_space_assignment/algorithm.h | 3 --- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index a9505ecadc02e2..9a16ff01017cf6 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -519,6 +519,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 530c17775b064a..7a5ff4073692c5 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -30,7 +30,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -40,6 +39,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -960,14 +960,6 @@ absl::Status MsaAlgorithm::OptimizeMemoryBoundLoop(int loop_start_idx, hlo_live_range_, alias_analysis_, options_)); optimizer->Optimize(); - const int loop_optimized_allocations_original_size = - loop_optimized_allocations_.size(); - for (MemoryBoundLoopOptimizer::LoopValue& value : optimizer->loop_values()) { - if (!value.allocations.empty() && value.IsAllocationTypeSupported()) { - loop_optimized_allocations_.push_back(std::move(value.allocations)); - } - } - // Check if this unrolled loop is in a while loop. const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); @@ -978,9 +970,12 @@ absl::Status MsaAlgorithm::OptimizeMemoryBoundLoop(int loop_start_idx, // Update the loop_optimized_allocations_map_ with the output of the // optimizer. - for (int i = loop_optimized_allocations_original_size; - i < loop_optimized_allocations_.size(); ++i) { - const AllocationSequence& sequence = loop_optimized_allocations_.at(i); + for (MemoryBoundLoopOptimizer::LoopValue& value : optimizer->loop_values()) { + if (value.allocations.empty() || !value.IsAllocationTypeSupported()) { + continue; + } + loop_optimized_allocations_.push_back(std::move(value.allocations)); + const AllocationSequence& sequence = loop_optimized_allocations_.back(); CHECK(!sequence.empty()); VLOG(3) << " alloc: " << sequence.back()->ToString(); for (const auto& allocation : sequence) { diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index 52d0f0ee563747..1cfcf1f6094938 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -36,13 +35,11 @@ limitations under the License. #endif #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" -#include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" From ffa633c73420e77a63ef84758b1c16b19168dbfb Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Thu, 19 Sep 2024 19:42:27 -0700 Subject: [PATCH 042/483] [HLO Componentization] Create hlo/translate sub-component (Phase II). This CL takes care of 1. Migrating external projects dependencies from xla/translate --> xla/hlo/translate Phase I takes care of 1. Migrating xla/translate --> xla/hlo/translate 2. Setting up build aliases in xla/translate ensuring external dependencies are still satisfied. PiperOrigin-RevId: 676658784 --- third_party/xla/xla/service/spmd/shardy/BUILD | 4 ++-- .../xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD | 5 ++--- .../service/spmd/shardy/mhlo_round_trip/export_shardings.cc | 2 +- .../xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc | 2 +- .../xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD | 4 ++-- .../shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc | 4 ++-- third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc | 4 ++-- 7 files changed, 12 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD index b5d040c5562195..fb0b39c4195ace 100644 --- a/third_party/xla/xla/service/spmd/shardy/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/BUILD @@ -60,6 +60,8 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/hlo/utils:hlo_sharding_util", "//xla/mlir_hlo:mhlo_passes", "//xla/service:computation_layout", @@ -70,8 +72,6 @@ cc_library( "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_export", "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_import", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", - "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD index b8ab5769b90267..9dfe3b597e00e0 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD @@ -26,11 +26,10 @@ cc_library( "//xla:array", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", - "//xla/translate/mhlo_to_hlo:type_to_shape", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -103,11 +102,11 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", + "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", - "//xla/translate/mhlo_to_hlo:attribute_exporter", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc index 3ab00020fc21d8..2c3578998e69d5 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc @@ -58,11 +58,11 @@ limitations under the License. #include "shardy/dialect/sdy/ir/utils.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" namespace xla { namespace sdy { diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc index 50177a1268f9e0..e80d3cc285ac47 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc @@ -58,12 +58,12 @@ limitations under the License. #include "shardy/dialect/sdy/ir/dialect.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h" #include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD index 448496e4e6de84..75974ab9c50c87 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD @@ -22,12 +22,12 @@ cc_library( deps = [ "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", - "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc index b9c55aebcdbf6b..8c8692d34ef409 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc @@ -36,13 +36,13 @@ limitations under the License. #include "shardy/dialect/sdy/ir/dialect.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc index c5c8997a6d51e2..19ed5d9a95292e 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -45,6 +45,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/layout.h" #include "xla/map_util.h" @@ -62,8 +64,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" #include "xla/util.h" #include "xla/xla_data.pb.h" From d5deb406cfb4386c78ec26948451c665a62114fa Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Thu, 19 Sep 2024 19:54:12 -0700 Subject: [PATCH 043/483] [XLA:TPU] Use a struct instead of a std::pair for EvenOddChunkPair. PiperOrigin-RevId: 676660764 --- .../memory_bound_loop_optimizer.cc | 50 +++++++++---------- .../memory_bound_loop_optimizer.h | 14 +++--- .../memory_bound_loop_optimizer_test.cc | 13 +++-- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc index 0cfbc0db37f178..11c35e74b21309 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc @@ -69,12 +69,12 @@ struct LoopOptimizerChunkInterval { EvenOddChunkPair chunks; std::string ToString() const { - CHECK(chunks.first.has_value() && chunks.second.has_value()); + CHECK(chunks.HasValues()); return absl::StrFormat( "begin_idx_in_loop: %d, end_idx_in_loop: %d, even chunk: %s, odd " "chunk: %s", - begin_idx_in_loop, end_idx_in_loop, chunks.first->ToString(), - chunks.second->ToString()); + begin_idx_in_loop, end_idx_in_loop, chunks.even_chunk->ToString(), + chunks.odd_chunk->ToString()); } }; @@ -417,10 +417,9 @@ absl::Status MemoryBoundLoopOptimizer::Initialize() { // Chunks for reserved scoped memory should always be found at offset 0. EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( i, i, reserved_memory, /*preferred_offsets=*/{0, 0}); - CHECK(chunks.first.has_value()); - CHECK(chunks.second.has_value()); - CHECK(chunks.first->size == reserved_memory); - VLOG(3) << "Reserved chunk: " << chunks.first->ToString() + CHECK(chunks.HasValues()); + CHECK(chunks.even_chunk->size == reserved_memory); + VLOG(3) << "Reserved chunk: " << chunks.even_chunk->ToString() << " loop index: " << i; } @@ -840,9 +839,11 @@ std::string MemoryBoundLoopOptimizer::LoopValue::ToString() const { absl::StrAppend(&allocations_str, "\n - ", allocation->ToString()); } std::string chunk_str; - if (HasEvenAndOddChunks()) { - absl::StrAppend(&chunk_str, "\n", "even chunk: ", chunks.first->ToString()); - absl::StrAppend(&chunk_str, "\n", "odd chunk: ", chunks.second->ToString()); + if (chunks.HasValues()) { + absl::StrAppend(&chunk_str, "\n", + "even chunk: ", chunks.even_chunk->ToString()); + absl::StrAppend(&chunk_str, "\n", + "odd chunk: ", chunks.odd_chunk->ToString()); absl::StrAppend(&chunk_str, "\n", "alternate memory begin idx in loop: ", alternate_memory_begin_idx_in_loop.value()); absl::StrAppend(&chunk_str, "\n", "alternate memory end idx in loop: ", @@ -861,10 +862,6 @@ bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { allocation_type == AllocationType::kPrefetch; } -bool MemoryBoundLoopOptimizer::LoopValue::HasEvenAndOddChunks() const { - return chunks.first.has_value() && chunks.second.has_value(); -} - void MemoryBoundLoopOptimizer::LoopValue::SetChunkPairAndInterval( EvenOddChunkPair chunk_pair, int64_t begin_idx_in_loop, int64_t end_idx_in_loop) { @@ -978,7 +975,7 @@ bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { } EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, value.size); - if (!chunks.first.has_value() || !chunks.second.has_value()) { + if (!chunks.HasValues()) { VLOG(3) << "Could not find Allocation for temporary value: " << value.ToString(); return false; @@ -1001,7 +998,7 @@ bool MemoryBoundLoopOptimizer::AllocatePinned(LoopValue& value) { int64_t end_idx_in_loop = loop_size_ - 1; EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, value.size); - if (!chunks.first.has_value() || !chunks.second.has_value()) { + if (!chunks.HasValues()) { VLOG(3) << "Could not find Allocation for pinned value: " << value.ToString(); return false; @@ -1287,7 +1284,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( early_forced_value->alternate_memory_end_idx_in_loop.value(); EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, early_forced_value->size); - if (!chunks.first.has_value() || !chunks.second.has_value()) { + if (!chunks.HasValues()) { VLOG(3) << "Could not allocate between " << begin_idx_in_loop << " and " << end_idx_in_loop << " for early forced value: " << early_forced_value->ToString(); @@ -1308,7 +1305,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( int64_t end_idx_in_loop = last_use_idx_sentinel; EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, value->size); - if (chunks.first.has_value() && chunks.second.has_value()) { + if (chunks.HasValues()) { pending_chunk_intervals.push_back( {begin_idx_in_loop, end_idx_in_loop, chunks}); VLOG(3) << "Added pending chunk: " @@ -1405,14 +1402,15 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( early_forced_value->alternate_memory_begin_idx_in_loop.value(), early_forced_value->alternate_memory_end_idx_in_loop.value(), early_forced_value->size, - {early_forced_value->chunks.first->offset, - early_forced_value->chunks.second->offset}); + {early_forced_value->chunks.even_chunk->offset, + early_forced_value->chunks.odd_chunk->offset}); // The chunk should always be present as we are allocating at the same // offset. - CHECK(chunks.first.has_value() && chunks.second.has_value()); - CHECK_EQ(chunks.first->offset, early_forced_value->chunks.first->offset); - CHECK_EQ(chunks.second->offset, - early_forced_value->chunks.second->offset); + CHECK(chunks.HasValues()); + CHECK_EQ(chunks.even_chunk->offset, + early_forced_value->chunks.even_chunk->offset); + CHECK_EQ(chunks.odd_chunk->offset, + early_forced_value->chunks.odd_chunk->offset); } VLOG(3) << "Memory usage after restoring original state: " << heap_.MemoryUsageToAsciiArt(); @@ -1473,7 +1471,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( begin_idx_in_loop, end_idx_in_loop, early_forced_value->size); // The chunk should always be present as we reproducing the same allocation // pattern as the out-of-memory check. - CHECK(chunks.first.has_value() && chunks.second.has_value()); + CHECK(chunks.HasValues()); CHECK_LT(begin_idx_in_loop, early_forced_value->alternate_memory_begin_idx_in_loop.value()); early_forced_value->SetChunkPairAndInterval(chunks, begin_idx_in_loop, @@ -1496,7 +1494,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( // pattern as the out-of-memory check. EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, value->size); - CHECK(chunks.first.has_value() && chunks.second.has_value()); + CHECK(chunks.HasValues()); value->SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); value->allocations.push_back(std::make_unique( *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h index a1f0769fc14658..d87975c6384b65 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h @@ -49,8 +49,14 @@ namespace xla { namespace memory_space_assignment { // Pair of chunks for even and odd loop iterations. -using EvenOddChunkPair = std::pair, - std::optional>; +struct EvenOddChunkPair { + std::optional even_chunk; + std::optional odd_chunk; + + bool HasValues() const { + return even_chunk.has_value() && odd_chunk.has_value(); + } +}; // LoopOptimizerBestFitHeap extends GlobalDecreasingSizeBestFitHeap to track // allocated buffers and their live intervals for the MemoryBoundLoopOptimizer. @@ -270,10 +276,6 @@ class MemoryBoundLoopOptimizer { static std::string AllocationTypeToString(AllocationType allocation_type); std::string ToString() const; - // Returns true if the LoopValue has chunks for even and odd loop - // iterations. - bool HasEvenAndOddChunks() const; - // Returns true if memory-bound loop optimizer supports allocating this type // of a loop value. bool IsAllocationTypeSupported() const; diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc index 460793d8309ec8..b01874003fb22a 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -97,7 +97,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.first.has_value() && chunks.second.has_value(); + return chunks.HasValues(); } bool CanFindSameEvenAndOddAllocationBetween(int64_t begin_idx_in_loop, @@ -105,7 +105,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.FindSameEvenAndOddAllocationBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.first.has_value() && chunks.second.has_value(); + return chunks.HasValues(); } bool IsAllocateEvenAndOddBetweenSuccessful(int64_t begin_idx_in_loop, @@ -113,7 +113,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.first.has_value() && chunks.second.has_value(); + return chunks.HasValues(); } bool CanFindEvenAndOddAllocationBetween(int64_t begin_idx_in_loop, @@ -121,7 +121,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.FindEvenAndOddAllocationBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.first.has_value() && chunks.second.has_value(); + return chunks.HasValues(); } std::string GetMemoryUsageAsciiArt() { return heap_.MemoryUsageToAsciiArt(); } @@ -193,10 +193,9 @@ TEST_F(LoopOptimizerBestFitHeapTest, TestAllocateEvenAndOddBetween) { TEST_F(LoopOptimizerBestFitHeapTest, TestRemoveChunk) { EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween(3, 11, 16); - EXPECT_TRUE(chunks.first.has_value() && chunks.second.has_value()); + EXPECT_TRUE(chunks.HasValues()); EvenOddChunkPair second_chunks = heap_.AllocateEvenAndOddBetween(-3, 8, 16); - EXPECT_TRUE(second_chunks.first.has_value() && - second_chunks.second.has_value()); + EXPECT_TRUE(second_chunks.HasValues()); EXPECT_THAT(heap_.RemainingMemoryByTime(), ContainerEq(std::vector{16, 16, 16, 0, 0, 0})); EXPECT_EQ(heap_.LastMemoryOffsetOccupied(), 64); From 8a85f9cf7c8262a43b1560b982671abe6abaf362 Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Thu, 19 Sep 2024 20:17:31 -0700 Subject: [PATCH 044/483] [XLA:TPU] Rename loop variable `i` to `current_idx` in memory_bound_loop_optimizer when it improves readability (multiple uses of `i` throughout a long loop). PiperOrigin-RevId: 676667952 --- .../memory_bound_loop_optimizer.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc index 11c35e74b21309..f5b6c810cb6597 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc @@ -1218,15 +1218,15 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( // limiting it till last_use_idx_sentinel - loop_size. This will allow a // prefetch to use all the idle bandwidth available during one full loop // iteration. - for (int i = first_use_idx - 1; i >= last_use_idx_sentinel - loop_size_; - --i) { - int loop_idx = (i + loop_size_) % loop_size_; + for (int current_idx = first_use_idx - 1; + current_idx >= last_use_idx_sentinel - loop_size_; --current_idx) { + int loop_idx = (current_idx + loop_size_) % loop_size_; // Check if this prefetch rolls over to the previous iteration, check if any // already-scheduled prefetches would violate the FIFO order, and if so, // "early-force" them to be co-scheduled with this prefetch to maintain the // FIFO order. This of course increases the required memory, so also keep // track of additional memory that would be consumed. - if (i < 0) { + if (current_idx < 0) { for (; context.value_indices[early_forced_prefetch_value_search_index] != value_index; ++early_forced_prefetch_value_search_index) { @@ -1301,7 +1301,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( if (pending_chunk_intervals.size() == early_forced_prefetch_value_indices.size()) { - int64_t begin_idx_in_loop = i; + int64_t begin_idx_in_loop = current_idx; int64_t end_idx_in_loop = last_use_idx_sentinel; EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, value->size); @@ -1355,7 +1355,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( (copy_resource - accumulated_copy_resource)); if (bandwidth_idle_time >= copy_resource - accumulated_copy_resource) { accumulated_copy_resource = copy_resource; - copy_start_loop_idx = i; + copy_start_loop_idx = current_idx; VLOG(3) << "Found the complete copy ratio and updated accumulated copy " "resource: " << accumulated_copy_resource; @@ -1364,7 +1364,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( accumulated_copy_resource + bandwidth_idle_time >= copy_resource * options_.desired_copy_ratio()) { accumulated_copy_resource += bandwidth_idle_time; - copy_start_loop_idx = i; + copy_start_loop_idx = current_idx; VLOG(3) << "Found the desired copy ratio and updated accumulated copy " "resource: " << accumulated_copy_resource; @@ -1373,7 +1373,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( // Even if desired resource isn't reached, and if the options allow it, // allow a fully pipelined prefetch. accumulated_copy_resource += bandwidth_idle_time; - copy_start_loop_idx = i; + copy_start_loop_idx = current_idx; VLOG(3) << "Could not reach the desired copy ratio but scheduling " "fully pipelined prefetch anyway: " << accumulated_copy_resource; From 0b519beee989ea5d6ca646c97f4994869283171e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 21:06:50 -0700 Subject: [PATCH 045/483] Automated Code Change PiperOrigin-RevId: 676679763 --- tensorflow/core/transforms/cse/pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/transforms/cse/pass.cc b/tensorflow/core/transforms/cse/pass.cc index f84fdfaaf568e3..f940a2bff65702 100644 --- a/tensorflow/core/transforms/cse/pass.cc +++ b/tensorflow/core/transforms/cse/pass.cc @@ -17,11 +17,11 @@ limitations under the License. #include +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/ops.h" From 03b0dd411df9535d2d810115ee07c31c4d06789b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 23:04:24 -0700 Subject: [PATCH 046/483] Automated Code Change PiperOrigin-RevId: 676706989 --- tensorflow/lite/delegates/gpu/common/BUILD | 2 ++ tensorflow/lite/delegates/gpu/common/winograd_util.cc | 1 + tensorflow/lite/delegates/gpu/common/winograd_util_test.cc | 3 ++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 195124f269c7c8..df4cbb4306790d 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -533,6 +533,8 @@ cc_test( name = "winograd_util_test", srcs = ["winograd_util_test.cc"], deps = [ + ":operations", + ":shape", ":winograd_util", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.cc b/tensorflow/lite/delegates/gpu/common/winograd_util.cc index c499d1e9e3dd0b..3ebaef9a38503e 100644 --- a/tensorflow/lite/delegates/gpu/common/winograd_util.cc +++ b/tensorflow/lite/delegates/gpu/common/winograd_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc b/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc index 81fb643d399a82..1c694488a33937 100644 --- a/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc +++ b/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" -#include #include +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" namespace tflite { namespace gpu { From 0c43cafe07d46be6fb551188672faffcd7fa7b2d Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 20 Sep 2024 00:13:41 -0700 Subject: [PATCH 047/483] Disable AllGatherDynamicSliceSimplifier Causing errors in some jax tests PiperOrigin-RevId: 676724285 --- third_party/xla/xla/service/gpu/BUILD | 1 - third_party/xla/xla/service/gpu/gpu_compiler.cc | 2 -- 2 files changed, 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b5b003dacb889e..22beda5486c796 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1418,7 +1418,6 @@ cc_library( "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/gpu/transforms:algorithm_checker", - "//xla/service/gpu/transforms:all_gather_dynamic_slice_simplifier", "//xla/service/gpu/transforms:all_gather_optimizer", "//xla/service/gpu/transforms:all_reduce_blueconnect", "//xla/service/gpu/transforms:all_reduce_splitter", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index f4aafc5972a3e6..4dee1d3f4fd57a 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -144,7 +144,6 @@ limitations under the License. #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "xla/service/gpu/transforms/algorithm_checker.h" -#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" #include "xla/service/gpu/transforms/all_gather_optimizer.h" #include "xla/service/gpu/transforms/all_reduce_blueconnect.h" #include "xla/service/gpu/transforms/all_reduce_splitter.h" @@ -904,7 +903,6 @@ absl::Status RunCollectiveOptimizationPasses( HloPassPipeline collectives_pipeline("collective-optimizations"); collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); collectives_pipeline.AddPass( debug_options.xla_gpu_enable_reassociation_for_converted_ar()); From 82263a468ee98d398075990c43253de7b8388682 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 20 Sep 2024 00:29:22 -0700 Subject: [PATCH 048/483] PR #17366: [NFC] Fix units and usage string in compute_cost tool. Imported from GitHub PR https://github.com/openxla/xla/pull/17366 Copybara import of the project: -- 80521637dd0c101a911968e5ddcf0a80b4317977 by Ilia Sergachev : [NFC] Fix units and usage string in compute_cost tool. Merging this change closes #17366 PiperOrigin-RevId: 676728602 --- third_party/xla/xla/tools/compute_cost.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/tools/compute_cost.cc b/third_party/xla/xla/tools/compute_cost.cc index 9615ae01b59940..153c489ad61c9a 100644 --- a/third_party/xla/xla/tools/compute_cost.cc +++ b/third_party/xla/xla/tools/compute_cost.cc @@ -41,7 +41,7 @@ The input file can be obtained from XProf graph viewer by clicking Usage: - bazel run compute_cost -- -input=path/to/hlo_module -format=[hlo|pb|pbtxt] + bazel run compute_cost -- --input=path/to/hlo_module --format=[hlo|pb|pbtxt] )"; } // namespace @@ -71,6 +71,6 @@ int main(int argc, char** argv) { std::cout << std::setw(5) << std::setprecision(4) << analysis.flop_count() / (1e9) << " GFLOPS. " - << analysis.bytes_accessed() / (1e6) << " MiB." << std::endl; + << analysis.bytes_accessed() / (1e6) << " MB." << std::endl; return 0; } From dfff8e8d946484408c7668d41539e49ea82ac463 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 20 Sep 2024 01:11:34 -0700 Subject: [PATCH 049/483] PR #17383: Parameterize elemental_ir_emitter_test.cc float tests Imported from GitHub PR https://github.com/openxla/xla/pull/17383 `elemental_ir_emitter_test.cc` contains multiple similar tests for `bf16`, `f8e4m3fnuz` and , `f8e5m2fnuz`. Changes: - Parameterize the float tests in elemental_ir_emitter_test.cc. - Add additional types to the list of tested types - `f8e5m2`, `f8e4m3fn`, `f8e4m3b11fnuz`. Some tests failed for newly added types. Temporary use `GTEST_SKIP` for such cases: Related issues: - https://github.com/openxla/xla/issues/17323 - https://github.com/openxla/xla/issues/17324 Copybara import of the project: -- 47dcfcf43908584d453f63008c9d68b5e7dae9c3 by Alexander Pivovarov : Parameterize elemental_ir_emitter_test.cc float tests Merging this change closes #17383 PiperOrigin-RevId: 676739814 --- .../xla/service/elemental_ir_emitter_test.cc | 609 ++++++------------ 1 file changed, 181 insertions(+), 428 deletions(-) diff --git a/third_party/xla/xla/service/elemental_ir_emitter_test.cc b/third_party/xla/xla/service/elemental_ir_emitter_test.cc index 9ee2680065a26f..7c73dd3a1d0dd7 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter_test.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" @@ -77,6 +78,23 @@ class ElementalIrEmitterExecutionTestWithoutFastMinMax } }; +template +class ElementalIrEmitterExecutionTypedTest + : public ElementalIrEmitterExecutionTest { + protected: + const std::string& TypeName() { + return primitive_util::LowercasePrimitiveTypeName( + primitive_util::NativeToPrimitiveType()); + } +}; + +using FloatTypes = + ::testing::Types; + +TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); + XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { const std::string hlo_text = R"( HloModule FusedDot @@ -229,473 +247,208 @@ XLA_TEST_F(ElementalIrEmitterExecutionTest, EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{(0.)})); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 - (f16_ f16[], f32_ f32[], f64_ f64[]) -> (bf16[], bf16[], bf16[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - converted_f16 = bf16[] convert(f16[] f16_) - converted_f32 = bf16[] convert(f32[] f32_) - converted_f64 = bf16[] convert(f64[] f64_) - ROOT tuple = (bf16[], bf16[], bf16[]) tuple(converted_f16, converted_f32, - converted_f64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (bf16[], bf16[], bf16[], bf16[]) { - s8_ = s8[] parameter(0) - s16_ = s16[] parameter(1) - s32_ = s32[] parameter(2) - s64_ = s64[] parameter(3) - converted_s8 = bf16[] convert(s8[] s8_) - converted_s16 = bf16[] convert(s16[] s16_) - converted_s32 = bf16[] convert(s32[] s32_) - converted_s64 = bf16[] convert(s64[] s64_) - ROOT tuple = (bf16[], bf16[], bf16[], bf16[]) tuple( - converted_s8, converted_s16, converted_s32, converted_s64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (bf16[], bf16[], bf16[], bf16[]) { - u8_ = u8[] parameter(0) - u16_ = u16[] parameter(1) - u32_ = u32[] parameter(2) - u64_ = u64[] parameter(3) - converted_u8 = bf16[] convert(u8[] u8_) - converted_u16 = bf16[] convert(u16[] u16_) - converted_u32 = bf16[] convert(u32[] u32_) - converted_u64 = bf16[] convert(u64[] u64_) - ROOT tuple = (bf16[], bf16[], bf16[], bf16[]) tuple( - converted_u8, converted_u16, converted_u32, converted_u64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16 - (to_f16 bf16[], to_f32 bf16[], to_f64 bf16[]) -> (f16[], f32[], f64[]) { - to_f16 = bf16[] parameter(0) - to_f32 = bf16[] parameter(1) - to_f64 = bf16[] parameter(2) - f16_ = f16[] convert(bf16[] to_f16) - f32_ = f32[] convert(bf16[] to_f32) - f64_ = f64[] convert(bf16[] to_f64) - ROOT tuple = (f16[], f32[], f64[]) tuple(f16_, f32_, f64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16(to_s8 bf16[], to_s16 bf16[], to_s32 bf16[], - to_s64 bf16[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = bf16[] parameter(0) - to_s16 = bf16[] parameter(1) - to_s32 = bf16[] parameter(2) - to_s64 = bf16[] parameter(3) - s8_ = s8[] convert(bf16[] to_s8) - s16_ = s16[] convert(bf16[] to_s16) - s32_ = s32[] convert(bf16[] to_s32) - s64_ = s64[] convert(bf16[] to_s64) - ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16(to_u8 bf16[], to_u16 bf16[], to_u32 bf16[], - to_u64 bf16[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = bf16[] parameter(0) - to_u16 = bf16[] parameter(1) - to_u32 = bf16[] parameter(2) - to_u64 = bf16[] parameter(3) - u8_ = u8[] convert(bf16[] to_u8) - u16_ = u16[] convert(bf16[] to_u16) - u32_ = u32[] convert(bf16[] to_u32) - u64_ = u64[] convert(bf16[] to_u64) - ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16 - (to_c64 bf16[], to_c128 bf16[]) -> (c64[], c128[]) { - to_c64 = bf16[] parameter(0) - to_c128 = bf16[] parameter(1) - c64_ = c64[] convert(bf16[] to_c64) - c128_ = c128[] convert(bf16[] to_c128) - ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareBF16) { - constexpr char hlo_text[] = R"( - HloModule compareBF16 - ENTRY main { - p0 = bf16[4] parameter(0) - p1 = bf16[4] parameter(1) - ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToBF16(lhs); - rhs = LiteralUtil::ConvertF32ToBF16(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaBF16) { - constexpr char hlo_text[] = R"( - HloModule IotaBF16 - ENTRY main { - ROOT iota_ = bf16[4] iota(), iota_dimension=0 +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatsToFloat) { + auto tname = this->TypeName(); + if (std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; } - )"; - - RunTest(hlo_text, {}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, BatchDotBF16) { - const char* const hlo_text = R"( - HloModule matmul - + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - x = bf16[8,16] parameter(0) - y = bf16[8,16,32] parameter(1) - ROOT dot = bf16[8,32] dot(x, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} + f16_ = f16[] parameter(0) + f32_ = f32[] parameter(1) + f64_ = f64[] parameter(2) + bf16_ = bf16[] parameter(3) + converted_f16 = ${tname}[] convert(f16_) + converted_f32 = ${tname}[] convert(f32_) + converted_f64 = ${tname}[] convert(f64_) + converted_bf16 = ${tname}[] convert(bf16_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( + converted_f16, converted_f32, converted_f64, converted_bf16) } - )"; - HloModuleConfig config; - DebugOptions debug_options = GetDebugOptionsForTest(); - config.set_debug_options(debug_options); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text, config)); - EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ - (f16_ f16[], f32_ f32[], f64_ f64[], bf16_ bf16[]) -> (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - bf16_ = bf16[] parameter(3) - converted_f16 = f8e4m3fnuz[] convert(f16[] f16_) - converted_f32 = f8e4m3fnuz[] convert(f32[] f32_) - converted_f64 = f8e4m3fnuz[] convert(f64[] f64_) - converted_bf16 = f8e4m3fnuz[] convert(bf16[] bf16_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( - converted_f16, converted_f32, converted_f64, converted_bf16) - } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertSignedToFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { s8_ = s8[] parameter(0) s16_ = s16[] parameter(1) s32_ = s32[] parameter(2) s64_ = s64[] parameter(3) - converted_s8 = f8e4m3fnuz[] convert(s8[] s8_) - converted_s16 = f8e4m3fnuz[] convert(s16[] s16_) - converted_s32 = f8e4m3fnuz[] convert(s32[] s32_) - converted_s64 = f8e4m3fnuz[] convert(s64[] s64_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_s8 = ${tname}[] convert(s8_) + converted_s16 = ${tname}[] convert(s16_) + converted_s32 = ${tname}[] convert(s32_) + converted_s64 = ${tname}[] convert(s64_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( converted_s8, converted_s16, converted_s32, converted_s64) } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertUnsignedToFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { u8_ = u8[] parameter(0) u16_ = u16[] parameter(1) u32_ = u32[] parameter(2) u64_ = u64[] parameter(3) - converted_u8 = f8e4m3fnuz[] convert(u8[] u8_) - converted_u16 = f8e4m3fnuz[] convert(u16[] u16_) - converted_u32 = f8e4m3fnuz[] convert(u32[] u32_) - converted_u64 = f8e4m3fnuz[] convert(u64[] u64_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_u8 = ${tname}[] convert(u8_) + converted_u16 = ${tname}[] convert(u16_) + converted_u32 = ${tname}[] convert(u32_) + converted_u64 = ${tname}[] convert(u64_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( converted_u8, converted_u16, converted_u32, converted_u64) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ - (to_f16 f8e4m3fnuz[], to_f32 f8e4m3fnuz[], to_f64 f8e4m3fnuz[], to_bf16 f8e4m3fnuz[]) -> (f16[], f32[], f64[], bf16[]) { - to_f16 = f8e4m3fnuz[] parameter(0) - to_f32 = f8e4m3fnuz[] parameter(1) - to_f64 = f8e4m3fnuz[] parameter(2) - to_bf16 = f8e4m3fnuz[] parameter(3) - f16_ = f16[] convert(f8e4m3fnuz[] to_f16) - f32_ = f32[] convert(f8e4m3fnuz[] to_f32) - f64_ = f64[] convert(f8e4m3fnuz[] to_f64) - bf16_ = bf16[] convert(f8e4m3fnuz[] to_f64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToFloats) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_f16 = ${tname}[] parameter(0) + to_f32 = ${tname}[] parameter(1) + to_f64 = ${tname}[] parameter(2) + to_bf16 = ${tname}[] parameter(3) + f16_ = f16[] convert(to_f16) + f32_ = f32[] convert(to_f32) + f64_ = f64[] convert(to_f64) + bf16_ = bf16[] convert(to_f64) ROOT tuple = (f16[], f32[], f64[], bf16[]) tuple(f16_, f32_, f64_, bf16_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ(to_s8 f8e4m3fnuz[], to_s16 f8e4m3fnuz[], to_s32 f8e4m3fnuz[], - to_s64 f8e4m3fnuz[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = f8e4m3fnuz[] parameter(0) - to_s16 = f8e4m3fnuz[] parameter(1) - to_s32 = f8e4m3fnuz[] parameter(2) - to_s64 = f8e4m3fnuz[] parameter(3) - s8_ = s8[] convert(f8e4m3fnuz[] to_s8) - s16_ = s16[] convert(f8e4m3fnuz[] to_s16) - s32_ = s32[] convert(f8e4m3fnuz[] to_s32) - s64_ = s64[] convert(f8e4m3fnuz[] to_s64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToSigned) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_s8 = ${tname}[] parameter(0) + to_s16 = ${tname}[] parameter(1) + to_s32 = ${tname}[] parameter(2) + to_s64 = ${tname}[] parameter(3) + s8_ = s8[] convert(to_s8) + s16_ = s16[] convert(to_s16) + s32_ = s32[] convert(to_s32) + s64_ = s64[] convert(to_s64) ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ(to_u8 f8e4m3fnuz[], to_u16 f8e4m3fnuz[], to_u32 f8e4m3fnuz[], - to_u64 f8e4m3fnuz[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = f8e4m3fnuz[] parameter(0) - to_u16 = f8e4m3fnuz[] parameter(1) - to_u32 = f8e4m3fnuz[] parameter(2) - to_u64 = f8e4m3fnuz[] parameter(3) - u8_ = u8[] convert(f8e4m3fnuz[] to_u8) - u16_ = u16[] convert(f8e4m3fnuz[] to_u16) - u32_ = u32[] convert(f8e4m3fnuz[] to_u32) - u64_ = u64[] convert(f8e4m3fnuz[] to_u64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToUnsigned) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_u8 = ${tname}[] parameter(0) + to_u16 = ${tname}[] parameter(1) + to_u32 = ${tname}[] parameter(2) + to_u64 = ${tname}[] parameter(3) + u8_ = u8[] convert(to_u8) + u16_ = u16[] convert(to_u16) + u32_ = u32[] convert(to_u32) + u64_ = u64[] convert(to_u64) ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ - (to_c64 f8e4m3fnuz[], to_c128 f8e4m3fnuz[]) -> (c64[], c128[]) { - to_c64 = f8e4m3fnuz[] parameter(0) - to_c128 = f8e4m3fnuz[] parameter(1) - c64_ = c64[] convert(f8e4m3fnuz[] to_c64) - c128_ = c128[] convert(f8e4m3fnuz[] to_c128) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToComplex) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_c64 = ${tname}[] parameter(0) + to_c128 = ${tname}[] parameter(1) + c64_ = c64[] convert(to_c64) + c128_ = c128[] convert(to_c128) ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareF8E4FNUZ) { - constexpr char hlo_text[] = R"( - HloModule compareF8E4FNUZ +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, CompareFloat) { + auto tname = this->TypeName(); + if (std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - p0 = f8e4m3fnuz[4] parameter(0) - p1 = f8e4m3fnuz[4] parameter(1) + p0 = ${tname}[4] parameter(0) + p1 = ${tname}[4] parameter(1) ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToF8E4M3FNUZ(lhs); - rhs = LiteralUtil::ConvertF32ToF8E4M3FNUZ(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E4FNUZ) { - constexpr char hlo_text[] = R"( - HloModule IotaF8E4FNUZ +})", + {{"${tname}", tname}}); + Literal lhs = LiteralUtil::CreateR1( + {TypeParam(1.), TypeParam(2.), TypeParam(3.), TypeParam(4.)}); + Literal rhs = LiteralUtil::CreateR1( + {TypeParam(4.), TypeParam(4.), TypeParam(2.), TypeParam(1.)}); + ElementalIrEmitterExecutionTest::RunTest(hlo_text, {&lhs, &rhs}); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { + auto tname = this->TypeName(); + if (std::is_same() || + std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - ROOT iota_ = f8e4m3fnuz[4] iota(), iota_dimension=0 + ROOT iota_ = ${tname}[4] iota(), iota_dimension=0 } - )"; - - RunTest(hlo_text, {}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ - (f16_ f16[], f32_ f32[], f64_ f64[], bf16_ bf16[]) -> (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - bf16_ = bf16[] parameter(3) - converted_f16 = f8e5m2fnuz[] convert(f16[] f16_) - converted_f32 = f8e5m2fnuz[] convert(f32[] f32_) - converted_f64 = f8e5m2fnuz[] convert(f64[] f64_) - converted_bf16 = f8e5m2fnuz[] convert(bf16[] bf16_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_f16, converted_f32, converted_f64, converted_bf16) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - s8_ = s8[] parameter(0) - s16_ = s16[] parameter(1) - s32_ = s32[] parameter(2) - s64_ = s64[] parameter(3) - converted_s8 = f8e5m2fnuz[] convert(s8[] s8_) - converted_s16 = f8e5m2fnuz[] convert(s16[] s16_) - converted_s32 = f8e5m2fnuz[] convert(s32[] s32_) - converted_s64 = f8e5m2fnuz[] convert(s64[] s64_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_s8, converted_s16, converted_s32, converted_s64) - } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTest(hlo_text, {}); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - u8_ = u8[] parameter(0) - u16_ = u16[] parameter(1) - u32_ = u32[] parameter(2) - u64_ = u64[] parameter(3) - converted_u8 = f8e5m2fnuz[] convert(u8[] u8_) - converted_u16 = f8e5m2fnuz[] convert(u16[] u16_) - converted_u32 = f8e5m2fnuz[] convert(u32[] u32_) - converted_u64 = f8e5m2fnuz[] convert(u64[] u64_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_u8, converted_u16, converted_u32, converted_u64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ - (to_f16 f8e5m2fnuz[], to_f32 f8e5m2fnuz[], to_f64 f8e5m2fnuz[]) -> (f16[], f32[], f64[]) { - to_f16 = f8e5m2fnuz[] parameter(0) - to_f32 = f8e5m2fnuz[] parameter(1) - to_f64 = f8e5m2fnuz[] parameter(2) - f16_ = f16[] convert(f8e5m2fnuz[] to_f16) - f32_ = f32[] convert(f8e5m2fnuz[] to_f32) - f64_ = f64[] convert(f8e5m2fnuz[] to_f64) - ROOT tuple = (f16[], f32[], f64[]) tuple(f16_, f32_, f64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ(to_s8 f8e5m2fnuz[], to_s16 f8e5m2fnuz[], to_s32 f8e5m2fnuz[], - to_s64 f8e5m2fnuz[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = f8e5m2fnuz[] parameter(0) - to_s16 = f8e5m2fnuz[] parameter(1) - to_s32 = f8e5m2fnuz[] parameter(2) - to_s64 = f8e5m2fnuz[] parameter(3) - s8_ = s8[] convert(f8e5m2fnuz[] to_s8) - s16_ = s16[] convert(f8e5m2fnuz[] to_s16) - s32_ = s32[] convert(f8e5m2fnuz[] to_s32) - s64_ = s64[] convert(f8e5m2fnuz[] to_s64) - ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ(to_u8 f8e5m2fnuz[], to_u16 f8e5m2fnuz[], to_u32 f8e5m2fnuz[], - to_u64 f8e5m2fnuz[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = f8e5m2fnuz[] parameter(0) - to_u16 = f8e5m2fnuz[] parameter(1) - to_u32 = f8e5m2fnuz[] parameter(2) - to_u64 = f8e5m2fnuz[] parameter(3) - u8_ = u8[] convert(f8e5m2fnuz[] to_u8) - u16_ = u16[] convert(f8e5m2fnuz[] to_u16) - u32_ = u32[] convert(f8e5m2fnuz[] to_u32) - u64_ = u64[] convert(f8e5m2fnuz[] to_u64) - ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ - (to_c64 f8e5m2fnuz[], to_c128 f8e5m2fnuz[]) -> (c64[], c128[]) { - to_c64 = f8e5m2fnuz[] parameter(0) - to_c128 = f8e5m2fnuz[] parameter(1) - c64_ = c64[] convert(f8e5m2fnuz[] to_c64) - c128_ = c128[] convert(f8e5m2fnuz[] to_c128) - ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareF8E5FNUZ) { - constexpr char hlo_text[] = R"( - HloModule compareF8E5FNUZ - ENTRY main { - p0 = f8e5m2fnuz[4] parameter(0) - p1 = f8e5m2fnuz[4] parameter(1) - ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToF8E5M2FNUZ(lhs); - rhs = LiteralUtil::ConvertF32ToF8E5M2FNUZ(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule matmul -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E5FNUZ) { - constexpr char hlo_text[] = R"( - HloModule IotaF8E5FNUZ ENTRY main { - ROOT iota_ = f8e5m2fnuz[4] iota(), iota_dimension=0 + x = ${tname}[8,16] parameter(0) + y = ${tname}[8,16,32] parameter(1) + ROOT dot = ${tname}[8,32] dot(x, y), lhs_batch_dims={0}, + rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} } - )"; + )", + {{"${tname}", tname}}); + HloModuleConfig config; + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + config.set_debug_options(debug_options); - RunTest(hlo_text, {}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloTestBase::ParseAndReturnVerifiedModule(hlo_text, config)); + EXPECT_TRUE( + HloTestBase::RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); } XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, From a773e97308d847200076b39ba42da682f09c1603 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 01:23:12 -0700 Subject: [PATCH 050/483] Automated Code Change PiperOrigin-RevId: 676742825 --- third_party/xla/xla/service/gpu/fusions/mlir/BUILD | 5 +++++ .../xla/service/gpu/fusions/mlir/computation_partitioner.cc | 6 +++++- .../xla/service/gpu/fusions/mlir/computation_partitioner.h | 2 ++ .../xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 1 + .../xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h | 2 ++ .../service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc | 5 ++--- .../xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc | 1 - .../service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc | 1 + 8 files changed, 18 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index f6890b24806648..4bf1a10b2f7025 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -21,6 +21,7 @@ cc_library( ":type_util", "//xla:shape_util", "//xla:union_find", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions:fusion_emitter", @@ -39,6 +40,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", ], ) @@ -68,6 +70,7 @@ cc_library( "//xla:comparison_util", "//xla:shape_util", "//xla:status_macros", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir/utils:type_util", "//xla/mlir_hlo", @@ -130,6 +133,7 @@ xla_cc_test( "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:DLTIDialect", "@llvm-project//mlir:FuncDialect", @@ -240,6 +244,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc index 53d8678e953074..9d3b772b7f2bac 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc @@ -32,9 +32,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -43,15 +44,18 @@ limitations under the License. #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h index 5d5c78c4cd64aa..5271cd1afa1754 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" @@ -31,6 +32,7 @@ limitations under the License. #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/util.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index d94c6d5a038461..68b3ab29359214 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -78,6 +78,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h index 8f109aa3f452fe..c92f2445709e89 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h @@ -22,8 +22,10 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 33046ea54085aa..e683e199ed03ca 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -31,6 +32,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -38,13 +40,10 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/hlo_parser.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 45d66ebca0108f..448d3050c30bd4 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -96,7 +96,6 @@ limitations under the License. #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/llvm_util.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc index f896d5e6b37475..35e14a98200c66 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/raw_ostream.h" From 092dcc7103556869bb25970d15836eadf452951d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 02:02:18 -0700 Subject: [PATCH 051/483] Update GraphDef version to 1991. PiperOrigin-RevId: 676754707 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e7fe3c0842ed14..b427efa92f342f 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1990 // Updated: 2024/9/19 +#define TF_GRAPH_DEF_VERSION 1991 // Updated: 2024/9/20 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 598ebd50a4ba4a6b06b062fdb987b0740e6a073f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 02:02:18 -0700 Subject: [PATCH 052/483] compat: Update forward compatibility horizon to 2024-09-20 PiperOrigin-RevId: 676754708 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 6f2b1be39c6ebd..9561bd3a5f1fde 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 19) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 20) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 37c5bd1983aa5974ac3da93375d9c466426ebde8 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Fri, 20 Sep 2024 02:35:57 -0700 Subject: [PATCH 053/483] Pad dynamically allocated tensors with XNN_EXTRA_BYTES Since XNNPack can now be applied to all models, these buffers may be passed to XNNPack. PiperOrigin-RevId: 676764594 --- tensorflow/lite/core/c/common.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/core/c/common.cc b/tensorflow/lite/core/c/common.cc index d2cb82199bfce0..09e71d578f4c25 100644 --- a/tensorflow/lite/core/c/common.cc +++ b/tensorflow/lite/core/c/common.cc @@ -295,7 +295,8 @@ TfLiteStatus TfLiteTensorResizeMaybeCopy(size_t num_bytes, TfLiteTensor* tensor, #ifdef TF_LITE_TENSORFLOW_PROFILER tflite::PauseHeapMonitoring(/*pause=*/true); #endif - size_t alloc_bytes = num_bytes; + // This buffer may be consumed by XNNPack. + size_t alloc_bytes = num_bytes + /*XNN_EXTRA_BYTES=*/16; // TODO(b/145340303): Tensor data should be aligned. if (!tensor->data.data) { tensor->data.data = (char*)malloc(alloc_bytes); From 8d06f3055e9556d421afb8abae5668d6edcaedce Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 03:00:43 -0700 Subject: [PATCH 054/483] Automated Code Change PiperOrigin-RevId: 676771234 --- tensorflow/lite/minimal_logging.cc | 2 ++ tensorflow/lite/model_flex_test.cc | 1 - tensorflow/lite/model_xnnpack_test.cc | 1 - tensorflow/lite/mutable_op_resolver_test.cc | 1 - 4 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/minimal_logging.cc b/tensorflow/lite/minimal_logging.cc index bdcec47e779359..7b5e4f6245a567 100644 --- a/tensorflow/lite/minimal_logging.cc +++ b/tensorflow/lite/minimal_logging.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/lite/logger.h" + namespace tflite { namespace logging_internal { diff --git a/tensorflow/lite/model_flex_test.cc b/tensorflow/lite/model_flex_test.cc index 987dcc4f234eac..c2257a6e393b83 100644 --- a/tensorflow/lite/model_flex_test.cc +++ b/tensorflow/lite/model_flex_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/core/model_builder.h" -#include "tensorflow/lite/testing/util.h" namespace tflite { diff --git a/tensorflow/lite/model_xnnpack_test.cc b/tensorflow/lite/model_xnnpack_test.cc index 64e8104cb9874d..740518dc05cf54 100644 --- a/tensorflow/lite/model_xnnpack_test.cc +++ b/tensorflow/lite/model_xnnpack_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/core/model_builder.h" -#include "tensorflow/lite/string_type.h" #include "tensorflow/lite/util.h" namespace tflite { diff --git a/tensorflow/lite/mutable_op_resolver_test.cc b/tensorflow/lite/mutable_op_resolver_test.cc index 8622579a3c8aa3..6a76f09575e0bf 100644 --- a/tensorflow/lite/mutable_op_resolver_test.cc +++ b/tensorflow/lite/mutable_op_resolver_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { From b3cf7de169949d1eaf6f47ddaa1f0dcfc6af0bc4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 03:23:13 -0700 Subject: [PATCH 055/483] Automated Code Change PiperOrigin-RevId: 676777271 --- tensorflow/lite/core/c/BUILD | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index 6e0066185483ff..00a1a27ec6d819 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -335,10 +335,7 @@ tflite_cc_library_with_c_headers_test( ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), - visibility = [ - "//tensorflow/lite:__subpackages__", - "@org_tensorflow_lite_support//tensorflow_lite_support/custom_ops:__subpackages__", - ] + common_header_visibility_allowlist(), + visibility = ["//tensorflow/lite:__subpackages__"] + common_header_visibility_allowlist(), deps = [ ":c_api_types", "//tensorflow/lite:tflite_kernel_use_xnnpack_optional", From c772d4766c0ac1e5565016bee7264ab7958220d8 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 20 Sep 2024 03:37:49 -0700 Subject: [PATCH 056/483] PR #17369: [GPU] Guard NCCL headers with if_nccl. Imported from GitHub PR https://github.com/openxla/xla/pull/17369 Copybara import of the project: -- b29838ffcb2be1b7cfbf0145514ff7c04ac97f52 by Ilia Sergachev : [GPU] Guard NCCL headers with if_nccl. Merging this change closes #17369 PiperOrigin-RevId: 676781148 --- third_party/xla/xla/tsl/cuda/BUILD.bazel | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/tsl/cuda/BUILD.bazel b/third_party/xla/xla/tsl/cuda/BUILD.bazel index 393cc5d03648cb..002fe484ae6990 100644 --- a/third_party/xla/xla/tsl/cuda/BUILD.bazel +++ b/third_party/xla/xla/tsl/cuda/BUILD.bazel @@ -13,6 +13,7 @@ load( load( "//xla/tsl:tsl.bzl", "if_cuda_libs", + "if_nccl", ) load("//xla/tsl/cuda:stub.bzl", "cuda_stub") @@ -348,11 +349,10 @@ cc_library( deps = if_cuda_is_configured([ "@com_google_absl//absl/container:flat_hash_set", "@local_config_cuda//cuda:cuda_headers", - "@local_config_nccl//:nccl_headers", "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:load_library", - ]), + ]) + if_nccl(["@local_config_nccl//:nccl"]), ) alias( From fc052337811e951087e057955a3cf40f1d665729 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Fri, 20 Sep 2024 03:45:25 -0700 Subject: [PATCH 057/483] PR #17257: Fix the command buffer scheduling pass return value Imported from GitHub PR https://github.com/openxla/xla/pull/17257 Fixes #17216. Returns true when either parameters are moved, or command buffer is created. Copybara import of the project: -- 410974f533c1846386f001ed97d93a8291e77b1a by Shraiysh Vaishay : Fix the command buffer scheduling pass return value Fixes #17216. Returns true when either parameters are moved, or command buffer is created. Merging this change closes #17257 PiperOrigin-RevId: 676782971 --- .../transforms/command_buffer_scheduling.cc | 19 ++++++-- .../transforms/command_buffer_scheduling.h | 4 +- .../command_buffer_scheduling_test.cc | 47 +++++++++++++++++++ 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index 3d8f11dd0dc7b6..641a37a9659d30 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -438,7 +438,9 @@ CommandBufferScheduling::CollectCommandBufferSequences( // the beginning of the computation. This simplifies the construction of command // buffer computations because we don't need to deal with parameters and // constants that have users outside of a command buffer. -absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( +// Returns true if there is a change in the order of instructions, false +// otherwise. +absl::StatusOr CommandBufferScheduling::MoveParametersAndConstantsToFront( HloComputation* computation) { HloInstructionSequence new_sequence; HloSchedule& schedule = computation->parent()->schedule(); @@ -468,7 +470,11 @@ absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( } schedule.set_sequence(computation, new_sequence); - return absl::OkStatus(); + for (auto [old_i, new_i] : + llvm::zip(sequence.instructions(), new_sequence.instructions())) { + if (old_i != new_i) return true; + } + return false; } //===----------------------------------------------------------------------===// @@ -767,7 +773,7 @@ absl::StatusOr CommandBufferScheduling::Run( if (std::min(device_description_.runtime_version(), device_description_.driver_version()) < se::SemanticVersion{12, 3, 0}) { - erase(kRequireTracing); // cuStreamBeginCaptureToGraph + erase(kRequireTracing); // cuStreamBeginCaptureToGraph } if (std::min(device_description_.runtime_version(), device_description_.driver_version()) < @@ -787,6 +793,7 @@ absl::StatusOr CommandBufferScheduling::Run( std::reverse(order.begin(), order.end()); absl::flat_hash_set processed_command_buffers; + auto changed = false; for (HloComputation* comp : order) { // Skip special computations that do not have lowering to thunks. if (comp->IsFusionComputation() || comp->IsAsyncComputation() || @@ -796,7 +803,8 @@ absl::StatusOr CommandBufferScheduling::Run( // Skip computations that already part of command buffers. if (processed_command_buffers.contains(comp)) continue; - TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp)); + TF_ASSIGN_OR_RETURN(bool changed_, MoveParametersAndConstantsToFront(comp)); + changed |= changed_; std::vector sequences = CollectCommandBufferSequences( @@ -809,6 +817,7 @@ absl::StatusOr CommandBufferScheduling::Run( TF_ASSIGN_OR_RETURN( HloComputation * command_buffer_computation, RewriteCommandBuffer(comp, seq, std::move(command_buffer))); + changed = true; // All computations reachable from a command buffer computation are nested // command buffers (i.e. body computations attached to a while operation). @@ -820,7 +829,7 @@ absl::StatusOr CommandBufferScheduling::Run( } TF_RETURN_IF_ERROR(module->schedule().Update()); - return true; + return changed; } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h index 15f0b2dd4d4da9..71d5b421c1ee56 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h @@ -99,7 +99,9 @@ class CommandBufferScheduling : public HloModulePass { // the beginning of the computation. This simplifies the construction of // command buffer computations because we don't need to deal with parameters // and constants that have users outside of a command buffer. - static absl::Status MoveParametersAndConstantsToFront( + // Returns true if there is a change in the order of instructions, false + // otherwise. + static absl::StatusOr MoveParametersAndConstantsToFront( HloComputation* computation); struct CommandBuffer { diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 6c79316a75a519..be29a09897b54b 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -1228,5 +1228,52 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { false, true, std::nullopt)); } +TEST_F(CommandBufferSchedulingTest, ReturnFalseWhenNoChange) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + ROOT call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), + std::nullopt); +} + +TEST_F(CommandBufferSchedulingTest, ReturnTrueWhenOnlyParamMoved) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + c = s32[8,8] parameter(2) + ROOT call2 = s32[8,8] custom-call(call, c), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), R"( + // CHECK: %{{.+}} = {{.+}} parameter(0) + // CHECK: %{{.+}} = {{.+}} parameter(1) + // CHECK: %{{.+}} = {{.+}} parameter(2) + // CHECK: %{{.+}} = {{.+}} custom-call + // CHECK: %{{.+}} = {{.+}} custom-call + )"); +} + } // namespace } // namespace xla::gpu From 7566d2e736f7e52885763419cce0450468348ddb Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Fri, 20 Sep 2024 04:10:54 -0700 Subject: [PATCH 058/483] [xla:cpu] Specialize SortThunk kernels for up to 25 inputs. Fixes https://github.com/jax-ml/jax/issues/23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676789960 --- .../xla/backends/cpu/runtime/sort_thunk.cc | 21 +++++++++++++++++++ third_party/xla/xla/tests/sort_test.cc | 19 ++++++++++++----- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc index fa892d2df54134..30c5e1a1b34897 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc @@ -471,6 +471,27 @@ static absl::Status SortInplace(absl::Span data, case 17: sort(std::integral_constant{}); break; + case 18: + sort(std::integral_constant{}); + break; + case 19: + sort(std::integral_constant{}); + break; + case 20: + sort(std::integral_constant{}); + break; + case 21: + sort(std::integral_constant{}); + break; + case 22: + sort(std::integral_constant{}); + break; + case 23: + sort(std::integral_constant{}); + break; + case 24: + sort(std::integral_constant{}); + break; case 25: sort(std::integral_constant{}); break; diff --git a/third_party/xla/xla/tests/sort_test.cc b/third_party/xla/xla/tests/sort_test.cc index 3acef7e48170b7..3e25d8ba0039c4 100644 --- a/third_party/xla/xla/tests/sort_test.cc +++ b/third_party/xla/xla/tests/sort_test.cc @@ -91,10 +91,17 @@ XLA_TEST_F(SortTest, SortTwiceWithSameComparator) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } -// TODO(penporn): Parameterize `num_inputs` and test several numbers when we -// have a generic fallback sort kernel. -XLA_TEST_F(SortTest, SortManyInputs) { - constexpr int num_inputs = 17; +class SortManyInputsTest : public SortTest, + public ::testing::WithParamInterface { + public: + static std::string Name(const ::testing::TestParamInfo& info) { + auto num_inputs = info.param; + return absl::StrFormat("Sort%dInputs", num_inputs); + } +}; + +XLA_TEST_P(SortManyInputsTest, SortManyInputs) { + int num_inputs = GetParam(); std::string_view hlo_text_module_template = R"( HloModule sort @@ -133,9 +140,11 @@ XLA_TEST_F(SortTest, SortManyInputs) { {"${SORT_SHAPE}", sort_shape}, {"${SORT_PARAMS}", sort_params}, {"${COMPARE_DECLARATIONS}", compare_decls}}); - EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } +INSTANTIATE_TEST_SUITE_P(ManyInputs, SortManyInputsTest, + ::testing::Values(17, 20), SortManyInputsTest::Name); + } // namespace } // namespace xla From 40252ea9c5885638c2ab35c04b4e483b0cef983d Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 20 Sep 2024 05:01:51 -0700 Subject: [PATCH 059/483] Rewrite column reductions to reduce-transpose-reduce. PiperOrigin-RevId: 676801883 --- .../fusions/transforms/rewrite_reductions.cc | 227 +++++++++++++++--- .../transforms/tests/rewrite_reductions.mlir | 54 ++++- 2 files changed, 234 insertions(+), 47 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc b/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc index 620fa8e51d30a8..50969b8bd6bbd8 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc @@ -12,14 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -30,6 +34,7 @@ limitations under the License. #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/util.h" namespace xla { @@ -62,23 +67,55 @@ int GetNumThreads(mlir::Operation* op) { return Product(grid.getThreadCounts()); } -std::pair GetNumAndSizeOfMinorReducedDimensions(ReduceOp op) { +struct DimensionGroup { + int64_t size; + int64_t stride; + int first_dimension; + int num_dimensions; +}; + +DimensionGroup GetMinorMostReduction(ReduceOp op) { llvm::ArrayRef dims = op.getDimensions(); + auto input_ty = GetInputType(op); - int64_t cumulative_size = 1; - for (int i = 0; i < dims.size(); ++i) { - // The expected next reduction dimension if it is contiguous with the - // previously reduced dimensions. - int expected_dim = input_ty.getRank() - 1 - i; - // If the next reduced dimension is not the expected one, it is not - // contiguous (i.e., it's not part of the minor reduced dimensions, there is - // a kept dimension in between). - if (dims[dims.size() - 1 - i] != expected_dim) { - return {i, cumulative_size}; + DimensionGroup result{1, 1, static_cast(input_ty.getRank()), 0}; + llvm::SmallBitVector reduced_dims(input_ty.getRank()); + for (int64_t dim : dims) { + reduced_dims.set(dim); + } + + // Look for the first group of consecutive reduced dimensions and compute the + // stride and size of the group. + bool in_reduction = false; + for (int dim = input_ty.getRank() - 1; + dim >= 0 && (!in_reduction || reduced_dims[dim]); --dim) { + assert(input_ty.getDimSize(dim) > 1 && + "degenerate dimensions are not allowed"); + --result.first_dimension; + if (reduced_dims[dim]) { + in_reduction = true; + result.size *= input_ty.getDimSize(dim); + ++result.num_dimensions; + } else { + result.stride *= input_ty.getDimSize(dim); } - cumulative_size *= input_ty.getDimSize(input_ty.getRank() - 1 - i); } - return {dims.size(), cumulative_size}; + + return result; +} + +llvm::SmallVector ReindexTensors( + mlir::OpBuilder& b, mlir::ValueRange tensors, mlir::ValueRange defaults, + llvm::ArrayRef new_shape, const IndexingMap& map) { + llvm::SmallVector reindexed; + reindexed.reserve(tensors.size()); + for (auto [tensor, def] : llvm::zip(tensors, defaults)) { + auto new_ty = + mlir::cast(tensor.getType()).clone(new_shape); + reindexed.push_back( + b.create(tensor.getLoc(), new_ty, tensor, def, map)); + } + return reindexed; } // Rewrites large row reductions to three reductions: @@ -94,13 +131,12 @@ struct RewriteRowReduction : mlir::OpRewritePattern { ReduceOp op, mlir::PatternRewriter& rewriter) const override { auto* ctx = op.getContext(); - auto [num_minor_dims, reduced_size] = - GetNumAndSizeOfMinorReducedDimensions(op); - if (num_minor_dims == 0) { + auto minor_reduction = GetMinorMostReduction(op); + if (minor_reduction.stride > 1) { return rewriter.notifyMatchFailure(op, "not a row reduction"); } - if (reduced_size <= WarpSize()) { + if (minor_reduction.size <= WarpSize()) { return rewriter.notifyMatchFailure(op, "small minor dimension"); } @@ -108,9 +144,9 @@ struct RewriteRowReduction : mlir::OpRewritePattern { assert(num_threads % WarpSize() == 0); llvm::ArrayRef input_shape = GetInputType(op).getShape(); - llvm::SmallVector projected_input_shape{ - input_shape.begin(), input_shape.end() - num_minor_dims}; - projected_input_shape.push_back(reduced_size); + auto projected_input_shape = llvm::to_vector( + input_shape.take_front(minor_reduction.first_dimension)); + projected_input_shape.push_back(minor_reduction.size); // Collapse the minor dimensions into one. // [..., 123, 456] -> [..., 123 * 456] @@ -120,14 +156,14 @@ struct RewriteRowReduction : mlir::OpRewritePattern { // Pad the new minor dimension to a multiple of the number of threads. For // example, for 128 threads, 123 * 456 = 56088 is padded to 56192. auto padded_projected_input_shape = projected_input_shape; - int64_t padded_size = RoundUpTo(reduced_size, num_threads); + int64_t padded_size = RoundUpTo(minor_reduction.size, num_threads); padded_projected_input_shape.back() = padded_size; // Reshape the padded minor dimension so that we can reduce it per thread // and then per warp. // [..., 56192] -> [..., 439, 4, 32] - llvm::SmallVector per_thread_reduction_input_shape( - input_shape.begin(), input_shape.end() - num_minor_dims); + auto per_thread_reduction_input_shape = llvm::to_vector( + input_shape.take_front(minor_reduction.first_dimension)); per_thread_reduction_input_shape.push_back(padded_size / num_threads); per_thread_reduction_input_shape.push_back(num_threads / WarpSize()); per_thread_reduction_input_shape.push_back(WarpSize()); @@ -141,24 +177,18 @@ struct RewriteRowReduction : mlir::OpRewritePattern { mlir::getAffineDimExpr(per_thread_input_rank - 1, ctx) + mlir::getAffineDimExpr(per_thread_input_rank - 2, ctx) * num_threads, - {0, reduced_size - 1}); - - // Reshape the inputs. - llvm::SmallVector new_operands; - new_operands.reserve(op.getOperands().size()); - for (auto [operand, init] : llvm::zip(op.getInputs(), op.getInits())) { - auto new_input_ty = mlir::cast(operand.getType()) - .clone(per_thread_reduction_input_shape); - new_operands.push_back(rewriter.create( - operand.getLoc(), new_input_ty, operand, init, reindex_map)); - } + {0, minor_reduction.size - 1}); + + auto new_inputs = + ReindexTensors(rewriter, op.getInputs(), op.getInits(), + per_thread_reduction_input_shape, reindex_map); // Reduce the non-minor dimensions and the third to last dimension. - auto dims_for_first_reduction = - llvm::to_vector(op.getDimensions().drop_back(num_minor_dims)); + auto dims_for_first_reduction = llvm::to_vector( + op.getDimensions().drop_back(minor_reduction.num_dimensions)); dims_for_first_reduction.push_back(per_thread_input_rank - 3); auto first_reduction = - rewriter.create(op.getLoc(), new_operands, op.getInits(), + rewriter.create(op.getLoc(), new_inputs, op.getInits(), dims_for_first_reduction, op.getCombiner()); // Reduce the last and the second-to-last dimensions. First to produce one @@ -175,9 +205,130 @@ struct RewriteRowReduction : mlir::OpRewritePattern { } }; +// Rewrites column reductions to a reduce-transpose-reduce. +struct RewriteColumnReduction : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + ReduceOp op, mlir::PatternRewriter& rewriter) const override { + auto* ctx = op.getContext(); + + auto minor_reduction = GetMinorMostReduction(op); + + if (minor_reduction.stride == 1) { + return rewriter.notifyMatchFailure(op, "not a column reduction"); + } + + int64_t num_threads = GetNumThreads(op); + + // If the stride is larger than the number of threads, we can efficiently + // emit this reduction as a simple loop, assuming there's no excessive + // padding. + // TODO(jreiffers): Is there anything we can do if the number of threads + // doesn't divide the stride? + if (minor_reduction.stride >= num_threads) { + return rewriter.notifyMatchFailure(op, "efficient loop reduction"); + } + + // A column reduction reduces [a, b] to [b]. We do this in four steps: + // 1. reshape [a, b] to [a ceildiv c, c, b] + // 2. reduce [a ceildiv c, c, b] to [c, b] via a loop + // 3. transpose [c, b] to [b, c] + // 4. emit a row reduction on [b, c]. + // + // We are constrained in our choice for `c`: + // + // - we need one element of shared memory (or a register) for each element + // of the intermediate results, so a larger c needs more shared memory. + // - we can have at most WarpSize intermediate results per final result, + // so c can be at most 32. + // - c must be a power of two so we can use a warp shuffle. + // - c * b should be less than the number of threads (but as close to it + // as possible, so we don't have excessive padding). + // + // All of this assumes no vectorization. + // TODO(jreiffers): Handle vectorization here. + + // Emitters always choose `c = 32` if `b` is not a small power of two. + // Also, reductions are tiled so `b = 32`. The number of threads is always + // 1024. This satisfies all the constraints above. + // Reduce the size of the reduction dimension. The maximum size we can + // handle is the warp size. + + assert(num_threads > minor_reduction.stride); + int64_t c = std::min(WarpSize(), num_threads / minor_reduction.stride); + + llvm::ArrayRef input_shape = GetInputType(op).getShape(); + auto projected_input_shape = llvm::to_vector( + input_shape.take_front(minor_reduction.first_dimension)); + projected_input_shape.push_back(minor_reduction.size); + projected_input_shape.push_back(minor_reduction.stride); + auto projection_map = + GetBitcastMap(projected_input_shape, input_shape, ctx); + int64_t projected_rank = projected_input_shape.size(); + + // Pad the new minor dimension to a multiple of c. + auto padded_projected_input_shape = projected_input_shape; + int64_t padded_size = RoundUpTo(minor_reduction.size, c); + padded_projected_input_shape[projected_rank - 2] = padded_size; + + // Reshape the input to [..., a ceildiv c, c, b] + auto reshaped_input_shape = llvm::to_vector( + input_shape.take_front(minor_reduction.first_dimension)); + reshaped_input_shape.push_back(padded_size / c); + reshaped_input_shape.push_back(c); + reshaped_input_shape.push_back(minor_reduction.stride); + int64_t reshaped_rank = reshaped_input_shape.size(); + + auto reindex_map = + GetBitcastMap(reshaped_input_shape, padded_projected_input_shape, ctx) * + projection_map; + reindex_map.AddConstraint( + mlir::getAffineDimExpr(reshaped_rank - 2, ctx) + + mlir::getAffineDimExpr(reshaped_rank - 3, ctx) * c, + {0, minor_reduction.size - 1}); + + auto new_inputs = ReindexTensors(rewriter, op.getInputs(), op.getInits(), + reshaped_input_shape, reindex_map); + + // Reduce the non-minor dimensions and the third to last dimension. + // [..., a ceildiv c, c, b] -> [..., c, b] + auto dims_for_first_reduction = llvm::to_vector( + op.getDimensions().drop_back(minor_reduction.num_dimensions)); + dims_for_first_reduction.push_back(reshaped_rank - 3); + auto first_reduction = + rewriter.create(op.getLoc(), new_inputs, op.getInits(), + dims_for_first_reduction, op.getCombiner()); + + // Transpose [..., c, b] to [..., b, c] + auto shape = GetOutputType(first_reduction).getShape(); + int64_t first_reduction_rank = shape.size(); + llvm::SmallVector permutation(first_reduction_rank); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[first_reduction_rank - 1], + permutation[first_reduction_rank - 2]); + + auto transposed_shape = llvm::to_vector(shape); + std::swap(transposed_shape[first_reduction_rank - 1], + transposed_shape[first_reduction_rank - 2]); + IndexingMap transpose_map( + mlir::AffineMap::getPermutationMap(permutation, ctx), + DimVarsFromTensorSizes(transposed_shape), {}, {}); + + auto transposed = + ReindexTensors(rewriter, first_reduction.getResults(), op.getInits(), + transposed_shape, transpose_map); + + rewriter.replaceOpWithNewOp( + op, transposed, op.getInits(), + llvm::ArrayRef{first_reduction_rank - 1}, op.getCombiner()); + return mlir::success(); + } +}; + void RewriteReductionsPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir index fe04e7f95463e0..94c6cddd4a8a40 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir @@ -37,8 +37,8 @@ func.func @add(%a: f32, %b: f32) -> f32 { return %0 : f32 } -func.func @row_reduction_with_major_reduced_dim(%arg0: tensor<1x42x128x32x8xf32>) - -> tensor<1x128xf32> attributes { +func.func @row_reduction_with_major_reduced_dim(%arg0: tensor<2x42x128x32x8xf32>) + -> tensor<2x128xf32> attributes { xla_gpu.launch_grid = #xla_gpu.launch_grid< block_counts = [42, 1, 1], thread_counts = [128, 1, 1] @@ -46,12 +46,48 @@ func.func @row_reduction_with_major_reduced_dim(%arg0: tensor<1x42x128x32x8xf32> } { %c0 = arith.constant 0.0 : f32 %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1, 3, 4] combiner=@add - : tensor<1x42x128x32x8xf32> to tensor<1x128xf32> - return %0 : tensor<1x128xf32> + : tensor<2x42x128x32x8xf32> to tensor<2x128xf32> + return %0 : tensor<2x128xf32> } -// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex -// CHECK-SAME: : tensor<1x42x128x32x8xf32> -> tensor<1x42x128x2x4x32xf32> -// CHECK: xla_gpu.reduce(%[[REINDEXED]]) -// CHECK-SAME: dimensions=[1, 3] -// CHECK-SAME: : tensor<1x42x128x2x4x32xf32> +// CHECK-LABEL: @row_reduction_with_major_reduced_dim +// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex +// CHECK-SAME: : tensor<2x42x128x32x8xf32> -> tensor<2x42x128x2x4x32xf32> +// CHECK: xla_gpu.reduce(%[[REINDEXED]]) +// CHECK-SAME: dimensions=[1, 3] +// CHECK-SAME: : tensor<2x42x128x2x4x32xf32> + +// ----- + +func.func @add(%a: f32, %b: f32) -> f32 { + %0 = arith.addf %a, %b : f32 + return %0 : f32 +} + +func.func @column(%arg0: tensor<2x32x32xf32>) + -> tensor<2x32xf32> attributes { + xla_gpu.launch_grid = #xla_gpu.launch_grid< + block_counts = [42, 1, 1], + thread_counts = [128, 1, 1] + > + } { + %c0 = arith.constant 0.0 : f32 + %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1] combiner=@add + : tensor<2x32x32xf32> to tensor<2x32xf32> + return %0 : tensor<2x32xf32> +} + +// CHECK: #[[$RESHAPE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3) +// CHECK-SAME: d1 * 4 + d2 in [0, 31] +// CHECK: #[[$TRANSPOSE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0, d2, d1) +// CHECK-LABEL: @column +// CHECK-SAME: %[[IN:.*]]: tensor<2x32x32xf32> +// CHECK: %[[C0:.*]] = arith.constant 0.00 +// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex %[[IN]] at #[[$RESHAPE]] default %[[C0]] +// CHECK-SAME: -> tensor<2x8x4x32xf32> +// CHECK: %[[R1:.*]] = xla_gpu.reduce(%[[REINDEXED]]) inits(%[[C0]]) dimensions=[1] +// CHECK-SAME: to tensor<2x4x32xf32> +// CHECK: %[[TRANSPOSED:.*]] = xla_gpu.reindex %[[R1]] at #[[$TRANSPOSE]] +// CHECK-SAME: -> tensor<2x32x4xf32> +// CHECK: %[[R2:.*]] = xla_gpu.reduce(%[[TRANSPOSED]]) inits(%[[C0]]) dimensions=[2] +// CHECK: return %[[R2]] : tensor<2x32xf32> From 5de25e012376ecd81ec0cc23694057edc217420a Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Fri, 20 Sep 2024 05:11:20 -0700 Subject: [PATCH 060/483] Fix race condition in dumping logic in xla::Executable This is adding a mutex to the lazy initialization that happens in a `const` member function. It also adds a test which hopefully ensures that this getter stays thread compatible. PiperOrigin-RevId: 676804217 --- third_party/xla/xla/service/BUILD | 19 ++++ third_party/xla/xla/service/executable.h | 14 ++- .../xla/xla/service/executable_test.cc | 86 +++++++++++++++++++ 3 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/executable_test.cc diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 808f7e3e00440e..2f770516b8d9d1 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1712,11 +1712,13 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", "//xla/tsl/lib/strings:proto_serialization", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@local_tsl//tsl/platform:env", @@ -1727,6 +1729,23 @@ cc_library( ] + internal_hlo_deps(), ) +xla_cc_test( + name = "executable_test", + srcs = ["executable_test.cc"], + deps = [ + ":executable", + ":hlo_execution_profile", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "compiler", srcs = ["compiler.cc"], diff --git a/third_party/xla/xla/service/executable.h b/third_party/xla/xla/service/executable.h index f1a08cd570b88b..6ff23a5ef2c8ee 100644 --- a/third_party/xla/xla/service/executable.h +++ b/third_party/xla/xla/service/executable.h @@ -18,11 +18,14 @@ limitations under the License. #include #include +#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/log/check.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "xla/debug_options_flags.h" @@ -376,6 +379,12 @@ class Executable { // Dumping helpers. void set_hlo_proto(std::unique_ptr hlo_proto) { + // Despite the mutex lock, this function is NOT thread-safe. + // The mutex is needed for the lazy HLO module loading in `hlo_proto()`. + // Since both `hlo_proto()` and `buffer_assignment_proto()` return a + // pointer to hlo_proto_, having the mutex is not enough to make this + // function thread-safe. + absl::MutexLock lock(&hlo_proto_mutex_); hlo_proto_ = std::move(hlo_proto); } bool dumping_snapshot() const { @@ -385,6 +394,7 @@ class Executable { } HloProto const* hlo_proto() const { + absl::MutexLock lock(&hlo_proto_mutex_); if (hlo_proto_ != nullptr && !hlo_proto_->has_hlo_module()) { *hlo_proto_->mutable_hlo_module() = module().ToProto(); } @@ -392,6 +402,7 @@ class Executable { } const BufferAssignmentProto* buffer_assignment_proto() const { + absl::MutexLock lock(&hlo_proto_mutex_); return hlo_proto_ != nullptr && hlo_proto_->has_buffer_assignment() ? &hlo_proto_->buffer_assignment() : nullptr; @@ -441,7 +452,8 @@ class Executable { // hlo_proto_->buffer_assignment is set and hlo_proto_->hlo_module isn't, the // hlo_module proto will be computed on the fly when requested with // hlo_proto(). This avoids wasting CPU and memory if the proto isn't needed. - std::unique_ptr hlo_proto_; + std::unique_ptr hlo_proto_ ABSL_GUARDED_BY(hlo_proto_mutex_); + mutable absl::Mutex hlo_proto_mutex_; }; } // namespace xla diff --git a/third_party/xla/xla/service/executable_test.cc b/third_party/xla/xla/service/executable_test.cc new file mode 100644 index 00000000000000..388b7be1bd44a7 --- /dev/null +++ b/third_party/xla/xla/service/executable_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/executable.h" + +#include +#include +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_execution_profile.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace { + +class TestExecutable : public Executable { + public: + explicit TestExecutable(std::shared_ptr module) + : Executable{std::move(module)} {} + + absl::StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector arguments, + HloExecutionProfile* hlo_execution_profile) override { + return absl::UnimplementedError("Not needed for this test."); + } +}; + +class ExecutableTest : public HloTestBase {}; + +TEST_F(ExecutableTest, HloProtoGetterIsThreadCompatible) { + // Executable::hlo_proto() is doing some lazy initialization of a + // part of `hlo_proto_`. This test ensures that this is done in a + // thread-compatible way. + // Note that this test needs to run with --config=tsan to reliably + // detect any potential data races. + constexpr std::string_view kHloModule = R"( + HloModule module + + ENTRY main { + ROOT c = s32[] constant(1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + TestExecutable executable(module); + + auto proto = std::make_unique(); + executable.set_hlo_proto(std::move(proto)); + + { + tsl::thread::ThreadPool pool(tsl::Env::Default(), "test", + /*num_threads=*/2); + for (int i = 0; i < 2; ++i) { + pool.Schedule([&] { executable.hlo_proto()->SerializeAsString(); }); + } + } +} + +} // namespace +} // namespace xla From dc899b5e8ad468e13357dda6205a46c12c6aaca1 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 20 Sep 2024 05:50:15 -0700 Subject: [PATCH 061/483] [PjRt-IFRT] Migrate the include file for IFRT/XLA DType conversion functions This change updates the header include file from `pjrt_array.h` to `pjrt_dtype.h` for IFRT/XLA DType conversion functions. PiperOrigin-RevId: 676813378 --- third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc | 1 + third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc | 1 + third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc | 1 + 3 files changed, 3 insertions(+) diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 88a0d858d33b97..5c1abd46ba6e73 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -45,6 +45,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index b6aed371cca9dc..9400e282fa07fc 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -69,6 +69,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/pjrt_remap.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index c665cbb0cdd68e..19c441cdea448d 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" From c33548124c4be6a57680f253a240061adaf916a3 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Fri, 20 Sep 2024 05:52:40 -0700 Subject: [PATCH 062/483] PR #16879: General offset computation for dynamic slice fusion Imported from GitHub PR https://github.com/openxla/xla/pull/16879 This patch adds logic for computing general offset while creating dynamic-slice-fusion. Copybara import of the project: -- 636988837ad87f0bb0b4906ff5653e6982a5cad0 by Shraiysh Vaishay : General offset computation for dynamic slice fusion This patch adds logic for computing general offset while creating dynamic-slice-fusion. Merging this change closes #16879 PiperOrigin-RevId: 676813991 --- .../gpu/fusions/dynamic_slice_fusion_test.cc | 96 +++ .../xla/xla/service/gpu/transforms/BUILD | 5 + .../dynamic_slice_fusion_rewriter.cc | 547 +++++++++++++++--- .../dynamic_slice_fusion_rewriter_test.cc | 306 +++++++--- 4 files changed, 790 insertions(+), 164 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc index 212e6b51e5445d..415ed7da7f2bf5 100644 --- a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -3650,6 +3651,101 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateSlice) { false, true, error)); } +TEST_F(DynamicSliceFusionTest, TestWithRewriter) { + const char* hlo = R"( + HloModule test_module, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a, b) + } + + Body { + param = (s32[], s32[16, 32], s32[8, 32]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + dest = s32[16,32] get-tuple-element(param), index=1 + src = s32[8,32] get-tuple-element(param), index=2 + eight = s32[] constant(8) + zero = s32[] constant(0) + thirty_two = s32[] constant(32) + add = s32[] add(eight, i) + add.2 = s32[] subtract(add, thirty_two) + compare = pred[] compare(add, thirty_two), direction=LT + offset = s32[] select(compare, add, add.2) + rs = s32[4,32] reduce-scatter(src), channel_id=0, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + fusion = s32[16,32] dynamic-update-slice(dest, rs, offset, zero) + one = s32[] constant(1) + i_plus_one = s32[] add(i, one) + ROOT tuple = tuple(i_plus_one, fusion, src) + } + + Cond { + param = (s32[], s32[16,32], s32[8,32]) parameter(0) + loop_iter = s32[] get-tuple-element(param), index=0 + c32 = s32[] constant(32) + ROOT compare = pred[] compare(loop_iter, c32), direction=LT + } + + ENTRY main { + zero = s32[] constant(0) + dest = s32[16,32] parameter(0) + src = s32[8,32] parameter(1) + tuple = tuple(zero, dest, src) + ROOT while = while(tuple), body=Body, condition=Cond + } + )"; + + HloModuleConfig config; + DebugOptions dboptions; + dboptions.set_xla_gpu_enable_dynamic_slice_fusion(false); + config.set_debug_options(dboptions); + TF_ASSERT_OK_AND_ASSIGN(auto module0, + ParseAndReturnVerifiedModule(hlo, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto module_without_fusion, + GetOptimizedModule(std::move(module0))); + dboptions.set_xla_gpu_enable_dynamic_slice_fusion(true); + config.set_debug_options(dboptions); + TF_ASSERT_OK_AND_ASSIGN(auto module1, + ParseAndReturnVerifiedModule(hlo, config)); + TF_ASSERT_OK_AND_ASSIGN(auto module_with_fusion, + GetOptimizedModule(std::move(module1))); + + ASSERT_EQ(GetDynamicSliceFusions(*module_without_fusion).size(), 0); + auto fusions = GetDynamicSliceFusions(*module_with_fusion); + ASSERT_EQ(fusions.size(), 1); + HloPrintOptions options; + options.set_print_large_constants(true) + .set_print_result_shape(false) + .set_print_operand_shape(false); + TF_ASSERT_OK_AND_ASSIGN(auto filecheck_fusion, + RunFileCheck(fusions[0]->ToString(options), + R"( + // CHECK-DAG: %[[rs:.+]] = reduce-scatter({{.+}}) + // CHECK-DAG: %[[offset_vals:.+]] = constant({8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7}) + // CHECK-DAG: %[[offset_as_arr:.+]] = dynamic-slice(%[[offset_vals]], {{.+}}), dynamic_slice_sizes={1} + // CHECK-DAG: %[[offset:.+]] = reshape(%[[offset_as_arr]]) + // CHECK-DAG: ROOT %{{.+}} = dynamic-update-slice({{.+}}, %[[rs]], %[[offset]], {{.+}}) + )")); + EXPECT_TRUE(filecheck_fusion); + TF_ASSERT_OK_AND_ASSIGN( + auto filecheck_while_loop, + RunFileCheck(fusions[0]->FusionInstruction()->parent()->ToString(options), + R"( + // CHECK-DAG: %[[p:.+]] = parameter(0) + // CHECK-DAG: %[[loop_counter:.+]] = get-tuple-element(%[[p]]), index=3 + // CHECK-DAG: %[[address_computation:.+]] = fusion({{.+}}, %[[loop_counter]]), kind=kCustom + // CHECK-DAG: %[[updated_loop_counter:.+]] = add(%[[loop_counter]], {{.+}}) + // CHECK-DAG: ROOT {{.+}} = tuple({{.+}}, %[[address_computation]], {{.+}}, %[[updated_loop_counter]]) + )")); + EXPECT_TRUE(filecheck_while_loop); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated( + std::move(module_without_fusion), std::move(module_with_fusion), false, + true, error_spec)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 96cdb279211540..6ee8cf82fda2e6 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -1423,19 +1423,23 @@ cc_library( hdrs = ["dynamic_slice_fusion_rewriter.h"], tags = ["gpu"], deps = [ + "//xla:literal_util", "//xla:shape_util", "//xla:util", "//xla/ffi:ffi_api", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:custom_call_target_registry", "//xla/service:pattern_matcher", + "//xla/service:while_loop_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:gpu_constants", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/kernels:custom_fusion_library", + "//xla/tools:hlo_extractor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1473,6 +1477,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index 8ea3bc3801062e..a58bed9d75697b 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -33,11 +33,14 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" @@ -45,8 +48,10 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/pattern_matcher.h" +#include "xla/service/while_loop_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tools/hlo_extractor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -77,6 +82,9 @@ using DataflowPathsView = absl::Span; using InstructionSet = absl::flat_hash_set; +using OffsetValueMap = + absl::flat_hash_map>; + bool IsNoOp(const HloInstruction* hlo) { return HloPredicateIsOp(hlo); @@ -152,99 +160,424 @@ bool IsAlignedSlice(const HloInstruction* slice) { return true; } -// Pattern matches the following IR (generated by `jax.lax.scan`) to check if -// the offset is a loop iteration number: - -// clang-format off -// param = (s32[], s32[], s32[16]{0}, s32[16]{0}) parameter(0) -// // the index in `gte` has to be the loop iteration index -// gte = s32[] get-tuple-element(param), index=0 -// c0 = s32[] constant(0) compare = pred[] compare(gte, c0), direction=LT -// c_trip_count = s32[] constant(16) -// add = s32[] add(gte, c_trip_count) select = s32[] select(compare, add, gte) -// clang-format on - -bool IsLoopIterationNumber(const HloInstruction& offset) { - const HloComputation* parent = offset.parent(); - if (!parent->IsWhileBodyComputation()) return false; - - // Scan loops trip count must be known at compile time as it iterates over the - // leading dimension of the statically shaped input. - const HloInstruction* while_instr = parent->WhileCallInstruction(); - auto config = while_instr->backend_config(); - if (!config.ok() || !config->has_known_trip_count()) return false; - int32_t trip_count = config->known_trip_count().n(); - - // First lets check the offset computation pattern - if (!Match(&offset, m::Select(m::Lt(m::GetTupleElement(m::Parameter(0)), - m::ConstantScalar(0)), - m::Add(m::GetTupleElement(m::Parameter(0)), - m::ConstantScalar(trip_count)), - m::GetTupleElement(m::Parameter())))) { - return false; +// Function looks for while backend config. If this config is present, it +// returns the value of trip count, otherwise it runs the while loop analysis to +// compute trip count. `whileop` must be a while operaton. Returns +// `std::nullopt` if it cannot figure out the trip count. +std::optional GetWhileLoopTripCount(HloInstruction* whileop) { + CHECK(whileop->opcode() == HloOpcode::kWhile); + auto backend_config = whileop->backend_config(); + if (!backend_config.ok() || !backend_config.value().has_known_trip_count()) { + VLOG(4) << "Backend config not ok. Computing while loop trip count for " + << whileop->name(); + return ComputeWhileLoopTripCount(whileop); } + int trip_count = backend_config.value().known_trip_count().n(); + VLOG(4) << "Found trip count in backend config for " << whileop->name() + << ": " << trip_count; + return trip_count; +} - // Next, we check that the parameter used in offset computation is the loop - // induction variable - int64_t param_idx = offset.operand(2)->tuple_index(); - const HloInstruction* root = offset.parent()->root_instruction(); - if (root->opcode() != HloOpcode::kTuple) { - return false; +// Given an HLO operation `idx`, which is wrapped by while operation, this +// function tries to find the values of the variable in all the iterations as an +// array of literals. This is done by repeatedly executing the loop update +// operation(s) and the operation(s) to calculate the value of `idx` at each +// iteration. If this is successful, then the vector of literals is returned. If +// for some reason this is not successful then `std::nullopt` is returned. +std::optional> GetValues(const HloInstruction* idx) { + VLOG(3) << "Getting values for " << idx->name(); + const HloComputation* computation = idx->parent(); + if (!computation->IsWhileBodyComputation()) { + VLOG(3) << "While calculating offset values for " << idx->name() + << ", the parent computation(" << computation->name() + << ") is not a while computation"; + return std::nullopt; } - // Check the update operation - const HloInstruction* updated_var = - offset.parent()->root_instruction()->operand(param_idx); - if (!Match(updated_var, m::Add(m::GetTupleElement(m::Parameter(0), param_idx), - m::ConstantScalar(1)))) { - return false; + HloInstruction* whileop = computation->WhileCallInstruction(); + std::optional trip_count = GetWhileLoopTripCount(whileop); + if (trip_count == std::nullopt) { + VLOG(3) << "Unable to get trip count for " << whileop->name(); + return std::nullopt; } - // Check that the condition considers this. - const HloInstruction* condition_root = - while_instr->while_condition()->root_instruction(); - if (!Match(condition_root, - m::Lt(m::GetTupleElement(m::Parameter(0), param_idx), - m::ConstantScalar(trip_count)))) { - return false; + auto root_tuple = computation->root_instruction(); + if (root_tuple->opcode() != HloOpcode::kTuple) { + VLOG(3) << "Root operation " << root_tuple->name() << " of computation " + << computation->name() + << " expected to be a tuple because it is a while body. Found: " + << root_tuple->opcode(); + return std::nullopt; } - // Check init - const HloInstruction* init_loop_iter = - while_instr->operand(0)->operand(param_idx); - if (!Match(init_loop_iter, m::ConstantScalar(0))) { - return false; + std::optional loop_indvar_tuple_idx = + GetLoopInductionVarTupleIdx(whileop); + if (loop_indvar_tuple_idx == std::nullopt) { + VLOG(3) << "Unable to find tuple index for loop induction variable"; + return std::nullopt; + } + auto update_operation = + computation->root_instruction()->operand(*loop_indvar_tuple_idx); + HloInstruction* loop_indvar = nullptr; + for (auto instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->operand(0) == computation->parameter_instruction(0) && + instr->tuple_index() == *loop_indvar_tuple_idx) { + loop_indvar = instr; + } + } + if (loop_indvar == nullptr) { + VLOG(3) << "Unable to find get-tuple-element(" + << computation->parameter_instruction(0)->name() + << "), index=" << *loop_indvar_tuple_idx << " in " + << computation->name(); + return std::nullopt; } - return true; + // Extract the offset and update modules and verify that they only take the + // loop iteration counter as parameter. + // The operation we are extracting (update and offset) are from `computation`. + // In the `extract_selector`, we stop at the parameter (tuple) for this + // `computation` or at the loop induction variable and convert that to a + // parameter. If the operation depends on the tuple parameter, then the + // argument to the extracted module will have the shape of a tuple. So, if the + // extracted module has only one parameter and the shape of that parameter is + // same as the loop induction variable, then the operation only depends on the + // loop induction variable. We also have to ensure there are no `partition-id` + // or `replica-id` operations in the extracted module. + auto IsValidModule = + [loop_indvar](std::unique_ptr& module) -> bool { + if (module == nullptr || module->entry_computation()->num_parameters() != 1) + return false; + const HloInstruction* p0 = + module->entry_computation()->parameter_instruction(0); + if (p0->shape() != loop_indvar->shape()) { + VLOG(4) << "Extracted module must depend only on the loop induction " + "variable."; + return false; + }; + return llvm::all_of(module->entry_computation()->instructions(), + [](const HloInstruction* instr) { + return instr->opcode() != HloOpcode::kPartitionId && + instr->opcode() != HloOpcode::kReplicaId; + }); + }; + auto params = computation->parameter_instructions(); + if (params.size() != 1 || !params[0]->shape().IsTuple()) { + VLOG(3) << "While loop parameter is expected to be a tuple."; + return std::nullopt; + } + std::unique_ptr offset_module = ExtractModule( + /*instruction=*/ + idx, /*height=*/-1, + /*extract_selector=*/ + [loop_indvar, params](const HloInstruction* inst) -> bool { + return inst != loop_indvar && llvm::find(params, inst) == params.end(); + }, + /*replace_type_selector=*/ + [](const HloInstruction* inst) -> ReplaceType { + return ReplaceType::kReplaceParam; + }); + std::unique_ptr update_module = ExtractModule( + /*instruction=*/ + update_operation, /*height=*/-1, + /*extract_selector=*/ + [loop_indvar, params](const HloInstruction* inst) -> bool { + return inst != loop_indvar && llvm::find(params, inst) == params.end(); + }, + /*replace_type_selector=*/ + [](const HloInstruction* inst) -> ReplaceType { + return ReplaceType::kReplaceParam; + }); + if (!IsValidModule(offset_module) || !IsValidModule(update_module)) { + return std::nullopt; + } + VLOG(3) << "Successfully generated offset and update modules"; + + std::vector offset_values; + absl::Status status = [&]() -> absl::Status { + HloEvaluator evaluator; + const Literal& init = + whileop->operand(0)->operand(*loop_indvar_tuple_idx)->literal(); + std::unique_ptr updated_value = nullptr; + for (int64_t i = 0; i < *trip_count; i++) { + if (i == 0) { + evaluator.ResetVisitStates(); + TF_ASSIGN_OR_RETURN(offset_values.emplace_back(), + evaluator.Evaluate(*offset_module, {&init})); + CHECK(offset_values.back().shape() == idx->shape()); + evaluator.ResetVisitStates(); + TF_ASSIGN_OR_RETURN(Literal next_update_value, + evaluator.Evaluate(*update_module, {&init})); + updated_value = next_update_value.CloneToUnique(); + } else { + evaluator.ResetVisitStates(); + TF_ASSIGN_OR_RETURN( + offset_values.emplace_back(), + evaluator.Evaluate(*offset_module, {updated_value.get()})); + CHECK(offset_values.back().shape() == idx->shape()); + evaluator.ResetVisitStates(); + TF_ASSIGN_OR_RETURN( + Literal next_update_value, + evaluator.Evaluate(*update_module, {updated_value.get()})); + updated_value = next_update_value.CloneToUnique(); + } + } + VLOG(3) << "Offset values for " << idx->name() << ": " + << absl::StrJoin(offset_values, ",", + [](std::string* out, const Literal& l) { + out->append(l.ToString()); + }); + return absl::OkStatus(); + }(); + if (status.ok()) return offset_values; + return std::nullopt; } -// This returns true for the constants that are handled in the dynamic slice -// fusion runtime. These constants do not force a D2H copy and hence preserve -// the cuda graph. -bool IsHandledConstantForDynamicSliceFusion(const HloInstruction& offset) { - if (auto* cst = DynCast(&offset)) { - switch (cst->shape().element_type()) { - case PrimitiveType::S32: - case PrimitiveType::S64: - case PrimitiveType::U32: - case PrimitiveType::U64: - return true; - default: +// This function takes a while operation and adds a loop iteration counter +// variable as the last parameter in the loop. This is useful, especially +// because the loop induction variable might not be 0,1,2,3... and we need a +// variable of this form to access the array literal for offset. +absl::StatusOr AddLoopIterationParam(HloInstruction* whileop) { + CHECK(whileop->opcode() == HloOpcode::kWhile); + HloComputation* while_body = whileop->while_body(); + HloComputation* while_cond = whileop->while_condition(); + const HloInstruction* while_init = whileop->operand(0); + + // First handle the initial values. + CHECK(while_init->opcode() == HloOpcode::kTuple); + std::vector new_init_operands(while_init->operands().begin(), + while_init->operands().end()); + PrimitiveType indvar_type = + whileop->while_init() + ->operand(*GetLoopInductionVarTupleIdx(whileop)) + ->shape() + .element_type(); + new_init_operands.push_back(whileop->parent()->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + whileop->while_init() + ->operand(*GetLoopInductionVarTupleIdx(whileop)) + ->shape() + .element_type(), + 0)), + "zero")); + HloInstruction* new_while_init = whileop->parent()->AddInstruction( + HloInstruction::CreateTuple(new_init_operands)); + HloInstruction* new_whileop = whileop->parent()->AddInstruction( + whileop->CloneWithNewOperands(new_while_init->shape(), {new_while_init})); + if (whileop->IsRoot()) { + absl::InlinedVector tuple_entries; + tuple_entries.reserve(while_init->shape().tuple_shapes_size()); + for (auto i = 0; i < while_init->shape().tuple_shapes_size(); i++) { + tuple_entries.push_back(whileop->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(new_whileop, i))); + } + HloInstruction* new_whileop_result = whileop->parent()->AddInstruction( + HloInstruction::CreateTuple(tuple_entries)); + TF_RETURN_IF_ERROR( + whileop->parent()->ReplaceInstruction(whileop, new_whileop_result)); + } else { + TF_RETURN_IF_ERROR(whileop->parent()->ReplaceInstructionWithDifferentShape( + whileop, new_whileop)); + } + + // Next, lets handle the condition + while_cond->ReplaceParameter(0, HloInstruction::CreateParameter( + 0, new_while_init->shape(), "new_param")); + + // Next, lets handle the body + HloInstruction* new_body_param = while_body->ReplaceParameter( + 0, + HloInstruction::CreateParameter(0, new_while_init->shape(), "new_param")); + + // Next, update the value of the param inside while op + HloInstruction* gte = while_body->AddInstruction( + HloInstruction::CreateGetTupleElement( + new_body_param, new_while_init->shape().tuple_shapes_size() - 1), + "loop_iteration_count"); + HloInstruction* c1 = while_body->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(indvar_type, 1)), + "one"); + HloInstruction* add = while_body->AddInstruction( + HloInstruction::CreateBinary(gte->shape(), HloOpcode::kAdd, gte, c1), + "updated_loop_iteration_count"); + absl::InlinedVector old_return_tuple_operands = + while_body->root_instruction()->operands(); + std::vector new_return_tuple_operands( + old_return_tuple_operands.begin(), old_return_tuple_operands.end()); + new_return_tuple_operands.push_back(add); + HloInstruction* new_return_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(new_return_tuple_operands)); + while_body->set_root_instruction(new_return_tuple, true); + return gte; +} + +// This function takes an array literal and gives a constant instruction with +// that literal. +std::unique_ptr GetAsConstantInstruction( + const std::vector& offset_values) { + if (offset_values.empty()) return nullptr; + std::unique_ptr value = + primitive_util::PrimitiveTypeSwitch>( + [&offset_values]( + auto primitive_type_constant) -> std::unique_ptr { + if constexpr (primitive_util::IsIntegralType( + primitive_type_constant)) { + using NativeT = typename primitive_util::PrimitiveTypeToNative< + primitive_type_constant>::type; + + Array constantLiterals({(int64_t)offset_values.size()}); + std::vector valuesAsTy; + valuesAsTy.reserve(offset_values.size()); + for (auto& i : offset_values) { + valuesAsTy.push_back( + static_cast(i.data()[0])); + } + constantLiterals.SetValues(valuesAsTy); + return HloInstruction::CreateConstant( + LiteralUtil::CreateFromArray(constantLiterals)); + } + return nullptr; + }, + offset_values[0].shape().element_type()); + return value; +} + +// This function takes an operation, and a reference to a map of +// {operation: array literals containing their values}. If the operation is a +// dynamic slicing operation, we populate the value map with the values of the +// offsets. This only returns true if it can successfully find values +// corresponding to all the offsets in the `matched_instrs`. If there is a +// single offset for which we cannot find the values, then we do not add +// anything to the value map, and return false. +bool PopulateOffsetValueMap(const HloInstruction* matched_instr, + OffsetValueMap& value_map) { + OffsetValueMap local_value_map; + if (auto dyn_idx_op = DynCast(matched_instr); + dyn_idx_op) { + for (auto indexop : dyn_idx_op->index_operands()) { + if (indexop->IsConstant()) continue; + if (local_value_map.contains(indexop) || value_map.contains(indexop)) + continue; + std::optional> values = GetValues(indexop); + if (values == std::nullopt) return false; + if (values->empty() || !primitive_util::IsIntegralType( + values->at(0).shape().element_type())) { return false; - }; + } + std::transform(values->begin(), values->end(), + std::back_inserter(local_value_map[indexop]), + [](Literal& l) { return std::move(l); }); + } + } + for (auto& [op, values] : local_value_map) { + std::transform(values.begin(), values.end(), + std::back_inserter(value_map[op]), + [](Literal& l) { return std::move(l); }); } - return false; + VLOG(2) << "Received " << local_value_map.size() << " new offsets."; + return true; } -// This checks whether a dynamic index operation has all offsets that are either -// constant or loop iteration offsets. -bool HasConstantOrLoopIterationOffsets( - const HloDynamicIndexInstruction& instr) { - return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) { - return IsLoopIterationNumber(*offset) || - IsHandledConstantForDynamicSliceFusion(*offset); - }); +// This function takes a list of fusion instructions, and a value map +// {operation: array literal containing its values across iterations}. These +// fusions take the value of offset as a input. So, the value of this offset is +// calculated outside the fusion. This function changes these fusions so that +// the fusion instead only takes the loop iteration number and the offset is +// read from a constant array. This constant array comes from the value map. On +// a high level, the transform looks like: +// +// clang-format off +// +// input-fusion(p0, p1, p2, offset, c0) { +// ds = dynamic-slice(p0, offset, c0, c0) +// gemm = custom-call(ds, p1) +// ROOT dus = dynamic-update-slice(p2, gemm, offset, c0, c0) +// } +// +// changes to +// +// output-fusion(p0, p1, p2, loop_counter, c0) { +// offset_values = constant({2,4,6,8,10}) +// offset_array = dynamic-slice(offset_values, loop_counter), slice_size={1} +// offset = reshape(offset_array) +// ds = dynamic-slice(p0, offset, c0, c0) +// gemm = custom-call(ds, p1) +// ROOT dus = dynamic-update-slice(p2, gemm, offset, c0, c0) +// } +// +// clang-format on +absl::Status ReplaceOffsetCalculationWithArrayAccess( + PtrVec fusions, OffsetValueMap& value_map) { + absl::flat_hash_map loop_iteration_param; + for (auto& [instr, _] : value_map) { + VLOG(2) << "Handling " << instr->name(); + if (!instr->parent()->IsWhileBodyComputation()) { + VLOG(2) << "It is not a while body computation"; + return absl::InternalError( + absl::StrFormat("%s is expected to be a while computation.", + instr->parent()->name())); + } + if (loop_iteration_param.find(instr->parent()) != + loop_iteration_param.end()) { + VLOG(2) << "This was already handled"; + continue; + } + VLOG(2) << "Adding loop iteration param for " << instr->parent()->name(); + TF_ASSIGN_OR_RETURN( + loop_iteration_param[instr->parent()], + AddLoopIterationParam(instr->parent()->WhileCallInstruction())); + } + for (auto fusion_instr : fusions) { + // Check that this fusion operation has something we need to replace: + for (auto maybe_offset : fusion_instr->operands()) { + if (value_map.find(maybe_offset) == value_map.end()) continue; + HloInstruction* loop_counter = + loop_iteration_param[fusion_instr->parent()]; + HloComputation* fusion = fusion_instr->fused_instructions_computation(); + loop_iteration_param[fusion] = + fusion_instr->AddFusionOperand(loop_counter); + break; + } + } + for (auto fusion_instr : fusions) { + absl::flat_hash_map param_replacement_map; + absl::InlinedVector parameters; + HloComputation* fusion_comp = + fusion_instr->fused_instructions_computation(); + for (auto [idx, maybe_offset] : llvm::enumerate(fusion_instr->operands())) { + HloInstruction* offset_param = + fusion_instr->fused_instructions_computation()->parameter_instruction( + idx); + if (value_map.find(maybe_offset) == value_map.end() || + param_replacement_map.contains(offset_param)) + continue; + std::vector& values = value_map.at(maybe_offset); + std::unique_ptr values_as_const_instruction = + GetAsConstantInstruction(values); + if (values_as_const_instruction == nullptr) { + return absl::InternalError( + "Unable to convert offsets into constant array."); + } + HloInstruction* array = fusion_comp->AddInstruction( + std::move(values_as_const_instruction), "offset_values"); + HloInstruction* ds = + fusion_comp->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(offset_param->shape().element_type(), {1}), + array, {loop_iteration_param[fusion_comp]}, {1})); + HloInstruction* offset = fusion_comp->AddInstruction( + HloInstruction::CreateReshape(offset_param->shape(), ds), "offset"); + param_replacement_map[offset_param] = offset; + parameters.push_back(offset_param); + } + for (auto param = parameters.rbegin(); param != parameters.rend(); + param++) { + auto offset = param_replacement_map[*param]; + TF_RETURN_IF_ERROR(fusion_comp->ReplaceInstruction(*param, offset)); + } + } + return absl::OkStatus(); } -UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { +UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, + OffsetValueMap& value_map) { UseDefDataflowPaths sliced_operand_paths; // This set is used to avoid duplicates in the matched results. It contains @@ -292,14 +625,9 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { }); if (maybe_slice_instr == std::nullopt) continue; - auto dynamic_index_operation = - DynCast(maybe_slice_instr.value()); - bool valid_slice_found = - slice_found && - ((dynamic_index_operation && - HasConstantOrLoopIterationOffsets(*dynamic_index_operation)) || - (*maybe_slice_instr)->opcode() == HloOpcode::kSlice); - if (valid_slice_found || + bool valid_slice_status = + PopulateOffsetValueMap(*maybe_slice_instr, value_map); + if ((valid_slice_status && slice_found) || processed_instrs.contains(maybe_slice_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced operand path @@ -320,7 +648,8 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { // vector. // Each entry contains the sliced paths for that user, i.e. the sequence of ops // following the dataflow from the user itself to the DUS (included). -DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { +DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr, + OffsetValueMap& value_map) { DefUseDataflowPaths sliced_user_paths; // This set is used to avoid duplicates in the matched results. It contains // the matched instructions that we have seen so far. @@ -347,12 +676,10 @@ DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { }, /*visit_operands=*/false); if (maybe_dus_instr == std::nullopt) return; - auto dynamic_index_operation = - DynCast(maybe_dus_instr.value()); - bool valid_dus_found = - dus_found && dynamic_index_operation && - HasConstantOrLoopIterationOffsets(*dynamic_index_operation); - if (valid_dus_found || processed_instrs.contains(maybe_dus_instr.value())) { + bool valid_slice_status = + PopulateOffsetValueMap(*maybe_dus_instr, value_map); + if ((valid_slice_status && dus_found) || + processed_instrs.contains(maybe_dus_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced user path // during the latest traversal. @@ -519,6 +846,8 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( matches_kv; std::vector matches; + OffsetValueMap value_map; + // Collect all potential custom call matches in the non-fusion computations. for (HloComputation* computation : module->computations()) { if (computation->IsFusionComputation()) continue; @@ -526,9 +855,30 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( if ((instr->opcode() == HloOpcode::kReduceScatter && instr->shape().IsArray()) || IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) { - UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr); + UseDefDataflowPaths sliced_operand_paths = + GetSlicedOperandPaths(instr, value_map); + VLOG(1) << "For operation: " << instr->name() << ", operands: " + << absl::StrJoin( + sliced_operand_paths, ",", + [](std::string* out, const HloInstruction* inst) { + out->append(inst->name()); + }); bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; - DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); + DefUseDataflowPaths sliced_user_paths = + GetSlicedUserPaths(instr, value_map); + VLOG(1) << "For operation: " << instr->name() << ", users: " + << absl::StrJoin( + sliced_user_paths, ",", + [](std::string* out, const DefUseDataflowPath& path) { + out->append( + "{" + + absl::StrJoin(path, ",", + [](std::string* out, + const HloInstruction* inst) { + out->append(inst->name()); + }) + + "}"); + }); bool has_sliced_user_paths = absl::c_any_of( sliced_user_paths, [&](auto& sliced_user_path) { return !sliced_user_path.empty(); }); @@ -552,6 +902,8 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( if (matches.empty()) return false; + PtrVec fusions; + for (HloInstruction* hero : matches) { auto& paths = matches_kv[hero]; auto& [sliced_operand_paths, sliced_user_paths] = paths; @@ -580,7 +932,7 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( HloInstruction * fusion, CreateFusionInstruction(module, hero, captures, fusion_body, has_dynamic_slices)); - + fusions.push_back(fusion); HloComputation* parent = hero->parent(); if (fusion->shape().IsTuple()) { TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape( @@ -624,6 +976,9 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( } } + TF_RETURN_IF_ERROR( + ReplaceOffsetCalculationWithArrayAccess(fusions, value_map)); + return true; } diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc index 9a71c9930adc78..9fb49afb1847b2 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/stream.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -1857,12 +1858,15 @@ TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSLoopIterationOffset) { })"; const char* expected = R"( // CHECK: %dynamic-slice-fusion{{.*}}{ - // CHECK: {{.+}} = {{.*}}reduce-scatter({{.+}}) - // CHECK: {{.+}} = {{.*}}dynamic-update-slice({{.+}}) + // CHECK: {{.+}} = {{.*}} reduce-scatter({{.+}}) + // CHECK: {{.+}} = s32[128]{0} constant({{.+}}) + // CHECK: {{.+}} = {{.+}} dynamic-slice({{.+}}) + // CHECK: {{.+}} = {{.+}} reshape({{.+}}) + // CHECK: {{.+}} = {{.*}} dynamic-update-slice({{.+}}) // CHECK: } // CHECK: Body{{.+}}{ - // CHECK-NOT: {{.+}} = {{.*}}reduce-scatter({{.+}}) - // CHECK: {{.+}} = {{.+}}fusion({{.+}}), kind=kCustom, calls=%dynamic-slice-fusion{{.*}}"name":"dynamic_address_computation" + // CHECK-NOT: {{.+}} = {{.*}} reduce-scatter({{.+}}) + // CHECK: {{.+}} = {{.+}} fusion({{.+}}), kind=kCustom, calls=%dynamic-slice-fusion{{.*}}"name":"dynamic_address_computation" // CHECK: } )"; RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); @@ -1881,26 +1885,10 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) { bitcast.41 = f16[8,8]{1,0} bitcast(p0) bitcast.42 = f16[8,8]{1,0} bitcast(p1) - custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{ - "alpha_real":1, - "beta":0, - "dot_dimension_numbers":{ - "lhs_contracting_dimensions":["1"], - "rhs_contracting_dimensions":["0"], - "lhs_batch_dimensions":[], - "rhs_batch_dimensions":[] - }, - "alpha_imag":0, - "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, - "epilogue":"DEFAULT", - "lhs_stride":"64", - "rhs_stride":"64", - "grad_x":false, - "grad_y":false - }} + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), custom_call_target="__cublas$gemm" bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1) c0 = u32[] constant(0) - c_trip_count = u32[] constant(11) + c_trip_count = u32[] constant(8) compare = pred[] compare(loop_iter, c0), direction=LT add = u32[] add(loop_iter, c_trip_count) offset = u32[] select(compare, add, loop_iter) @@ -1913,7 +1901,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) { %Cond { %param.1 = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0) %i.1 = u32[] get-tuple-element(%param.1), index=3 - %trip_count = u32[] constant(11) + %trip_count = u32[] constant(8) ROOT %done = pred[] compare(u32[] %i.1, u32[] %trip_count), direction=LT } @@ -1923,19 +1911,32 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) { %p2.1 = f16[4,8,8]{2,1,0} parameter(2) %c0.1 = u32[] constant(0) %initial_tuple = tuple(%p0.1, %p1.1, %p2.1, u32[] %c0.1) - ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"11"}} + ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"8"}} })"; const char* expected = R"( + // CHECK: %dynamic-slice-fusion{{.*}} { + // CHECK-DAG: %[[p0:.+]] = f16[8,8]{1,0} parameter(0) + // CHECK-DAG: %[[p1:.+]] = f16[8,8]{1,0} parameter(1) + // CHECK-DAG: %[[p2:.+]] = f16[4,8,8]{2,1,0} parameter(2) + // CHECK-DAG: %[[gemm:.+]] = f16[8,8]{1,0} custom-call(%[[p0]], %[[p1]]), custom_call_target="__cublas$gemm" + // CHECK-DAG: %[[bc_gemm:.+]] = f16[1,8,8]{2,1,0} bitcast(%[[gemm]]) + // CHECK-DAG: %[[offset_values:.+]] = u32[8]{0} constant({0, 1, 2, 3, 4, 5, 6, 7}) + // CHECK-DAG: %[[p4:.+]] = u32[] parameter(4) + // CHECK-DAG: %[[offset_as_array:.+]] = u32[1]{0} dynamic-slice(%[[offset_values]], %[[p4]]), dynamic_slice_sizes={1} + // CHECK-DAG: %[[offset:.+]] = u32[] reshape(%[[offset_as_array]]) + // CHECK-DAG: %[[p3:.+]] = u32[] parameter(3) + // CHECK-DAG: ROOT %{{.+}} = f16[4,8,8]{2,1,0} dynamic-update-slice(%[[p2]], %[[bc_gemm]], %[[offset]], %[[p3]], %[[p3]]) + // CHECK: } // CHECK: %Body{{.+}}{ // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0) - // CHECK: %[[LOOP_ITER:.+]] = u32[] get-tuple-element(%[[PARAM]]), index=3 - // CHECK: %[[OFFSET:.+]] = u32[] select({{.+}}) - // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, {{.+}}, {{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%dynamic-slice-fusion, {{.+}}"name":"dynamic_address_computation" - // CHECK: ROOT %tuple = {{.+}} tuple(%{{.+}}, %{{.+}}, %[[ADDRESS_COMPUTATION]], %{{.+}}) + // CHECK: %[[LOOP_ITER:.+]] = u32[] get-tuple-element(%[[PARAM]]), index=4 + // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion(%{{.+}}, %loop_iteration_count), kind=kCustom, calls=%dynamic-slice-fusion + // CHECK: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %{{.+}}, %[[ADDRESS_COMPUTATION]], %{{.+}}) // CHECK: } // CHECK: ENTRY %test{{.+}}{ - // CHECK: ROOT %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"11"}} + // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"8"}} + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}}) } )"; @@ -1984,6 +1985,86 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmParameterOffset) { std::nullopt); } +TEST_F(DynamicSliceFusionRewriterTest, DUSOffsetAsFunctionOfLoopIteration) { + const char* hlo = R"( + HloModule test_module, replica_count=2 + + add { + a = s64[] parameter(0) + b = s64[] parameter(1) + ROOT add = s64[] add(a, b) + } + + Body { + param = (s64[], s64[16, 32], s64[8, 32]) parameter(0) + i = s64[] get-tuple-element(param), index=0 + dest = s64[16,32] get-tuple-element(param), index=1 + src = s64[8,32] get-tuple-element(param), index=2 + eight = s64[] constant(8) + zero = s64[] constant(0) + thirty_two = s64[] constant(32) + add = s64[] add(eight, i) + add.2 = s64[] subtract(add, thirty_two) + compare = pred[] compare(add, thirty_two), direction=LT + offset = s64[] select(compare, add, add.2) + rs = s64[4,32] reduce-scatter(src), channel_id=1, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + fusion = s64[16,32] dynamic-update-slice(dest, rs, offset, zero) + one = s64[] constant(1) + i_plus_one = s64[] add(i, one) + ROOT tuple = tuple(i_plus_one, fusion, src) + } + + Cond { + param = (s64[], s64[16,32], s64[8,32]) parameter(0) + loop_iter = s64[] get-tuple-element(param), index=0 + c16 = s64[] constant(16) + ROOT compare = pred[] compare(loop_iter, c16), direction=LT + } + + ENTRY main { + zero = s64[] constant(0) + dest = s64[16,32] parameter(0) + src = s64[8,32] parameter(1) + tuple = tuple(zero, dest, src) + ROOT while = while(tuple), body=Body, condition=Cond + } + )"; + + const char* expected = R"( + // CHECK: %dynamic-slice-fusion{{.*}} { + // CHECK-DAG: %[[p1:.*]] = s64[16,32]{1,0} parameter(1) + // CHECK-DAG: %[[p0:.*]] = s64[8,32]{1,0} parameter(0) + // CHECK-DAG: %[[rs:.+]] = s64[4,32]{1,0} reduce-scatter(s64[8,32]{1,0} %[[p0]]), channel_id=1 + // CHECK-DAG: %[[offset_values:.+]] = s64[16]{0} constant({8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}) + // CHECK-DAG: %[[p3:.+]] = s64[] parameter(3) + // CHECK-DAG: %[[ds:.+]] = s64[1]{0} dynamic-slice(s64[16]{0} %[[offset_values]], s64[] %[[p3]]), dynamic_slice_sizes={1} + // CHECK-DAG: %[[offset:.+]] = s64[] reshape(s64[1]{0} %[[ds]]) + // CHECK-DAG: %[[p2:.+]] = s64[] parameter(2) + // CHECK-DAG: ROOT %{{.+}} = s64[16,32]{1,0} dynamic-update-slice(s64[16,32]{1,0} %[[p1:.*]], s64[4,32]{1,0} %[[rs]], s64[] %[[offset]], s64[] %[[p2]]) + // CHECK: } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN( + auto changed, + RunHloPass(DynamicSliceFusionRewriter("gpu"), module.get())); + EXPECT_TRUE(changed); + std::vector fusions; + for (auto computation : module->computations()) { + if (computation->IsFusionComputation()) { + fusions.push_back(computation); + } + } + ASSERT_EQ(fusions.size(), 1); + const HloComputation* dynamic_slice_fusion = fusions[0]; + TF_ASSERT_OK_AND_ASSIGN( + auto filecheck_match, + RunFileCheck(dynamic_slice_fusion->ToString( + HloPrintOptions{}.set_print_large_constants(true)), + expected)); + EXPECT_TRUE(filecheck_match); +} + TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) { const char* hlo = R"( HloModule lax_scan @@ -1995,68 +2076,71 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) { // ans = jax.lax.scan(lambda carry, x : (init, x@carry), init, inp) Body { - arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + arg_tuple.15 = (s32[], f32[8,8]{1,0}, f32[8,8,8]{2,1,0}, f32[8,8,8]{2,1,0}, f32[8,8]{1,0}) parameter(0) get-tuple-element.16 = s32[] get-tuple-element(arg_tuple.15), index=0 constant.21 = s32[] constant(1) add.2 = s32[] add(get-tuple-element.16, constant.21) - get-tuple-element.30 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=4 - get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=2 - get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=3 + get-tuple-element.30 = get-tuple-element(arg_tuple.15), index=4 + get-tuple-element.18 = get-tuple-element(arg_tuple.15), index=2 + get-tuple-element.19 = get-tuple-element(arg_tuple.15), index=3 constant.23 = s32[] constant(0) compare.2 = pred[] compare(get-tuple-element.16, constant.23), direction=LT - constant.22 = s32[] constant(128) + constant.22 = s32[] constant(8) add.3 = s32[] add(get-tuple-element.16, constant.22) select.1 = s32[] select(compare.2, add.3, get-tuple-element.16) - dynamic-slice.1 = f32[1,128,128]{2,1,0} dynamic-slice(get-tuple-element.19, select.1, constant.23, constant.23), dynamic_slice_sizes={1,128,128} - bitcast.72 = f32[128,128]{1,0} bitcast(dynamic-slice.1) - get-tuple-element.17 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=1 - custom-call.1 = (f32[128,128]{1,0}, s8[131072]{0}) custom-call(bitcast.72, get-tuple-element.17), custom_call_target="__cublas$gemm" - get-tuple-element = f32[128,128]{1,0} get-tuple-element(custom-call.1), index=0 - bitcast.77 = f32[1,128,128]{2,1,0} bitcast(get-tuple-element) - dynamic-update-slice.1 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.18, bitcast.77, select.1, constant.23, constant.23) + dynamic-slice.1 = f32[1,8,8]{2,1,0} dynamic-slice(get-tuple-element.19, select.1, constant.23, constant.23), dynamic_slice_sizes={1,8,8} + bitcast.72 = f32[8,8]{1,0} bitcast(dynamic-slice.1) + get-tuple-element.17 = f32[8,8]{1,0} get-tuple-element(arg_tuple.15), index=1 + custom-call.1 = (f32[8,8]{1,0}, s8[131072]{0}) custom-call(bitcast.72, get-tuple-element.17), custom_call_target="__cublas$gemm" + get-tuple-element = f32[8,8]{1,0} get-tuple-element(custom-call.1), index=0 + bitcast.77 = f32[1,8,8]{2,1,0} bitcast(get-tuple-element) + dynamic-update-slice.1 = f32[8,8,8]{2,1,0} dynamic-update-slice(get-tuple-element.18, bitcast.77, select.1, constant.23, constant.23) ROOT tuple.38 = tuple(add.2, get-tuple-element.30, dynamic-update-slice.1, get-tuple-element.19, get-tuple-element.30) } // Body Cond { - arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + arg_tuple.40 = (s32[], f32[8,8]{1,0}, f32[8,8,8]{2,1,0}, f32[8,8,8]{2,1,0}, f32[8,8]{1,0}) parameter(0) get-tuple-element.41 = s32[] get-tuple-element(arg_tuple.40), index=0 - constant.46 = s32[] constant(128) + constant.46 = s32[] constant(8) ROOT compare.3 = pred[] compare(get-tuple-element.41, constant.46), direction=LT } ENTRY main { constant.4 = s32[] constant(0) - Arg_1.2 = f32[128,128]{1,0} parameter(1) + Arg_1.2 = f32[8,8]{1,0} parameter(1) constant.5 = f32[] constant(0) - broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={} - Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2) - Arg_0.1 = f32[128,128]{1,0} parameter(0) + broadcast.1 = f32[8,8,8]{2,1,0} broadcast(constant.5), dimensions={} + Arg_2.3 = f32[8,8,8]{2,1,0} parameter(2) + Arg_0.1 = f32[8,8]{1,0} parameter(0) tuple.7 = tuple(constant.4, Arg_1.2, broadcast.1, Arg_2.3, Arg_0.1) - while.48 = while(tuple.7), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}} - get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while.48), index=1 - get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while.48), index=2 - ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51) + while.48 = while(tuple.7), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"8"}} + get-tuple-element.50 = get-tuple-element(while.48), index=1 + get-tuple-element.51 = get-tuple-element(while.48), index=2 + ROOT tuple.54 = tuple(get-tuple-element.50, get-tuple-element.51) } // main.55 )"; - auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); const char* expected = R"( - // CHECK: %dynamic-slice-fusion{{.*}} {{.+}} { - // CHECK: {{.+}} = {{.+}}dynamic-slice - // CHECK: {{.+}} = {{.+}}custom-call - // CHECK: {{.+}} = {{.+}}dynamic-update-slice - // CHECK: } - // CHECK: %Body{{.+}}{ - // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0) - // CHECK: %[[LOOP_ITER:.+]] = s32[] get-tuple-element(%[[PARAM]]), index=0 - // CHECK: %[[OFFSET:.+]] = s32[] select({{.+}}) - // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%dynamic-slice-fusion{{.+}}"name":"dynamic_address_computation" - // CHECK: %[[GTE:.+]] = {{.+}} get-tuple-element(%[[ADDRESS_COMPUTATION]]), index=0 - // CHECK: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %[[GTE]], %{{.+}}) - // CHECK: } - // CHECK: ENTRY %main{{.+}}{ - // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"128"}} - // CHECK: } + // CHECK: %dynamic-slice-fusion{{.*}} {{.+}} { + // CHECK-DAG: %[[ITER:.+]] = s32[] parameter(4) + // CHECK-DAG: %[[OFFSET_VALUES:.+]] = s32[8]{0} constant({0, 1, 2, 3, 4, 5, 6, 7}) + // CHECK-DAG: %[[OFFSET_ARR:.+]] = s32[1]{0} dynamic-slice(%[[OFFSET_VALUES]], %[[ITER]]), dynamic_slice_sizes={1} + // CHECK-DAG: %[[OFFSET:.+]] = s32[] reshape(%[[OFFSET_ARR]]) + // CHECK-DAG: %[[DS:.+]] = f32[1,8,8]{2,1,0} dynamic-slice({{.+}}, %[[OFFSET]], {{.+}}), dynamic_slice_sizes={1,8,8} + // CHECK-DAG: %[[BITCAST:.+]] = {{.+}} bitcast(%[[DS]]) + // CHECK-DAG: %[[GEMM:.+]] = {{.+}} custom-call(%[[BITCAST]], {{.+}}), custom_call_target="__cublas$gemm" + // CHECK-DAG: %[[DUS:.+]] = {{.+}} dynamic-update-slice({{.+}}, %[[OFFSET]], {{.+}}) + // CHECK: } + // CHECK: %Body{{.+}}{ + // CHECK-DAG: %[[PARAM:.+]] = {{.+}} parameter(0) + // CHECK-DAG: %[[LOOP_ITER:.+]] = s32[] get-tuple-element(%[[PARAM]]), index=5 + // CHECK-DAG: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, %{{.+}}, %[[LOOP_ITER]]), kind=kCustom, calls=%dynamic-slice-fusion{{.+}}"name":"dynamic_address_computation" + // CHECK-DAG: %[[GTE:.+]] = {{.+}} get-tuple-element(%[[ADDRESS_COMPUTATION]]), index=0 + // CHECK-DAG: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %[[GTE]], %{{.+}}) + // CHECK: } + // CHECK: ENTRY %main{{.+}}{ + // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"8"}} + // CHECK: } )"; RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); } @@ -2144,4 +2228,90 @@ TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDynamicSlice) { RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); } +// This is not a requirement from the DynamicSliceFusionRewriter, but this tests +// the current behavior so that the removal of this is intentional. +TEST_F(DynamicSliceFusionRewriterTest, ReplicaIdAndPartitionIdAsOffset) { + const char* hlo = R"( + HloModule test_module, replica_count=2, num_partitions=2 + ENTRY main { + p0 = s32[32,32] parameter(0) + p1 = s32[32,32] parameter(1) + p2 = s32[64,32] parameter(2) + c10 = u32[] constant(10) + c0 = u32[] constant(0) + + // This should get fused. + call1 = s32[32,32] custom-call(p0, p1), custom_call_target="__cublas$gemm" + dus1 = s32[64,32] dynamic-update-slice(p2, call1, c10, c0) + + // This should not get fused. + replica = u32[] replica-id() + call2 = s32[32,32] custom-call(p0, p1), custom_call_target="__cublas$gemm" + dus2 = s32[64,32] dynamic-update-slice(p2, call2, replica, c0) + + // This should not get fused. + partition = u32[] partition-id() + call3 = s32[32,32] custom-call(p0, p1), custom_call_target="__cublas$gemm" + dus3 = s32[64,32] dynamic-update-slice(p2, call3, partition, c0) + ROOT tuple = tuple(dus1, dus2, dus3) + } + )"; + + const char* expected = R"( + // CHECK: dynamic-slice-fusion{{.*}} { + // CHECK: custom-call + // CHECK: dynamic-update-slice + // CHECK: } + // CHECK: ENTRY {{.+}} { + // CHECK-DAG: %{{.+}} = {{.+}} fusion({{.+}}) + // CHECK-DAG: %[[call2:.+]] = {{.+}} custom-call({{.+}}) + // CHECK-DAG: %[[replica:.+]] = u32[] replica-id() + // CHECK-DAG: %{{.+}} = {{.+}} dynamic-update-slice({{.+}} %[[call2]], %[[replica]], {{.+}}) + // CHECK-DAG: %[[partition:.+]] = u32[] partition-id() + // CHECK-DAG: %[[call3:.+]] = {{.+}} custom-call({{.+}}) + // CHECK-DAG: %{{.+}} = {{.+}} dynamic-update-slice({{.+}} %[[call3]], %[[partition]], {{.+}}) + // CHECK-DAG: ROOT {{.+}} = {{.+}} tuple({{.+}}) + // CHECK: } + )"; + + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + +TEST_F(DynamicSliceFusionRewriterTest, ParameterOffsetThroughWhileLoop) { + const char* hlo = R"( + HloModule test + Body { + p = (s32[], s32[32,32], s32[32,32], s32[64,32], s32[]) parameter(0) + i = get-tuple-element(p), index=0 + p0 = get-tuple-element(p), index=1 + p1 = get-tuple-element(p), index=2 + p2 = s32[64,32] get-tuple-element(p), index=3 + offset = s32[] get-tuple-element(p), index=4 + c0 = s32[] constant(0) + call = s32[32,32] custom-call(p0, p1), custom_call_target="__cublas$gemm" + dus = s32[64,32] dynamic-update-slice(p2, call, offset, c0) + c1 = s32[] constant(1) + i_plus_one = add(i, c1) + ROOT tuple = tuple(i_plus_one, p1, p0, dus, offset) + } + Cond { + p = (s32[], s32[32,32], s32[32,32], s32[64,32], s32[]) parameter(0) + i = get-tuple-element(p), index=0 + c4 = s32[] constant(4) + ROOT compare = compare(i, c4), direction=LT + } + ENTRY main { + offset = s32[] parameter(0) + p0 = s32[32,32] parameter(1) + p1 = s32[32,32] parameter(2) + p2 = s32[64,32] parameter(3) + c0 = s32[] constant(0) + tuple = tuple(c0, p0, p1, p2, offset) + ROOT while = while(tuple), body=Body, condition=Cond + } + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), + std::nullopt); +} + } // namespace xla::gpu From e901de87c68a9c9374d90117ee6954870d662657 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Fri, 20 Sep 2024 05:53:21 -0700 Subject: [PATCH 063/483] [XLA:GPU] "NOOP" refactoring of the FusionDecision API Let's introduce FusionDecision::Allow and FusionDecision::Forbid static methods and use them everywhere. As a result of that we will have a bit more code but it will be more readable and controllable. PiperOrigin-RevId: 676814132 --- .../xla/service/cpu/cpu_instruction_fusion.cc | 33 ++++--- .../gpu/fusions/triton/triton_support.cc | 35 ++++---- .../fusions/triton/triton_support_legacy.cc | 47 +++++----- .../xla/xla/service/gpu/gpu_fusible.cc | 80 +++++++++-------- .../model/gpu_indexing_performance_model.cc | 2 +- .../gpu/model/symbolic_tile_analysis.cc | 32 ++++--- .../service/gpu/transforms/fusion_merger.cc | 18 ++-- .../xla/service/gpu/transforms/gemm_fusion.cc | 16 ++-- .../gpu/transforms/instruction_fusion.cc | 20 +++-- .../gpu/transforms/multi_output_fusion.cc | 36 ++++---- .../service/gpu/transforms/priority_fusion.cc | 48 +++++----- .../gpu/transforms/softmax_rewriter_triton.cc | 54 ++++++----- .../service/gpu/triton_tiling_propagation.cc | 89 ++++++++++--------- .../xla/xla/service/instruction_fusion.cc | 64 +++++++------ .../xla/xla/service/instruction_fusion.h | 38 ++++---- .../xla/service/instruction_fusion_test.cc | 16 ++-- 16 files changed, 343 insertions(+), 285 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index 1743cad5c2c5ec..957bcfb22d851c 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -79,22 +79,23 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, if (CanBeOutputFused(producer, consumer)) { VLOG(2) << "Fusion OK: Can create output fusion."; - return {}; + return FusionDecision::Allow(); } if (CanBeOutputFusedIntoSomeOperand(producer)) { - return "Bailing because producer can be output-fused into some operand."; + return FusionDecision::Forbid( + "Bailing because producer can be output-fused into some operand."); } if (!CanBeLoopFused(*producer)) { - return "Producer is not loop-fusible."; + return FusionDecision::Forbid("Producer is not loop-fusible."); } // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && ReusesOperandElements(consumer, operand_index)) { - return "Fusion is not profitable."; + return FusionDecision::Forbid("Fusion is not profitable."); } RETURN_IF_NOT_FUSIBLE(InstructionFusion::ShouldFuse(consumer, operand_index)); @@ -103,12 +104,14 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // just a constant and another node. if (producer->opcode() == HloOpcode::kConstant && consumer->opcode() != HloOpcode::kFusion) { - return "Not fusing: insufficient non-constant nodes."; + return FusionDecision::Forbid( + "Not fusing: insufficient non-constant nodes."); } // Output fusion is not currently supported on CPUs. if (producer->opcode() == HloOpcode::kFusion) { - return "Not fusing: producer is itself a fusion node."; + return FusionDecision::Forbid( + "Not fusing: producer is itself a fusion node."); } // Don't fuse if fusing would cause too much code duplication because of @@ -126,7 +129,7 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh( producer)) { - return "Code duplication too high"; + return FusionDecision::Forbid("Code duplication too high"); } } @@ -149,13 +152,13 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, ShapeUtil::ByteSizeOfElements(consumer->operand(0)->shape()) < kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; - return {}; + return FusionDecision::Allow(); } else if (consumer->operand(1)->shape().rank() == 1 && operand_index == 0 && ShapeUtil::ByteSizeOfElements(consumer->operand(1)->shape()) < kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; - return {}; + return FusionDecision::Allow(); } } } @@ -166,26 +169,28 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, !absl::c_linear_search( consumer->dimensions(), LayoutUtil::Minor(consumer->operand(0)->shape().layout(), 0))) { - return "Not fusing reductions over major dimensions"; + return FusionDecision::Forbid( + "Not fusing reductions over major dimensions"); } if (producer->opcode() == HloOpcode::kReduce && !absl::c_linear_search( producer->dimensions(), LayoutUtil::Minor(producer->operand(0)->shape().layout(), 0))) { - return "Not fusing reductions over major dimensions"; + return FusionDecision::Forbid( + "Not fusing reductions over major dimensions"); } if (consumer->IsLoopFusion()) { VLOG(2) << "Fusing: consumer is a fusion node."; - return {}; + return FusionDecision::Allow(); } if (CanBeLoopFused(*consumer)) { VLOG(2) << "Fusing: consumer is elementwise or fusible."; - return {}; + return FusionDecision::Allow(); } - return "Not fusing: not found a fusible case"; + return FusionDecision::Forbid("Not fusing: not found a fusible case"); } HloInstruction::FusionKind CpuInstructionFusion::ChooseKind( diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc index 69eeb7461bba6f..0a4d5c0e9c18e5 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc @@ -117,7 +117,7 @@ CodegenDecision IsTritonSupportedConversion( }; auto error_message = [&]() { - return CodegenDecision( + return CodegenDecision::Forbid( absl::StrCat("Unsupported conversion in Triton: ", primitive_util::LowercasePrimitiveTypeName(input), " to ", primitive_util::LowercasePrimitiveTypeName(output))); @@ -139,7 +139,7 @@ CodegenDecision IsTritonSupportedConversion( if (IsTritonSupportedDataType(input, gpu_version) && (IsTritonSupportedDataType(output, gpu_version) || output == PrimitiveType::S4)) { - return CodegenDecision{}; + return CodegenDecision::Allow(); } return error_message(); @@ -224,7 +224,8 @@ CodegenDecision CanTritonHandleReduce( const se::GpuComputeCapability& gpu_version) { if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN || reduce.shape().element_type() == PrimitiveType::F8E5M2) { - return "F8E4M3FN and F8E5M2 are not supported for reductions."; + return CodegenDecision::Forbid( + "F8E4M3FN and F8E5M2 are not supported for reductions."); } bool is_triton_supported_reduction_computation = absl::c_all_of( @@ -232,19 +233,21 @@ CodegenDecision CanTritonHandleReduce( return IsTritonSupportedInstructionImpl(*instr, gpu_version).CanFuse(); }); if (!is_triton_supported_reduction_computation) { - return "Unsupported reduction computation by Triton."; + return CodegenDecision::Forbid( + "Unsupported reduction computation by Triton."); } if (reduce.dimensions().size() == 1 && reduce.operand_count() == 2) { - return CodegenDecision{}; + return CodegenDecision::Allow(); } - return "Reduction is not a row-reduction of a single operand."; + return CodegenDecision::Forbid( + "Reduction is not a row-reduction of a single operand."); } CodegenDecision IsTritonSupportedInstructionImpl( const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { if (internal::IsTritonUnsupportedOpcode(instr.opcode())) { - return "Unsupported opcode."; + return CodegenDecision::Forbid("Unsupported opcode."); } // Special handling for the kConvert instruction, which has a non-standard @@ -259,7 +262,7 @@ CodegenDecision IsTritonSupportedInstructionImpl( bool output_type_is_supported = IsTritonSupportedDataType(type, gpu_version); if (!output_type_is_supported) { - return "Unsupported output data type."; + return CodegenDecision::Forbid("Unsupported output data type."); } bool input_types_are_supported = @@ -269,16 +272,16 @@ CodegenDecision IsTritonSupportedInstructionImpl( }); if (!input_types_are_supported) { - return "Unsupported input data type."; + return CodegenDecision::Forbid("Unsupported input data type."); } // Const is technically an elementwise op, so this check must be before the // elementwise check. if (instr.opcode() == HloOpcode::kConstant) { return ShapeUtil::IsScalar(instr.shape()) - ? CodegenDecision{} - : CodegenDecision{ - "Only scalar constants are supported in Triton."}; + ? CodegenDecision::Allow() + : CodegenDecision::Forbid( + "Only scalar constants are supported in Triton."); } if (instr.IsElementwise()) { @@ -289,9 +292,9 @@ CodegenDecision IsTritonSupportedInstructionImpl( // operand. instr.operand(instr.operand_count() - 1)->shape().element_type(), gpu_version)) { - return "Unsupported elementwise operation."; + return CodegenDecision::Forbid("Unsupported elementwise operation."); } - return CodegenDecision{}; + return CodegenDecision::Allow(); } // TODO(bchetioui): support kDot, kPad, and kDynamicSlice. @@ -306,12 +309,12 @@ CodegenDecision IsTritonSupportedInstructionImpl( case HloOpcode::kBroadcast: case HloOpcode::kBitcast: case HloOpcode::kReshape: - return CodegenDecision{}; + return CodegenDecision::Allow(); default: VLOG(2) << "Unsupported instruction: " << instr.ToString(); break; } - return "Unsupported opcode."; + return CodegenDecision::Forbid("Unsupported opcode."); } } // namespace diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc index b07630b7cb7734..8280accb99e10f 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/triton/triton_support.h" - #include #include #include @@ -28,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/variant_visitor.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" @@ -115,7 +114,7 @@ bool IsTritonSupportedDataType(PrimitiveType type, CodegenDecision IsInstructionSupportsDataTypes( const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { if (!IsTritonSupportedDataType(instr.shape().element_type(), gpu_version)) { - return "Unsupported output data type."; + return CodegenDecision::Forbid("Unsupported output data type."); } for (const HloInstruction* operand : instr.operands()) { @@ -133,11 +132,11 @@ CodegenDecision IsInstructionSupportsDataTypes( [[fallthrough]]; default: if (!IsTritonSupportedDataType(operand_type, gpu_version)) { - return "Unsupported input data type."; + return CodegenDecision::Forbid("Unsupported input data type."); } } } - return CodegenDecision{}; + return CodegenDecision::Allow(); } std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( @@ -211,12 +210,12 @@ CodegenDecision CanTritonHandleElementwise( return decision; } if (instr.opcode() == HloOpcode::kConstant) { - return CodegenDecision{}; + return CodegenDecision::Allow(); } else if (!IsTritonSupportedElementwiseUpToFloatNormalization( instr.opcode(), instr.operand(0)->shape().element_type())) { - return "Unsupported elementwise operation."; + return CodegenDecision::Forbid("Unsupported elementwise operation."); } - return CodegenDecision{}; + return CodegenDecision::Allow(); } bool IsDotAlgorithmSupportedByTriton( @@ -268,37 +267,39 @@ CodegenDecision CanTritonHandleGEMM( if (!tsl::tensor_float_32_execution_enabled() || absl::c_any_of(dot.precision_config().operand_precision(), [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return "Having non-default operand precisions or TensorFloat-32 disabled " - "for Dot op with unset algorithm."; + return CodegenDecision::Forbid( + "Having non-default operand precisions or TensorFloat-32 disabled " + "for Dot op with unset algorithm."); } } else { if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), gpu_version)) { - return "Unsupported algorithm on the current device(s)."; + return CodegenDecision::Forbid( + "Unsupported algorithm on the current device(s)."); } } // TODO(b/266862493): Support more output types. if (!IsTritonSupportedDotOutputType(dot.shape().element_type(), gpu_version)) { - return "Unsupported output data type for Dot op."; + return CodegenDecision::Forbid("Unsupported output data type for Dot op."); } if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), gpu_version) || !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), gpu_version)) { - return "Unsupported input data type for Dot op."; + return CodegenDecision::Forbid("Unsupported input data type for Dot op."); } const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // TODO(b/269580541): support multiple batch dimensions. if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return "Multiple batch dimensions."; + return CodegenDecision::Forbid("Multiple batch dimensions."); } - return CodegenDecision{}; + return CodegenDecision::Allow(); } bool NoNonContractingDimension(const HloDotInstruction& dot) { @@ -323,7 +324,7 @@ CodegenDecision IsTritonSupportedDynamicSlice( case S32: break; // supported default: - return CodegenDecision( + return CodegenDecision::Forbid( "Dynamic slice is only supported with S8, S16, or S32 indices."); } } @@ -341,14 +342,14 @@ CodegenDecision IsTritonSupportedDynamicSlice( if (i == majormost_dim_id) { continue; } else if (input->shape().dimensions(i) != instr.slice_sizes(i)) { - return CodegenDecision( + return CodegenDecision::Forbid( "Unsupported dynamic slice on non-major-most dimension."); } } // TODO(b/343143854): Check the subtleties of which dynamic slices are // supported, for example that a fragmented dimension cannot be sliced. - return CodegenDecision{}; + return CodegenDecision::Allow(); } CodegenDecision IsTritonSupportedInstruction( @@ -362,15 +363,15 @@ CodegenDecision IsTritonSupportedInstruction( auto* dot = Cast(&instr); // Cases where lhs or rhs have no non-contracting dims are not handled. if (NoNonContractingDimension(*dot)) { - return "No non-contracting dimensions."; + return CodegenDecision::Forbid("No non-contracting dimensions."); } return CanTritonHandleGEMM(*dot, gpu_version); } case HloOpcode::kTuple: { if (instr.IsRoot()) { - return CodegenDecision{}; + return CodegenDecision::Allow(); } - return "Only supports root tuples."; + return CodegenDecision::Forbid("Only supports root tuples."); } case HloOpcode::kDynamicSlice: { return IsTritonSupportedDynamicSlice( @@ -384,11 +385,11 @@ CodegenDecision IsTritonSupportedInstruction( case HloOpcode::kConcatenate: case HloOpcode::kParameter: case HloOpcode::kBroadcast: - return CodegenDecision{}; + return CodegenDecision::Allow(); default: break; } - return "Unsupported opcode."; + return CodegenDecision::Forbid("Unsupported opcode."); } } // namespace legacy_triton diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index bae2e880a9f199..94e67e43c1adb6 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -266,15 +266,15 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, if (hero1_is_unnested_reduce && hero2_is_unnested_reduce && !AreReductionsMultiOutputFusionCompatible(hero2, hero1)) { - return "tiled reductions with different shapes"; + return FusionDecision::Forbid("tiled reductions with different shapes"); } else if (hero1_is_unnested_transpose && hero2_is_unnested_transpose && // After normalization to rank 3, the transposes should have the // same shape and permute the same dimensions. !tiled_transpose_hero1->IsEquivalent(*tiled_transpose_hero2)) { - return "tiled transposes with different shapes"; + return FusionDecision::Forbid("tiled transposes with different shapes"); } else if ((hero1_is_unnested_transpose && hero2_is_unnested_reduce) || (hero1_is_unnested_reduce && hero2_is_unnested_transpose)) { - return "MOF-fusion of a transpose and a reduction"; + return FusionDecision::Forbid("MOF-fusion of a transpose and a reduction"); } // If we are dealing with unnested transpose, make sure that we can still // treat them as unnested transpose after the sibling fusion. @@ -303,18 +303,18 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, int64_t operand_idx = fusion2->operand_index(fusion1); auto hlo = fusion2->fused_parameter(operand_idx); if (!check_path_of_intermediate_ops(hlo)) { - return "tiled transpose would become untiled"; + return FusionDecision::Forbid("tiled transpose would become untiled"); } } else if (hero2_is_unnested_transpose && fusion1->IsUserOf(fusion2)) { int64_t operand_idx = fusion1->operand_index(fusion2); auto hlo = fusion1->fused_parameter(operand_idx); if (!check_path_of_intermediate_ops(hlo)) { - return "tiled transpose would become untiled"; + return FusionDecision::Forbid("tiled transpose would become untiled"); } } } } - return {}; + return FusionDecision::Allow(); } FusionDecision ShapesCompatibleForMultiOutputFusion( @@ -356,9 +356,9 @@ FusionDecision ShapesCompatibleForMultiOutputFusion( (!accept_unequal_shape || !ShapeUtil::IsReshapeOrTransposeBitcast(l1, l2, /*ignore_element_type=*/true))) { - return "different loop shapes"; + return FusionDecision::Forbid("different loop shapes"); } - return {}; + return FusionDecision::Allow(); } bool IsInputFusibleScatter(const HloInstruction& instr) { @@ -469,10 +469,10 @@ static bool AllSatisfy(const HloInstruction& instr, FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, const HloInstruction& consumer) { if (IsInputFusibleScatter(producer)) { - return "do not fuse into the output of scatter"; + return FusionDecision::Forbid("do not fuse into the output of scatter"); } if (!IsInputFusibleScatter(consumer)) { - return {}; + return FusionDecision::Allow(); } const HloInstruction* inplace_operand; @@ -485,19 +485,21 @@ FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, inplace_operand = consumer.operand(0); } if (inplace_operand == &producer) { - return "do not fuse into the in-place operand of scatter"; + return FusionDecision::Forbid( + "do not fuse into the in-place operand of scatter"); } if (absl::c_linear_search(producer.operands(), inplace_operand)) { - return "Producer uses the in-place operand of a scatter"; + return FusionDecision::Forbid( + "Producer uses the in-place operand of a scatter"); } - return {}; + return FusionDecision::Allow(); } FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, const HloInstruction& consumer) { if (!IsLoopFusibleAsProducer(producer) && !IsInputFusibleTranspose(producer)) { - return "the producer is not loop-fusible"; + return FusionDecision::Forbid("the producer is not loop-fusible"); } if (IsInputFusibleReduction(producer)) { @@ -505,7 +507,8 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, ->config() .debug_options() .xla_gpu_enable_reduction_epilogue_fusion()) { - return "Reduction epilogue fusion is not enabled."; + return FusionDecision::Forbid( + "Reduction epilogue fusion is not enabled."); } const HloInstruction& reduce_hero = producer.opcode() == HloOpcode::kFusion @@ -514,16 +517,19 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, if (!ReductionIsRaceFree( reduce_hero.GetModule()->config(), GetReductionKindAndContiguousComponents(reduce_hero))) { - return "Reduction output fusion only works for race free reductions"; + return FusionDecision::Forbid( + "Reduction output fusion only works for race free reductions"); } if (!AllSatisfy(consumer, [](const HloInstruction* hlo) { return IsIntermediate(hlo, /*allowed_operand_count=*/1); })) { - return "Reductions from/to continuous dims epilogue not fusible"; + return FusionDecision::Forbid( + "Reductions from/to continuous dims epilogue not fusible"); } if (producer.user_count() > 1) { - return "reduction output fusion only works for single user"; + return FusionDecision::Forbid( + "reduction output fusion only works for single user"); } } @@ -532,12 +538,14 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, } if (!IsInputFusible(consumer) && !IsLoopFusibleAsConsumer(consumer)) { - return "the consumer is not input-fusible and not loop-fusible"; + return FusionDecision::Forbid( + "the consumer is not input-fusible and not loop-fusible"); } // Skip multiple output fusion. It's not yet supported. if (producer.IsMultiOutputFusion()) { - return "the producer is not fusible as it is a multi-output fusion"; + return FusionDecision::Forbid( + "the producer is not fusible as it is a multi-output fusion"); } // Fuse scalar constants into loop fusion nodes. This reduces the number of @@ -551,7 +559,7 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, if (producer.opcode() == HloOpcode::kConstant && (!ShapeUtil::IsEffectiveScalar(producer.shape()) || consumer.opcode() != HloOpcode::kFusion)) { - return "not fusing constant"; + return FusionDecision::Forbid("not fusing constant"); } // Make sure the new fusion obeys the in-place semantics. @@ -561,7 +569,7 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { // Skip multiple output fusion. It's not yet supported. if (producer.IsMultiOutputFusion()) { - return "Producer is a multi-output fusion"; + return FusionDecision::Forbid("Producer is a multi-output fusion"); } // Allowing multi-output fusions that contain in-place operations makes code @@ -589,18 +597,18 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { // contract that describes what multi-output fusion scenarios are supported by // codegen and then changing this check to allow exactly those fusions). if (!HloDataflowAnalysis::GetInPlaceInputOutputPairs(&producer).empty()) { - return "In-place operations are present"; + return FusionDecision::Forbid("In-place operations are present"); } if (!IsLoopFusibleAsProducer(producer)) { - return "producer is not loop-fusible"; + return FusionDecision::Forbid("producer is not loop-fusible"); } if (IsPhysicallyTransposing(producer)) { - return "producer is physically transposing"; + return FusionDecision::Forbid("producer is physically transposing"); } - return {}; + return FusionDecision::Allow(); } // Returns an estimate of the shared memory usage for a given instruction in @@ -751,16 +759,17 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, FusionInfoCache* cache /*=nullptr*/) { if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) > device_info.shared_memory_per_block()) { - return FusionDecision{} - << "shared memory usage would be over the budget of " + return FusionDecision::Forbid( + "shared memory usage would be over the budget of ") << device_info.shared_memory_per_block() << "B"; } if (NumUnnestedReductions(instr1, cache) + NumUnnestedReductions(instr2, cache) > kMaxUnnestedReductionOutputsPerFusion) { - return FusionDecision{} << "over " << kMaxUnnestedReductionOutputsPerFusion - << " unnested reductions in fusion"; + return FusionDecision::Forbid("over ") + << kMaxUnnestedReductionOutputsPerFusion + << " unnested reductions in fusion"; } // Compute the number of outputs of the (possibly multi-output) fusion node @@ -791,7 +800,7 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, if (instr1.operand_count() + instr2.operand_count() - 1 + num_output_buffers <= MaxOperandsAndOutputsPerFusion()) { - return {}; + return FusionDecision::Allow(); } else { VLOG(5) << "Operand count of " << "(" << instr1.ToString() << " ) = " << instr1.operand_count() << " and ( " @@ -816,15 +825,16 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, // consumer numbers of output. So no need to check it. if (is_consumer_producer_fusion && operands.size() <= instr1.operands().size()) { - return {}; + return FusionDecision::Allow(); } // Does the new fusion have more operands and outputs than the max? if (operands.size() + num_output_buffers > MaxOperandsAndOutputsPerFusion()) { - return "Number of operands and output buffers is larger than allowed " - "budget per fusion"; + return FusionDecision::Forbid( + "Number of operands and output buffers is larger than allowed budget " + "per fusion"); } - return {}; + return FusionDecision::Allow(); } bool CreatesHeavyComputation(const HloInstruction& producer, diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 4ec9f347dff90d..520d0448c0c78f 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -543,7 +543,7 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( } if (!best_tiled_run_time_data.has_value()) { - return FusionDecision("No valid tilings found."); + return FusionDecision::Forbid("No valid tilings found."); } return *best_tiled_run_time_data; } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 1ab971ff2a4a94..9b9e637904210a 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -230,7 +230,7 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation( // issues to be resolved in the current implementation. if (hlo->opcode() == HloOpcode::kDot || hlo->opcode() == HloOpcode::kConcatenate) { - return FusionDecision{} << "Bailing out on " << hlo->ToString(); + return FusionDecision::Forbid("Bailing out on ") << hlo->ToString(); } // Due to the issue highlighted in b/365727080, and the related workaround @@ -254,13 +254,13 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation( SymbolicTile::FromIndexingMap(reshape_indexing_map); if (!reshape_symbolic_tile.has_value()) { - return FusionDecision{} << "Bailing out on reshape " << hlo->ToString() - << " with indexing map " - << reshape_indexing_map.ToString(); + return FusionDecision::Forbid("Bailing out on reshape ") + << hlo->ToString() << " with indexing map " + << reshape_indexing_map.ToString(); } } - return {}; + return FusionDecision::Allow(); } // Sets a SymbolicTile for each tiled hlo instruction and computes their @@ -292,16 +292,14 @@ SetSymbolicTilesAndComputeConstraints( auto symbolic_tile = SymbolicTile::FromIndexingMap(indexing_map); if (!symbolic_tile.has_value()) { - return FusionDecision{} << "Failed to compute symbolic tile for " - << indexing_map.ToString() << " for HLO " - << hlo->ToString(); + return FusionDecision::Forbid("Failed to compute symbolic tile for ") + << indexing_map.ToString() << " for HLO " << hlo->ToString(); } if (!symbolic_tile->is_satisfiable()) { - return FusionDecision{} << "Symbolic tile " << symbolic_tile->ToString() - << " is not satisfiable for " - << indexing_map.ToString() << " for HLO " - << hlo->ToString(); + return FusionDecision::Forbid("Symbolic tile ") + << symbolic_tile->ToString() << " is not satisfiable for " + << indexing_map.ToString() << " for HLO " << hlo->ToString(); } constraints = ConstraintExpression::And(std::move(constraints), @@ -309,7 +307,7 @@ SetSymbolicTilesAndComputeConstraints( constraints.Simplify(); if (!constraints.is_satisfiable()) { - return FusionDecision{} << "Fusion has unsatisfiable constraints"; + return FusionDecision::Forbid("Fusion has unsatisfiable constraints"); } tiled_hlo_instruction->set_symbolic_tile(*std::move(symbolic_tile)); @@ -365,8 +363,8 @@ void SortTiledHloInstructionsInPostOrder( auto roots = fusion.GetRoots(); if (roots.size() > 1) { - return FusionDecision{} << "Multi-output fusions are not supported. " - << fusion.ToString(); + return FusionDecision::Forbid("Multi-output fusions are not supported. ") + << fusion.ToString(); } auto& root = roots[0]; @@ -399,8 +397,8 @@ void SortTiledHloInstructionsInPostOrder( ComposeIndexingMaps(tiled_hlo_instruction->indexing_map(), *operand_indexing_map_set.begin()); if (operand_indexing_map.IsUndefined()) { - return FusionDecision{} - << "Couldn't derive indexing map for instruction " + return FusionDecision::Forbid( + "Couldn't derive indexing map for instruction ") << tiled_hlo_instruction->hlo()->ToString() << " and operand " << operand.instruction().ToString(); } diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc index 37986219faae16..d132cf6f3ae682 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc @@ -211,7 +211,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { // merge. if (producer->users().empty()) { ++num_fail_no_users_; - return "fusion has no users"; + return FusionDecision::Forbid("fusion has no users"); } // Skip 'producer' instruction if it is not a loop fusion. Library fusion @@ -220,7 +220,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { // kReduce), so they shouldn't be further fused either. if (!producer->IsLoopFusion()) { ++num_fail_not_loop_fusion_; - return "not a loop fusion"; + return FusionDecision::Forbid("not a loop fusion"); } auto producer_hero = GetRealHeroForMultiOutputFusion(*producer); @@ -229,11 +229,11 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { for (const HloInstruction* user : producer->users()) { if (user->opcode() == HloOpcode::kBitcast) { ++num_fail_merge_all_users_; - return "not fusing bitcast ops"; + return FusionDecision::Forbid("not fusing bitcast ops"); } if (user->IsCustomFusion()) { ++num_fail_merge_all_users_; - return "not fusing custom fusions"; + return FusionDecision::Forbid("not fusing custom fusions"); } auto consumer_hero = GetRealHeroForMultiOutputFusion(*user); if (auto compatible = @@ -256,7 +256,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { // it to a producer which transposes most data. if (has_reduction_user && TransposesMostData(*producer)) { ++num_fail_uncoalesced_read_; - return "would read mostly uncoalesced"; + return FusionDecision::Forbid("would read mostly uncoalesced"); } for (const HloInstruction* user : producer->users()) { @@ -285,8 +285,8 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { for (const HloInstruction* user : producer->users()) { if (cost_analysis_->ProducerConsumerMergedTooLarge(*producer, *user)) { ++num_fail_inefficient_fusion_emitter_; - return FusionDecision{} << "if merged with " << user->name() - << " will generate huge IR"; + return FusionDecision::Forbid("if merged with ") + << user->name() << " will generate huge IR"; } } @@ -295,10 +295,10 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { GpuPerformanceModelOptions::Default(), producer->users()); if (t.time_fused > t.time_unfused) { ++num_fail_slower_if_fused_; - return "will execute slower if fused"; + return FusionDecision::Forbid("will execute slower if fused"); } - return {}; + return FusionDecision::Allow(); } absl::StatusOr FusionMerger::Run( diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc index e1d21f1827c6f6..1f6d41698aeaa4 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc @@ -635,14 +635,14 @@ class Decision { // Returns true if it's profitable to fuse. bool WantToFuse() const { return fusing_decision_.CanFuse(); } - static Decision Accept() { return {FusionDecision(), true}; }; + static Decision Allow() { return {FusionDecision::Allow(), true}; }; - static Decision Decline(std::string_view value) { - return {FusionDecision(value), false}; + static Decision Deny(std::string_view value) { + return {FusionDecision::Forbid(value), false}; } static Decision NotProfitable(std::string_view value) { - return {FusionDecision(value), true}; + return {FusionDecision::Forbid(value), true}; } private: @@ -670,7 +670,7 @@ absl::StatusOr CreateDotFusion( legacy_triton::IsTritonSupportedInstruction(dot, gpu_version); !is_supported) { VLOG(3) << is_supported.Explain(); - return Decision::Decline(is_supported.Explain()); + return Decision::Deny(is_supported.Explain()); } // Verify sparse dot constraints. @@ -729,7 +729,7 @@ absl::StatusOr CreateDotFusion( dot, TritonFusionAnalysis::Scope::LHS) || !analysis.IsBatchDimMinorForInt4Parameter( dot, TritonFusionAnalysis::Scope::RHS)) { - return Decision::Decline( + return Decision::Deny( "Fusion is not possible because the parameter with the type S4 has " "minor batch dimension."); } @@ -742,7 +742,7 @@ absl::StatusOr CreateDotFusion( algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() || dot.sparse_operands()) { - return Decision::Accept(); + return Decision::Allow(); } bool is_pure_matmul = true; @@ -760,7 +760,7 @@ absl::StatusOr CreateDotFusion( if (is_pure_matmul) { return Decision::NotProfitable("Pure Matmul"); } - return Decision::Accept(); + return Decision::Allow(); } // Extracts into fused computations parts of HLO graph including dot() diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc index 5e32f2ec0c2ee1..bfd8c5bbb6b0a9 100644 --- a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc @@ -91,31 +91,34 @@ FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks( // Output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { - return "the producer is a fusion"; + return FusionDecision::Forbid("the producer is a fusion"); } if (consumer->IsCustomFusion()) { - return "the consumer is a custom fusion"; + return FusionDecision::Forbid("the consumer is a custom fusion"); } // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (is_expensive(*producer) && ReusesOperandElements(consumer, operand_index)) { - return "the producer is expensive, and the consumer reuses inputs"; + return FusionDecision::Forbid( + "the producer is expensive, and the consumer reuses inputs"); } // Do not fuse into fusions if the resulting kernel would suffer from // uncoalesced reads due to a transposed memory access pattern. if (IsInputFusibleReduction(*consumer) && IsPhysicallyTransposing(*producer)) { - return "fusing the producer would break read coalescing"; + return FusionDecision::Forbid( + "fusing the producer would break read coalescing"); } RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer)); if (CreatesHeavyComputation(*producer, *consumer)) { - return "the fusion would create a heavy computation"; + return FusionDecision::Forbid( + "the fusion would create a heavy computation"); } return InstructionFusion::ShouldFuse(consumer, operand_index); @@ -133,7 +136,7 @@ FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, /*is_consumer_producer_fusion=*/true)); if (consumer->opcode() != HloOpcode::kFusion) { - return {}; + return FusionDecision::Allow(); } // Also check that our emitter can handle the fusion node. We currently can @@ -149,9 +152,10 @@ FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, FusionNodeIndexingEvaluation(consumer)); } if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) { - return "the fusion would result in an overly large code duplication"; + return FusionDecision::Forbid( + "the fusion would result in an overly large code duplication"); } - return {}; + return FusionDecision::Allow(); } HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc index 04456d8131ac76..9ab9729b3b9202 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc @@ -86,14 +86,16 @@ const HloSliceInstruction* FindUniqueSlice(const HloInstruction* parent, FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1, const HloInstruction& instr2, const HloInstruction* parent) { - if (parent->shape().IsTuple()) return {}; + if (parent->shape().IsTuple()) return FusionDecision::Allow(); // Allow MOF if the parameter is small, even if there's no overlap. 1024 bytes // were arbitrarily chosen as the threshold. - if (ShapeUtil::ByteSizeOfElements(parent->shape()) < 1024) return {}; + if (ShapeUtil::ByteSizeOfElements(parent->shape()) < 1024) { + return FusionDecision::Allow(); + } const HloSliceInstruction* slice1 = FindUniqueSlice(parent, &instr1); const HloSliceInstruction* slice2 = FindUniqueSlice(parent, &instr2); - if (!slice1 || !slice2) return {}; + if (!slice1 || !slice2) return FusionDecision::Allow(); // TODO(jreiffers): Check strides as well. auto& starts1 = slice1->slice_starts(); @@ -104,10 +106,10 @@ FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1, for (int64_t dim = 0; dim < parent->shape().rank(); ++dim) { bool overlap = starts1[dim] < limits2[dim] && starts2[dim] < limits1[dim]; if (!overlap) { - return "slices are non-overlapping"; + return FusionDecision::Forbid("slices are non-overlapping"); } } - return {}; + return FusionDecision::Allow(); } FusionDecision LegalToFuse(const HloInstruction& instr1, @@ -125,7 +127,7 @@ FusionDecision LegalToFuse(const HloInstruction& instr1, (instr2.opcode() == HloOpcode::kFusion && instr2.fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice)) { - return "can't fuse multiple DUSs"; + return FusionDecision::Forbid("can't fuse multiple DUSs"); } // Do this check last, as it may be expensive. @@ -175,11 +177,11 @@ FusionDecision OperandReachableFromProducer( << "Reachability map is incomplete. This should never " "happen."; if (&producer != operand && reachability.IsReachable(&producer, operand)) { - return { - absl::StrCat(producer.name(), " would introduce a cycle when fused")}; + return FusionDecision::Forbid( + absl::StrCat(producer.name(), " would introduce a cycle when fused")); } } - return {}; + return FusionDecision::Allow(); } FusionDecision ProducerCandidateIsFusible( @@ -188,7 +190,8 @@ FusionDecision ProducerCandidateIsFusible( const se::DeviceDescription& device_info, GpuHloCostAnalysis* cost_analysis) { if (!IsFusibleAsMultiOutputFusionRoot(consumer)) { - return "consumer not eligible as multi-output fusion root."; + return FusionDecision::Forbid( + "consumer not eligible as multi-output fusion root."); } RETURN_IF_NOT_FUSIBLE( @@ -202,7 +205,7 @@ FusionDecision ProducerCandidateIsFusible( /*is_consumer_producer_fusion=*/false, fusion_info_cache)); if (cost_analysis->ProducerConsumerMergedTooLarge(producer, consumer)) { - return "will generate too large IR"; + return FusionDecision::Forbid("will generate too large IR"); } GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( @@ -211,10 +214,10 @@ FusionDecision ProducerCandidateIsFusible( /*fused_consumers=*/{&consumer}, /*multi_output=*/true); if (t.time_fused > t.time_unfused) { - return "will execute slower if fused"; + return FusionDecision::Forbid("will execute slower if fused"); } - return {}; + return FusionDecision::Allow(); } std::vector GetProducerConsumerMultiOutputFusionCandidates( @@ -283,8 +286,9 @@ FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, FusionInfoCache* fusion_info_cache, const se::DeviceDescription& device_info) { if (reachability.IsConnected(&sibling_consumer_1, &sibling_consumer_2)) { - return {absl::StrCat(sibling_consumer_1.name(), " and ", - sibling_consumer_2.name(), " are connected")}; + return FusionDecision::Forbid( + absl::StrCat(sibling_consumer_1.name(), " and ", + sibling_consumer_2.name(), " are connected")); } RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion( @@ -302,7 +306,7 @@ FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, // This check should be last, as it may be expensive. RETURN_IF_NOT_FUSIBLE(LegalToFuse(sibling_consumer_1, sibling_consumer_2, device_info, fusion_info_cache)); - return {}; + return FusionDecision::Allow(); } } // namespace diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index f887161d869fd9..c0e86818d36fe8 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -555,7 +555,7 @@ class PriorityFusionQueue { } } - return {}; + return FusionDecision::Allow(); } TiledRunTimeDataOrError GetTiledRunTimeDataCached( @@ -587,17 +587,17 @@ class PriorityFusionQueue { if (result_or_status.ok()) { return *result_or_status; } else { - return FusionDecision{ + return FusionDecision::Forbid( absl::StrCat("TiledRunTimeDataOrError return status: ", - result_or_status.status().message())}; + result_or_status.status().message())); } }(); if (const auto* fusion_decision = std::get_if(&tiled_run_time_data_or_error)) { - tiled_run_time_data_or_error = FusionDecision{ + tiled_run_time_data_or_error = FusionDecision::Forbid( absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ", - fusion_decision->Explain())}; + fusion_decision->Explain())); } absl::MutexLock lock(&tiled_run_time_data_cache_mutex_); @@ -608,12 +608,12 @@ class PriorityFusionQueue { FusionDecision CanFuseTriton(HloInstruction* producer, HloInstruction* consumer) { if (!triton_softmax_priority_fusion_enabled_) { - return "triton softmax fusion is not enabled"; + return FusionDecision::Forbid("triton softmax fusion is not enabled"); } if (IsGenericTritonFusion(*producer)) { if (!IsFusible(*consumer)) { - return "the consumer is not fusible"; + return FusionDecision::Forbid("the consumer is not fusible"); } if (auto fusion_decision = IsTritonSupported(*consumer); @@ -622,7 +622,7 @@ class PriorityFusionQueue { } } else { if (!IsFusible(*producer)) { - return "the producer is not fusible"; + return FusionDecision::Forbid("the producer is not fusible"); } if (auto fusion_decision = IsTritonSupported(*producer); @@ -651,7 +651,7 @@ class PriorityFusionQueue { tiled_run_time_data.block_level_parameters; } - return {}; + return FusionDecision::Allow(); } FusionDecision CanFuse(HloInstruction* producer, HloInstruction* consumer) { @@ -660,15 +660,16 @@ class PriorityFusionQueue { } if (!IsFusible(*producer)) { - return "the producer is not fusible"; + return FusionDecision::Forbid("the producer is not fusible"); } if (!IsFusible(*consumer)) { - return "the consumer is not fusible"; + return FusionDecision::Forbid("the consumer is not fusible"); } if (consumer->opcode() == HloOpcode::kBitcast) { - return "not fusing into a single bitcast as consumer"; + return FusionDecision::Forbid( + "not fusing into a single bitcast as consumer"); } // Scatter is special as it has no elemental version but is still input @@ -698,7 +699,8 @@ class PriorityFusionQueue { }; if (contains_significant_reduce(producer) && contains_significant_reduce(consumer)) { - return "both the producer and the consumer contain a reduce"; + return FusionDecision::Forbid( + "both the producer and the consumer contain a reduce"); } // Avoid doing fusions into the output of an "input" fusion when it would @@ -712,8 +714,8 @@ class PriorityFusionQueue { fusion_analysis_cache_.Get(*producer, *consumer); if (analysis_fused.GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kLoop) { - return "fusion into output of a reduce fusion would create a loop " - "fusion"; + return FusionDecision::Forbid( + "fusion into output of a reduce fusion would create a loop fusion"); } } @@ -731,7 +733,8 @@ class PriorityFusionQueue { // kernels, in which case we don't want to fuse. // TODO(b/119692968): Remove this once we have fixed our fusion emitter. if (cost_analysis_.ProducerConsumerMergedTooLarge(*producer, *consumer)) { - return "the fusion would result in an overly large code duplication"; + return FusionDecision::Forbid( + "the fusion would result in an overly large code duplication"); } // Don't fuse across a root instruction. There are situation when a root @@ -739,7 +742,8 @@ class PriorityFusionQueue { // root are not necessary dead. They can be inputs to instructions with side // effects, like outfeed. if (producer == producer->parent()->root_instruction()) { - return "not fusing into the output of the root instruction"; + return FusionDecision::Forbid( + "not fusing into the output of the root instruction"); } return InstructionFusion::ShouldFuseInPlaceOp(producer, consumer); @@ -764,7 +768,7 @@ class PriorityFusionQueue { // override any value. { absl::MutexLock lock(&can_fuse_cache_mutex_); - can_fuse_cache_[producer][consumer] = fusion_decision; + can_fuse_cache_[producer].insert_or_assign(consumer, fusion_decision); } return fusion_decision; @@ -772,10 +776,9 @@ class PriorityFusionQueue { FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) { if (producer->users().empty()) { - return "No users to fuse"; + return FusionDecision::Forbid("No users to fuse"); } - FusionDecision result; bool has_non_bitcast_user = false; for (const auto& user : producer->users()) { if (user->opcode() == HloOpcode::kBitcast) { @@ -790,9 +793,10 @@ class PriorityFusionQueue { } } if (!has_non_bitcast_user) { - return "not fusing because there are only bitcast users"; + return FusionDecision::Forbid( + "not fusing because there are only bitcast users"); } - return {}; + return FusionDecision::Allow(); } // Store computation for cost analysis. diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index 00f00e7bba4f17..4b2e12c1ce36b8 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -519,8 +519,8 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( if (const auto* fusion_decision = std::get_if(&tiled_runtime_data_or)) { - return FusionDecision{absl::StrCat("SymbolicTileAnalysis failed: ", - fusion_decision->Explain())}; + return FusionDecision::Forbid(absl::StrCat("SymbolicTileAnalysis failed: ", + fusion_decision->Explain())); } TiledRunTimeData tiled_runtime_data = @@ -539,8 +539,9 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( if (run_time_without_softmax_rewriter < tiled_runtime_data.runtime_data.exec_time) { - return "Run time estimate for without applying the custom normalization " - "rewrite is faster."; + return FusionDecision::Forbid( + "Run time estimate for without applying the custom normalization " + "rewrite is faster."); } } @@ -552,7 +553,7 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config)); VLOG(5) << "Fusing with backend config: " << backend_config.DebugString(); - return FusionDecision{}; + return FusionDecision::Allow(); } absl::StatusOr MaybeFuseDiamondChainImpl( @@ -620,12 +621,12 @@ FusionDecision ShouldFuseReduction(const HloInstruction& reduce, const se::GpuComputeCapability& cc) { if (CodegenDecision is_supported = IsTritonSupportedInstruction(reduce, cc); !is_supported) { - return FusionDecision(is_supported.Explain()); + return FusionDecision::Forbid(is_supported.Explain()); } if (reduce.dimensions().size() != 1 || reduce.dimensions(0) != reduce.operand(0)->shape().rank() - 1) { - return FusionDecision( + return FusionDecision::Forbid( "The reductions in the diamond must reduce 1 dimension and that " "dimension must be the last dimension of the operand."); } @@ -639,21 +640,23 @@ FusionDecision ShouldFuseReduction(const HloInstruction& reduce, identity->operand(0)->opcode() == HloOpcode::kConstant && IsTritonSupportedInstruction(*identity, cc)); if (!should_fuse_identity) { - return "Reduction identity is not a constant or a supported convert of a " - "constant."; + return FusionDecision::Forbid( + "Reduction identity is not a constant or a supported convert of a " + "constant."); } - return {}; + return FusionDecision::Allow(); } DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( HloInstruction* instr, const se::GpuComputeCapability& cc) { if (!instr->IsElementwiseBinary()) { - return "Root is not elementwise binary."; + return FusionDecision::Forbid("Root is not elementwise binary."); } if (!IsTritonSupportedInstruction(*instr, cc)) { - return "Root is not supported for Triton instruction."; + return FusionDecision::Forbid( + "Root is not supported for Triton instruction."); } HloInstruction* producer; @@ -662,18 +665,21 @@ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast, cc)) { - return "Could not find a trivial connection from root to a broadcast."; + return FusionDecision::Forbid( + "Could not find a trivial connection from root to a broadcast."); } if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, cc)) { - return "Could not find a trivial connection from matched broadcast to a " - "reduction."; + return FusionDecision::Forbid( + "Could not find a trivial connection from matched broadcast to a " + "reduction."); } if (!(HasDefaultLayout(broadcast->shape()) && HasDefaultLayout(reduce->shape()))) { - return "Broadcast or reduce have non-default layouts."; + return FusionDecision::Forbid( + "Broadcast or reduce have non-default layouts."); } if (FusionDecision should_fuse_reduction = ShouldFuseReduction(*reduce, cc); @@ -691,19 +697,21 @@ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( identity->operand(0)->opcode() == HloOpcode::kConstant && IsTritonSupportedInstruction(*identity, cc)); if (!should_fuse_identity) { - return "Reduction identity is not a constant or a supported convert of a " - "constant."; + return FusionDecision::Forbid( + "Reduction identity is not a constant or a supported convert of a " + "constant."); } if (!HasOneUse(broadcast) || !HasOneUse(reduce)) { - return "More than one use of broadcast or reduce."; + return FusionDecision::Forbid("More than one use of broadcast or reduce."); } producer = reduce->mutable_operand(0); if (absl::c_linear_search(broadcast->dimensions(), broadcast->shape().rank() - 1)) { - return "Broadcast is not along the reduction dimension."; + return FusionDecision::Forbid( + "Broadcast is not along the reduction dimension."); } while (IsTriviallyFusible(producer, cc)) { @@ -711,16 +719,16 @@ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( } if (!HasDefaultLayout(producer->shape())) { - return "Producer has non-default layout."; + return FusionDecision::Forbid("Producer has non-default layout."); } if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), cc)) { - return "Producer is not trivially connected."; + return FusionDecision::Forbid("Producer is not trivially connected."); } if (producer != instr->operand(0) && instr->operand(0)->user_count() != 1) { - return "Unsupported root-producer connection."; + return FusionDecision::Forbid("Unsupported root-producer connection."); } VLOG(5) << "Matched Softmax diamond with: "; diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 1943af5343b427..1d5230ee66855e 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -284,7 +284,7 @@ Int64OrError CombineSplitDimMajorPartSizeReqs(int64_t a, int64_t b) { if (a == kNoSplitRequirement) { return b; } - return FusionDecision("Conflicting splits of splittable dimension"); + return FusionDecision::Forbid("Conflicting splits of splittable dimension"); } } // namespace @@ -318,7 +318,7 @@ DotRequirementsOrError GetRequirementsIfSupportedOrder( CHECK(!dim_fragments.empty()); for (int i = 0; i < dim_fragments.size() - 1; ++i) { if (tensor_dim_fragments[dim_fragments[i]].is_sliced()) { - return "Sliced non-major-most fragment."; + return FusionDecision::Forbid("Sliced non-major-most fragment."); } } int group_counter = 0; @@ -342,7 +342,7 @@ DotRequirementsOrError GetRequirementsIfSupportedOrder( } if (last_seen_group_last_fragment_index > *fragment_it) { - return "Transpose within a dimension."; + return FusionDecision::Forbid("Transpose within a dimension."); } ++group_counter; @@ -356,14 +356,16 @@ DotRequirementsOrError GetRequirementsIfSupportedOrder( if (group_counter == 2) { if (split_dim_major_part != kNoSplitRequirement && split_dim_major_part != grouped_size) { - return "Conflicting splits of splittable dimension"; + return FusionDecision::Forbid( + "Conflicting splits of splittable dimension"); } split_dim_major_part = grouped_size; } else if (group_counter > 2) { - return "2nd split of a splittable dimension."; + return FusionDecision::Forbid( + "2nd split of a splittable dimension."); } } else { - return "Unsupported split of a dimension."; + return FusionDecision::Forbid("Unsupported split of a dimension."); } } @@ -479,7 +481,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( }; if (dst_remaining_size >= src_dim->full_count()) { if (dst_remaining_size % src_dim->full_count()) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } // Source dimension fragment completely fits into the destination one: // just copy it as is. @@ -497,7 +499,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // If there is a remaining fragment of a previous destination dimension // assign it first. if (src_remaining_size % dst_remaining_size || (src_dim->is_sliced())) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } add_new_fragment( Fragment{src_dim->dst_dim_number(), dst_remaining_size}); @@ -515,13 +517,13 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // size assign the remainder of the source and carry over the // remainder of the destination. if (dst_dim_size % src_remaining_size) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } dst_remaining_size = dst_dim_size / src_remaining_size; new_fragment_size = src_remaining_size; } if (src_dim->is_sliced()) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } add_new_fragment( Fragment{src_dim->dst_dim_number(), new_fragment_size}); @@ -537,7 +539,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // give up. while (dst_dim_it != dst_dim_end) { if (dst_shape.dimensions(*dst_dim_it) != 1) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } if (!dst_fragments_order.empty()) { dst_fragments_order.push_back( @@ -582,7 +584,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( Fragments src_fragments_order = src_dim_order.TensorFragmentsOrder(); if (hlo.opcode() == HloOpcode::kSlice && ShapeUtil::IsEffectiveScalar(hlo.shape())) { - return FusionDecision("Slice to scalar is not implemented yet."); + return FusionDecision::Forbid("Slice to scalar is not implemented yet."); } // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by @@ -595,7 +597,8 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( // It's not supported currently to further propagate dimensions after // reaching a trivial sized tensor. We could probably support it, but now we // just prevent crashing here. - return FusionDecision("Cannot propagate further from trivial sized tensor"); + return FusionDecision::Forbid( + "Cannot propagate further from trivial sized tensor"); } auto src_fragment_it = src_fragments_order.begin(); for (int64_t dim_index : src.shape().layout().minor_to_major()) { @@ -652,17 +655,17 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( dst_logical.resize(src_logical.size() + reduce->dimensions().size()); if (reduce->dimensions().size() != 1) { - return FusionDecision("Unsupported reduction."); + return FusionDecision::Forbid("Unsupported reduction."); } else if (reduce->dimensions().front() != reduce->operand(0)->shape().rank() - 1) { - return FusionDecision("Only row reductions are supported."); + return FusionDecision::Forbid("Only row reductions are supported."); } } else if (hlo.opcode() == HloOpcode::kConcatenate) { dst_logical.resize(src_logical.size()); for (int i = 0; i < src_logical.size(); ++i) { if (i == hlo.concatenate_dimension()) { if (src_logical[i].size() != 1 || src_logical[i][0]->is_sliced()) { - return FusionDecision("Unsupported concatenation."); + return FusionDecision::Forbid("Unsupported concatenation."); } const Fragment& src_fragment = *src_logical[i][0]; Fragment& dst_fragment = new_fragments.emplace_back( @@ -733,7 +736,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( if (slice->slice_limits(dim) - slice->slice_starts(dim) != dst->shape().dimensions(dim)) { if (dst_logical[dim].size() > 1) { - return FusionDecision("Slicing of fragmented dimension."); + return FusionDecision::Forbid("Slicing of fragmented dimension."); } auto fragment = dst_logical[dim].front(); fragment->set_count(dst->shape().dimensions(dim)); @@ -755,7 +758,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( dst_logical[dim] = src_logical[dim]; if (dynamic_slice->slice_sizes(dim) != dst->shape().dimensions(dim)) { if (dst_logical[dim].size() > 1) { - return FusionDecision("Slicing of fragmented dimension."); + return FusionDecision::Forbid("Slicing of fragmented dimension."); } auto fragment = dst_logical[dim].front(); fragment->set_count(dst->shape().dimensions(dim)); @@ -767,7 +770,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( } } } else { - return FusionDecision("Function called on a wrong instruction."); + return FusionDecision::Forbid("Function called on a wrong instruction."); } // Destination logical -> destination physical and ungroup subdimensions. // Map original fragments to the resulting ones to derive their new @@ -794,7 +797,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( if (hlo.opcode() == HloOpcode::kBroadcast && src_fragments_order[fragment_number].full_count() > 1 && dim_numbers_present_in_dst.contains(dim_index)) { - return FusionDecision("Unsupported broadcast"); + return FusionDecision::Forbid("Unsupported broadcast"); } continue; } @@ -818,7 +821,8 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, return (user->opcode() == HloOpcode::kConcatenate || user->opcode() == HloOpcode::kDynamicSlice); })) { - return "No fusion into concatenations or dynamic slice."; + return FusionDecision::Forbid( + "No fusion into concatenations or dynamic slice."); } if (hlo.opcode() == HloOpcode::kParameter || hlo_query::IsScalarConstant(&hlo)) { @@ -830,7 +834,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, properties); } else if (hlo.opcode() == HloOpcode::kBroadcast) { if (direction != TransformDirection::kOutputToInput) { - return "Unsupported broadcast direction."; + return FusionDecision::Forbid("Unsupported broadcast direction."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); @@ -838,7 +842,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, // Pad ops are only supported when they are generated as part of the split-k // transform of dot fusions. if (direction != TransformDirection::kOutputToInput) { - return "Unsupported pad direction."; + return FusionDecision::Forbid("Unsupported pad direction."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); @@ -852,7 +856,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, } else if (hlo.opcode() == HloOpcode::kSlice) { // TODO(b/316637896) Add support for slices in softmax. if (direction != TransformDirection::kOutputToInput) { - return "Unsupported slice direction."; + return FusionDecision::Forbid("Unsupported slice direction."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, @@ -870,7 +874,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, properties); } else if (hlo.opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo.operand(0)->shape(), hlo.shape())) { - return "Non-bitcast reshape."; + return FusionDecision::Forbid("Non-bitcast reshape."); } return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, properties); @@ -885,15 +889,16 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, if (noncontracting_dim_fragment_order_it != src_dim_fragments_orders.end()) { if (noncontracting_dim_fragment_order_it->second.size() > 1) { - return "Concatenations on split non-contracting dimensions are " - "unsupported."; + return FusionDecision::Forbid( + "Concatenations on split non-contracting dimensions are " + "unsupported."); } } auto dim = LogicalIndexOfLabeledDimension(hlo.shape(), src_dim_order, noncontracting_dim_label); if (!dim.has_value() || dim.value() != hlo.concatenate_dimension()) { - return "Unsupported concatenation."; + return FusionDecision::Forbid("Unsupported concatenation."); } if (absl::c_any_of(hlo.operands(), [&hlo](const HloInstruction* operand) { // In the current simple implementation of concatenation the size of @@ -907,13 +912,13 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, kMinConcatFragmentSize != 0; })) { - return FusionDecision( + return FusionDecision::Forbid( "At least one operand of concatenation can not be perfectly tiled."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); } - return "Unimplemented instruction."; + return FusionDecision::Forbid("Unimplemented instruction."); } // Difference of input and output data volumes of an instruction. @@ -966,9 +971,9 @@ FusionDecision IsConversionWorthFusing(const HloInstruction& input, // output fusion - then it should be fused here anyway! if (ShapeUtil::ByteSizeOf(input.operand(0)->shape()) > ShapeUtil::ByteSizeOf(input.shape())) { - return "Narrowing conversion."; + return FusionDecision::Forbid("Narrowing conversion."); } - return FusionDecision{}; + return FusionDecision::Allow(); } } // namespace @@ -1004,16 +1009,16 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( if (hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement) { - return "Unsupported instruction."; + return FusionDecision::Forbid("Unsupported instruction."); } if (hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kAllReduce || hlo.opcode() == HloOpcode::kAllReduceStart || hlo.opcode() == HloOpcode::kAllReduceDone) { - return "Reductions are not fused yet."; + return FusionDecision::Forbid("Reductions are not fused yet."); } if (hlo.opcode() == HloOpcode::kPad) { - return "Pads are not fused yet."; + return FusionDecision::Forbid("Pads are not fused yet."); } if (auto decision = legacy_triton::IsTritonSupportedInstruction(hlo, gpu_version); @@ -1042,7 +1047,7 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( return decision; } } else if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy) { - return "Ignored elementwise operation"; + return FusionDecision::Forbid("Ignored elementwise operation"); } } else { // Exception for binary elementwise operations: in most cases these are @@ -1068,12 +1073,14 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( } } if (!accepted && !IsInputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as input."; + return FusionDecision::Forbid( + "Not obviously profitable to fuse as input."); } } } else { if (fusion_level < 2) { - return "Skipping fusing outputs at low fusion levels."; + return FusionDecision::Forbid( + "Skipping fusing outputs at low fusion levels."); } for (int i = 0; i < hlo.operand_count(); ++i) { const HloInstruction* operand = hlo.operand(i); @@ -1088,10 +1095,12 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( operand->opcode() == HloOpcode::kParameter) { continue; } - return "Has multiple inputs - not properly analyzed yet."; + return FusionDecision::Forbid( + "Has multiple inputs - not properly analyzed yet."); } if (!IsOutputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as output."; + return FusionDecision::Forbid( + "Not obviously profitable to fuse as output."); } } return dim_orders_and_requirements; diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index e96d811f4ba460..9fb259061e5de6 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -614,7 +614,7 @@ absl::StatusOr InstructionFusion::Run( << use_regular_fusion.Explain(); } - FusionDecision use_mof; + FusionDecision use_mof = FusionDecision::Allow(); if (!use_regular_fusion) { use_mof = ShouldFuseIntoMultiOutput(instruction, i); if (use_mof) { @@ -911,8 +911,9 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, for (int i = 0; i < consumer->operand_count(); ++i) { if (i != operand_number && consumer->operand(operand_number) == consumer->operand(i)) { - return "The consumer is an in-place operation that has an additional " - "operand that has the same value as the in-place buffer"; + return FusionDecision::Forbid( + "The consumer is an in-place operation that has an additional " + "operand that has the same value as the in-place buffer"); } } if (consumer->operand(operand_number) == producer || @@ -949,12 +950,13 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, return is_nonelementwise_op(inst); }); if (producer_nonelementwise_ops.size() > 1) { - return "Producer fusion has multiple non-elementwise ops, bailing."; + return FusionDecision::Forbid( + "Producer fusion has multiple non-elementwise ops, bailing."); } // If the producer has only elementwise ops or bitcasts, we can fuse. if (producer_nonelementwise_ops.empty()) { if (consumer->opcode() != HloOpcode::kFusion) { - return {}; + return FusionDecision::Allow(); } // If the consumer fusion has both elementwise and non-elementwise ops, // and ops of the two groups access the same buffer of the producer, we @@ -980,9 +982,10 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, instr); }); return inplace_conflict_after_fusion - ? "Non-elementwise ops in consumer lead to inplace conflict " - "after fusion." - : FusionDecision(); + ? FusionDecision::Forbid( + "Non-elementwise ops in consumer lead to inplace " + "conflict after fusion.") + : FusionDecision::Allow(); } auto dus_ops = ExtractInstructions(consumer, HloOpcode::kDynamicUpdateSlice); @@ -991,17 +994,18 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, // TODO(akuegel): Are there other ops than dynamic update slice where we // have a special emitter if it can be done in-place? if (dus_ops.empty()) { - return {}; + return FusionDecision::Allow(); } if (dus_ops.size() > 1) { - return "multiple dus ops, bailing."; + return FusionDecision::Forbid("multiple dus ops, bailing."); } auto dus = dus_ops[0]; auto producer_nonelementwise = producer_nonelementwise_ops[0]; if (producer_nonelementwise->opcode() == HloOpcode::kSlice) { if (producer_nonelementwise->shape() != dus->operand(1)->shape()) { - return "Slice op has a different shape than the update shape of the " - "dus op, bailing."; + return FusionDecision::Forbid( + "Slice op has a different shape than the update shape of the " + "dus op, bailing."); } for (int i = 0; i < dus->shape().rank(); ++i) { const HloInstruction* dus_operand = @@ -1010,21 +1014,23 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, if (!constant_operand || *constant_operand != producer_nonelementwise->slice_starts(i) || producer_nonelementwise->slice_strides(i) != 1) { - return "DUS and slice index mismatch"; + return FusionDecision::Forbid("DUS and slice index mismatch"); } } VLOG(4) << "DUS and slice index match"; if (consumer->opcode() == HloOpcode::kFusion && !IsSafeToFuseSliceIntoDusFusion(producer, consumer, dus)) { - return "Fusing slice into DUS will also fuse another non-elementwise " - "op with shared operand as DUS."; + return FusionDecision::Forbid( + "Fusing slice into DUS will also fuse another non-elementwise " + "op with shared operand as DUS."); } - return {}; + return FusionDecision::Allow(); } if (producer_nonelementwise->opcode() == HloOpcode::kDynamicSlice) { if (producer_nonelementwise->shape() != dus->operand(1)->shape()) { - return "Dynamic slice op has a different shape than the update shape " - "of the dus op, bailing."; + return FusionDecision::Forbid( + "Dynamic slice op has a different shape than the update shape " + "of the dus op, bailing."); } for (int i = 0; i < dus->shape().rank(); ++i) { const HloInstruction* ds_operand = get_real_operand( @@ -1035,21 +1041,23 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, auto constant_dus_operand = get_constant_operand(dus_operand); if (constant_ds_operand != constant_dus_operand || (!constant_ds_operand && ds_operand != dus_operand)) { - return "DUS and DS index mismatch"; + return FusionDecision::Forbid("DUS and DS index mismatch"); } } VLOG(4) << "DUS and DS index match"; if (consumer->opcode() == HloOpcode::kFusion && !IsSafeToFuseSliceIntoDusFusion(producer, consumer, dus)) { - return "Fusing DS into DUS will also fuse another non-elementwise op " - "with shared operand as DUS."; + return FusionDecision::Forbid( + "Fusing DS into DUS will also fuse another non-elementwise op " + "with shared operand as DUS."); } - return {}; + return FusionDecision::Allow(); } - return "unrecognized inplace update non-elementwise output pair"; + return FusionDecision::Forbid( + "unrecognized inplace update non-elementwise output pair"); } } - return {}; + return FusionDecision::Allow(); } FusionDecision InstructionFusion::ShouldFuse(HloInstruction* consumer, @@ -1065,15 +1073,17 @@ FusionDecision InstructionFusion::ShouldFuse( // Don't fuse across a root instruction. if (producer == producer->parent()->root_instruction()) { - return "not fusing into the output of the root instruction"; + return FusionDecision::Forbid( + "not fusing into the output of the root instruction"); } // Cost condition: don't duplicate expensive instructions. if (FusionWouldDuplicate(*producer, *consumer) && (!may_duplicate_ || is_expensive_(*producer)) && !IsAlwaysDuplicable(*producer)) { - return may_duplicate_ ? "expensive producer would be duplicated" - : "fusion pass cannot duplicate"; + return FusionDecision::Forbid(may_duplicate_ + ? "expensive producer would be duplicated" + : "fusion pass cannot duplicate"); } return inplace_op_fusion_decider(producer, consumer); } diff --git a/third_party/xla/xla/service/instruction_fusion.h b/third_party/xla/xla/service/instruction_fusion.h index c4952aea15c2ae..facd01b4dc2b89 100644 --- a/third_party/xla/xla/service/instruction_fusion.h +++ b/third_party/xla/xla/service/instruction_fusion.h @@ -46,16 +46,11 @@ namespace xla { // explain the reason. class FusionDecision { public: - // Can not be fused: explain why. Implicit conversion due to optional-like - // semantics: waiver granted in cl/419938611. - FusionDecision(absl::string_view explanation) // NOLINT - : explanation_(explanation) {} - - // Same constructor as string_view, to allow implicit string conversion (can't - // implicitly convert both char* to string_view and string_view to - // FusionDecision). - FusionDecision(const char* explanation) // NOLINT - : explanation_(explanation) {} + static FusionDecision Allow() { return FusionDecision(); } + static FusionDecision Forbid(absl::string_view explanation) { + return FusionDecision(explanation); + } + FusionDecision(const FusionDecision& decision) = default; // If condition is `true` means that we CAN fuse. In that case, explanation is // discarded. @@ -74,9 +69,6 @@ class FusionDecision { absl::SourceLocation source_location = absl::SourceLocation::current()); #endif // PLATFORM_GOOGLE - // Can be fused. - FusionDecision() = default; - // Returns whether it can be fused. explicit operator bool() const { return CanFuse(); } @@ -88,9 +80,10 @@ class FusionDecision { // them is false to show why fusion wasn't performed. FusionDecision Or(const FusionDecision& decision) const { if (CanFuse() || decision.CanFuse()) { - return {}; + return Allow(); } - return {absl::StrCat(explanation_.value_or(""), " ; ", decision.Explain())}; + return Forbid( + absl::StrCat(explanation_.value_or(""), " ; ", decision.Explain())); } // Connects two fusion decision with a conjunction. Unlike disjunction, @@ -109,12 +102,12 @@ class FusionDecision { // Appends to explanation, or turns the decision negative. FusionDecision operator<<(absl::string_view explanation) const { - return {absl::StrCat(explanation_.value_or(""), explanation)}; + return Forbid(absl::StrCat(explanation_.value_or(""), explanation)); } // Appends to explanation, or turns the decision negative. FusionDecision operator<<(int64_t explanation) const { - return {absl::StrCat(explanation_.value_or(""), explanation)}; + return Forbid(absl::StrCat(explanation_.value_or(""), explanation)); } // Explains why the fusion could not be performed. @@ -123,6 +116,14 @@ class FusionDecision { private: // Empty IFF fusion is possible (explanation provided for negative cases). std::optional explanation_; + + FusionDecision() = default; + + explicit FusionDecision(absl::string_view explanation) + : explanation_(explanation) {} + + explicit FusionDecision(const char* explanation) + : explanation_(explanation) {} }; #define RETURN_IF_NOT_FUSIBLE(...) \ @@ -213,7 +214,8 @@ class InstructionFusion : public HloModulePass { // duplicated by multi-output fusion. virtual FusionDecision ShouldFuseIntoMultiOutput(HloInstruction* consumer, int64_t operand_index) { - return "multi-output fusion not supported by this pass"; + return FusionDecision::Forbid( + "multi-output fusion not supported by this pass"); } // Chooses a fusion kind for `producer` and `consumer`. diff --git a/third_party/xla/xla/service/instruction_fusion_test.cc b/third_party/xla/xla/service/instruction_fusion_test.cc index db6c3244c3932f..0fe695f8a11caa 100644 --- a/third_party/xla/xla/service/instruction_fusion_test.cc +++ b/third_party/xla/xla/service/instruction_fusion_test.cc @@ -791,20 +791,20 @@ TEST_F(InstructionFusionTest, DontFuseProducerIfInplaceConflict) { class FusionDecisionTest : public HloTestBase {}; TEST_F(FusionDecisionTest, NotFusionPossibleDisjunction) { - FusionDecision a = {}; - FusionDecision b = "not possible"; + FusionDecision a = FusionDecision::Allow(); + FusionDecision b = FusionDecision::Forbid("not possible"); EXPECT_TRUE(!a || !b); - a = "not possible"; - b = {}; + a = FusionDecision::Forbid("not possible"); + b = FusionDecision::Allow(); EXPECT_TRUE(!a || !b); - a = "impossible"; - b = "very impossible"; + a = FusionDecision::Forbid("impossible"); + b = FusionDecision::Forbid("very impossible"); EXPECT_TRUE(!a || !b); - a = {}; - b = {}; + a = FusionDecision::Allow(); + b = FusionDecision::Allow(); EXPECT_FALSE(!a || !b); } From 1cb72f045d69dd6259c9036ee92eb94d9fbbb83c Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Fri, 20 Sep 2024 05:59:03 -0700 Subject: [PATCH 064/483] PR #16779: [NVIDIA] Added an option in hlo verifier options to skip checking duplicate channel ids Imported from GitHub PR https://github.com/openxla/xla/pull/16779 This is to address this issue(https://github.com/openxla/xla/issues/14600). In gpu pipeline, the uniqueness of channel id is not used anywhere, only the presence of it is used for determining replica/partition properties and for device assignment. This pr introduces an option in hlo verifier to skip checking for duplicate channel ids. The duplication check is disabled in all gpu pipelines. Copybara import of the project: -- 8c0ca67bfab296de11d20be638d563a9c035b855 by TJ Xu : Added an option in hlo verifier options to skip checking duplicate channel ids -- 731e663f362b4521b459b77a24007bfc1d0ed23a by TJ Xu : rename to verify_unique_channel_ids -- df66fd75685974b15795f747cd8d098ab4670f9c by TJ Xu : introduce a debug flag to control verifier check -- 93aac1f7f28c8a1572b079d8e488b5d0bcabdc07 by TJ Xu : Changed default to false Merging this change closes #16779 PiperOrigin-RevId: 676815577 --- third_party/xla/xla/debug_options_flags.cc | 6 +++ .../xla/xla/service/gpu/fusion_pipeline.cc | 11 +++--- .../xla/xla/service/gpu/gpu_compiler.cc | 37 ++++++++++++------- .../prepare_hlo_for_ir_emitting_pipeline.cc | 11 +++--- third_party/xla/xla/service/hlo_verifier.cc | 22 +++++++---- third_party/xla/xla/service/hlo_verifier.h | 3 ++ .../xla/xla/service/hlo_verifier_test.cc | 20 ++++++++++ .../xla/xla/tests/collective_ops_e2e_test.cc | 32 ++++++++++++++++ third_party/xla/xla/xla.proto | 6 ++- 9 files changed, 115 insertions(+), 33 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 3af22151ea6d6c..281f936fd772bb 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -292,6 +292,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10); opts.set_xla_gpu_executable_terminate_timeout_seconds(30); opts.set_xla_gpu_experimental_disable_binary_libraries(false); + opts.set_xla_experimental_ignore_channel_id(false); return opts; } @@ -1945,6 +1946,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_experimental_disable_binary_libraries(), "Disable XLA GPU passes that depend on non-open source binary " "libraries")); + flag_list->push_back(tsl::Flag( + "xla_experimental_ignore_channel_id", + bool_setter_for(&DebugOptions::set_xla_experimental_ignore_channel_id), + debug_options->xla_experimental_ignore_channel_id(), + "Experimental: ignore channel ids for collective operations.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc index 7bc4f170980ea1..e27865c06c63d1 100644 --- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc @@ -50,12 +50,13 @@ HloPassPipeline FusionPipeline( // We try to split variadic ops with many parameters into several such ops // to avoid exceeding the parameter space. fusion.AddPass(); + HloVerifierOpts opts = + HloVerifierOpts().MakeLayoutSensitive().WithInstructionCanChangeLayout( + LayoutAssignment::InstructionCanChangeLayout); + opts.verify_unique_channel_ids = + !debug_options.xla_experimental_ignore_channel_id(); fusion.AddInvariantCheckerDebug( - std::make_unique( - HloVerifierOpts() - .MakeLayoutSensitive() - .WithInstructionCanChangeLayout( - LayoutAssignment::InstructionCanChangeLayout)), + std::make_unique(std::move(opts)), "hlo verifier (debug)"); if (debug_options.xla_gpu_enable_priority_fusion()) { diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 4dee1d3f4fd57a..0db21cdf7173d2 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -473,8 +473,10 @@ GpuCompiler::GpuCompiler(se::Platform::Id platform_id, namespace { // Adds the HloVerifier for GPU to the given pipeline. -void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {}, - bool debug_only = false) { +void AddHloVerifier(HloPassPipeline* pipeline, + bool verify_unique_channel_ids = false, + HloVerifierOpts&& opts = {}, bool debug_only = false) { + opts.verify_unique_channel_ids = verify_unique_channel_ids; std::unique_ptr verifier_metadata = std::make_unique(std::move(opts)); if (debug_only) { @@ -647,7 +649,8 @@ absl::Status RunOptimizationPasses( const DebugOptions& debug_options = hlo_module->config().debug_options(); HloPassPipeline pipeline("optimization"); - AddHloVerifier(&pipeline); + AddHloVerifier(&pipeline, + !debug_options.xla_experimental_ignore_channel_id()); if (debug_options.xla_gpu_multi_streamed_windowed_einsum()) { pipeline.AddPass(); } @@ -767,7 +770,9 @@ absl::Status RunOptimizationPasses( // point. [&, &pipeline = pipeline.AddPass>("simplification")] { - AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true); + AddHloVerifier(&pipeline, + !debug_options.xla_experimental_ignore_channel_id(), + HloVerifierOpts{}, /*debug_only=*/true); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -1538,7 +1543,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( } HloPassPipeline pipeline("post-layout_assignment"); - AddHloVerifier(&pipeline, + AddHloVerifier(&pipeline, !debug_options.xla_experimental_ignore_channel_id(), HloVerifierOpts{} .MakeLayoutSensitive() .WithInstructionCanChangeLayout( @@ -1603,14 +1608,16 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( #ifdef NDEBUG // Verify the module in non-debug builds. For debug builds, the verifier // already runs after every pass. + HloVerifierOpts opts = HloVerifierOpts{} + .MakeLayoutSensitive() + .WithInstructionCanChangeLayout( + LayoutAssignment::InstructionCanChangeLayout) + .VerifyBroadcastDimensionsOrder() + .VerifyReshapeIsBitcast(); + opts.verify_unique_channel_ids = + !debug_options.xla_experimental_ignore_channel_id(); pipeline.AddPass( - std::make_unique( - HloVerifierOpts{} - .MakeLayoutSensitive() - .WithInstructionCanChangeLayout( - LayoutAssignment::InstructionCanChangeLayout) - .VerifyBroadcastDimensionsOrder() - .VerifyReshapeIsBitcast()), + std::make_unique(std::move(opts)), "end-of-post-layout_assignment"); #endif // NDEBUG @@ -2574,8 +2581,10 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( } if (module->config().debug_options().xla_gpu_enable_pgle_accuracy_checker()) { - AddHloVerifier(&main_pipeline, - HloVerifierOpts{}.VerifyInstructionNameUnchanged()); + AddHloVerifier( + &main_pipeline, + module->config().debug_options().xla_experimental_ignore_channel_id(), + HloVerifierOpts{}.VerifyInstructionNameUnchanged()); } return main_pipeline.Run(module).status(); } diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc index 8704737b711e25..50c627981ce7c5 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc @@ -48,12 +48,13 @@ HloPassPipeline PrepareHloModuleForIrEmittingPipeline( // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); + HloVerifierOpts opts = + HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout( + LayoutAssignment::InstructionCanChangeLayout); + opts.verify_unique_channel_ids = + !debug_options.xla_experimental_ignore_channel_id(); std::unique_ptr verifier_metadata = - std::make_unique( - HloVerifierOpts{} - .MakeLayoutSensitive() - .WithInstructionCanChangeLayout( - LayoutAssignment::InstructionCanChangeLayout)); + std::make_unique(std::move(opts)); pipeline.AddInvariantCheckerDebug(std::move(verifier_metadata), "hlo verifier (debug)"); diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 22592935ec498f..8c71ad693bfb33 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -2414,7 +2414,8 @@ absl::Status VerifyLayoutConstrainedAllReduce(const HloModule& module) { // Checks various invariants of channel instructions (send/recv and // collectives). -absl::Status VerifyChannels(const HloModule& module) { +absl::Status VerifyChannels(const HloModule& module, + const HloVerifierOpts& opts) { absl::flat_hash_map> channel_instructions; @@ -2533,9 +2534,11 @@ absl::Status VerifyChannels(const HloModule& module) { } else { opcode_to_count[instr->opcode()] = 1; } - TF_RET_CHECK(DynCast(instr) != nullptr) - << "channel " << pair.first - << " is used for different types of channel instructions"; + if (opts.verify_unique_channel_ids) { + TF_RET_CHECK(DynCast(instr) != nullptr) + << "channel " << pair.first + << " is used for different types of channel instructions"; + } } int count = opcode_to_count.begin()->second; @@ -2571,9 +2574,11 @@ absl::Status VerifyChannels(const HloModule& module) { } } else { for (const HloInstruction* instr : instructions) { - TF_RET_CHECK(first->opcode() == instr->opcode()) - << "channel " << pair.first - << " is used for different types of channel instructions"; + if (opts.verify_unique_channel_ids) { + TF_RET_CHECK(first->opcode() == instr->opcode()) + << "channel " << pair.first + << " is used for different types of channel instructions"; + } } } } @@ -3082,7 +3087,8 @@ absl::StatusOr HloVerifier::Run( TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module)); - TF_RETURN_IF_ERROR(VerifyChannels(*module)); + TF_RETURN_IF_ERROR( + VerifyChannels(*module, target_metadata_->GetVerifierOpts())); TF_RETURN_IF_ERROR(VerifyInstructionNameUnchanged( *module, target_metadata_->GetVerifierOpts())); diff --git a/third_party/xla/xla/service/hlo_verifier.h b/third_party/xla/xla/service/hlo_verifier.h index 83f5e32a6def05..588b24ec7d1bb9 100644 --- a/third_party/xla/xla/service/hlo_verifier.h +++ b/third_party/xla/xla/service/hlo_verifier.h @@ -149,6 +149,9 @@ struct HloVerifierOpts { // cloned (".clone" suffix) or rematted (".remat"); bool verify_instruction_name_unchanged = false; + // Check if channel instructions all have unique channel ids. + bool verify_unique_channel_ids = true; + HloPredicate instruction_can_change_layout; // Returns a target-specific shape size. diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index 07c47980640b71..6c649a8c0ff004 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -3408,5 +3408,25 @@ TEST_F(HloVerifierTestLayoutSensitive, HasSubstr("Instruction has mismatched minor-to-major size and " "dimension size: ")); } + +TEST_F(HloVerifierTest, NoErrorOnDuplicateChannelId) { + const char* const hlo_string = R"( + HloModule m + + ENTRY main { + data_param = f32[2048,2048]{1,0} parameter(0) + cp1 = f32[2048,2048]{1,0} collective-permute(data_param), source_target_pairs={{0,1},{1,2},{2,3}}, channel_id=1 + cp2 = f32[2048,2048]{1,0} collective-permute(data_param), source_target_pairs={{0,1}}, channel_id=1 + + ROOT tuple = (f32[2048,2048]{1,0}, f32[2048,2048]{1,0}) tuple(cp1, cp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + HloVerifierOpts opts{}; + opts.verify_unique_channel_ids = false; + HloVerifier verifier(std::move(opts)); + ASSERT_IS_OK(verifier.Run(module.get()).status()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 7f30b975ebfe02..4fcfdfd3391745 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -1279,5 +1279,37 @@ ENTRY entry { LiteralTestUtil::ExpectR1Equal({8., 8.}, results[1]); } +TEST_F(CollectiveOpsTestE2E, NoErrorOnDuplicateChannelId) { + absl::string_view kModuleReplicatedStr = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f32[4,32,128]{2,1,0})->(f32[4,32,128]{2,1,0}, f32[4,32,128]{2,1,0})}, num_partitions=4 +ENTRY entry { + param = f32[4,32,128]{2,1,0} parameter(0) + all-to-all = f32[4,32,128]{2,1,0} all-to-all(param), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1} + all-to-all.1 = f32[4,32,128]{2,1,0} all-to-all(param), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={0} + ROOT tuple = (f32[4,32,128]{2,1,0}, f32[4,32,128]{2,1,0}) tuple(all-to-all, all-to-all.1) +} +)"; + + const int64_t kNumReplicas = 1; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + const int64_t kNumPartitions = 4; + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + + auto opts = GetDebugOptionsForTest(); + opts.set_xla_experimental_ignore_channel_id(true); + config.set_debug_options(opts); + + config.set_num_partitions(kNumPartitions); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); + EXPECT_TRUE(executable->has_module()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index d5384f52d52e8d..74eaf1166e459e 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -974,7 +974,11 @@ message DebugOptions { int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327; int32 xla_gpu_executable_terminate_timeout_seconds = 328; - // Next id: 330 + // Whether to ignore channel ids(including verifier channel id checks) + // for collectives in the given HLO. + bool xla_experimental_ignore_channel_id = 330; + + // Next id: 331 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From dc3b02a9df6310e12b77266e75f07d3adf1812ef Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 20 Sep 2024 06:17:38 -0700 Subject: [PATCH 065/483] Also handle xla::gpu::LoopOp when computing range annotations. PiperOrigin-RevId: 676820437 --- .../gpu/fusions/transforms/simplify_affine.cc | 11 +++++ .../transforms/tests/simplify_arith.mlir | 44 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc index acbd9d3735ea46..9df80ffa1922db 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc @@ -357,6 +357,17 @@ std::optional GetIVRange(mlir::Value iv) { return {{lb.getSExtValue(), ub.getSExtValue() - 1}}; } } + if (auto loop_op = mlir::dyn_cast(parent)) { + const auto& indexing_map = loop_op.getIndexingMap(); + if (bbarg.getArgNumber() >= loop_op.getNumInductionVars() && + bbarg.getArgNumber() < + loop_op.getNumInductionVars() + indexing_map.GetNumResults()) { + RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator(); + return range_evaluator.ComputeExpressionRange( + indexing_map.GetAffineMap().getResult(bbarg.getArgNumber() - + loop_op.getNumInductionVars())); + } + } return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir index 5f776d0f338862..aaeb665815dcc5 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir @@ -291,3 +291,47 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, } // CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d2 * 32768 + (d0 * 4 + d1 * 512 + d3) floordiv 9), // CHECK-LABEL: func.func @refine_constraints_for_symbol + +// ----- + +#map = #xla_gpu.indexing_map<(d0, d1, d2, d3, d4, d5)[s0] -> ((d0 * 4 + s0) floordiv 6, (d0 * 4 + s0) mod 6), domain: d0 in [0, 29], d1 in [0, 0], d2 in [0, 0], d3 in [0, 0], d4 in [0, 0], d5 in [0, 0], s0 in [0, 3], d0 * 4 + s0 in [0, 29], is_simplified: false> +func.func @dus(%arg0: tensor<20x30xf32>, %arg1: tensor<5x6xf32>, %arg2: i32, %arg3: i32, %arg4: tensor<20x30xf32>) -> tensor<20x30xf32> { + %c24 = arith.constant 24 : index + %c15 = arith.constant 15 : index + %c0 = arith.constant 0 : index + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %thread_id_z = gpu.thread_id z + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %block_id_z = gpu.block_id z + %0 = arith.index_cast %arg2 : i32 to index + %1 = arith.minsi %0, %c15 : index + %2 = arith.maxsi %1, %c0 : index + %3 = arith.index_cast %arg3 : i32 to index + %4 = arith.minsi %3, %c24 : index + %5 = arith.maxsi %4, %c0 : index + %xla_loop = xla_gpu.loop (%thread_id_x, %thread_id_y, %thread_id_z, %block_id_x, %block_id_y, %block_id_z)[%i] -> (%ra, %rb) in #map iter_args(%iter = %arg4) -> (tensor<20x30xf32>) { + %6 = arith.addi %2, %ra : index + %7 = arith.addi %5, %rb : index + %extracted = tensor.extract %arg1[%ra, %rb] : tensor<5x6xf32> + %inserted = tensor.insert %extracted into %iter[%6, %7] : tensor<20x30xf32> + xla_gpu.yield %inserted : tensor<20x30xf32> + } + return %xla_loop : tensor<20x30xf32> +} + +// CHECK-LABEL: func.func @dus +// CHECK: arith.minsi +// CHECK-SAME: xla.range = [-9223372036854775808 : index, 15 : index] +// CHECK: arith.maxsi +// CHECK-SAME: xla.range = [0 : index, 15 : index] +// CHECK: arith.minsi +// CHECK-SAME: xla.range = [-9223372036854775808 : index, 24 : index] +// CHECK: arith.maxsi +// CHECK-SAME: xla.range = [0 : index, 24 : index] +// CHECK: xla_gpu.loop +// CHECK: arith.addi +// CHECK-SAME: xla.range = [0 : index, 19 : index] +// CHECK: arith.addi +// CHECK-SAME: xla.range = [0 : index, 29 : index] From e38d5b891db15f5b1414c61279e86f7a78a63244 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Fri, 20 Sep 2024 06:21:24 -0700 Subject: [PATCH 066/483] Remove calls to XNNPack from maximum_minimum kernel PiperOrigin-RevId: 676821213 --- tensorflow/lite/kernels/BUILD | 11 ---- tensorflow/lite/kernels/maximum_minimum.cc | 61 ---------------------- 2 files changed, 72 deletions(-) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 003cd5c5b9e968..4c08175c3abe93 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -852,17 +852,6 @@ BUILTIN_KERNEL_DEPS = [ ":eigen_support", "//tensorflow/lite/kernels/internal:optimized_eigen", ], -}) + select({ - "//tensorflow/lite:tflite_with_xnnpack_explicit_false": [], - "//conditions:default": [ - "@pthreadpool", - ], -}) + select({ - "//tensorflow/lite:tflite_with_xnnpack_explicit_false": [], - "//tensorflow/lite:tflite_kernel_use_xnnpack_false": [], - "//conditions:default": [ - "@XNNPACK", - ], }) + select({ # This select must match the similar select in `copts` "//tensorflow:linux_ppc64le": [], diff --git a/tensorflow/lite/kernels/maximum_minimum.cc b/tensorflow/lite/kernels/maximum_minimum.cc index 3a56d171474645..08e6d991d6cd2c 100644 --- a/tensorflow/lite/kernels/maximum_minimum.cc +++ b/tensorflow/lite/kernels/maximum_minimum.cc @@ -26,16 +26,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/kernel_util.h" -#ifdef TFLITE_KERNEL_USE_XNNPACK -#include -#include -#include - -#include "xnnpack.h" // from @XNNPACK -#include "tensorflow/lite/kernels/cpu_backend_context.h" -#include "tensorflow/lite/minimal_logging.h" -#endif // TFLITE_KERNEL_USE_XNNPACK - namespace tflite { namespace ops { namespace builtin { @@ -175,57 +165,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (op_context.output->type) { case kTfLiteFloat32: { -#ifdef TFLITE_KERNEL_USE_XNNPACK - size_t num_input1_dims = static_cast( - GetTensorShape(op_context.input1).DimensionsCount()); - size_t num_input2_dims = static_cast( - GetTensorShape(op_context.input2).DimensionsCount()); - if (std::max(num_input1_dims, num_input2_dims) < XNN_MAX_TENSOR_DIMS) { - std::array input1_shape; - std::array input2_shape; - for (size_t i = 0; i < num_input1_dims; ++i) { - input1_shape[i] = GetTensorShape(op_context.input1).Dims(i); - } - for (size_t i = 0; i < num_input2_dims; ++i) { - input2_shape[i] = GetTensorShape(op_context.input2).Dims(i); - } - CpuBackendContext* cpu_backend_context = - CpuBackendContext::GetFromContext(context); - pthreadpool_t threadpool = - cpu_backend_context->get_xnnpack_threadpool(); - enum xnn_status status = xnn_status_invalid_parameter; - if (std::is_same::value) { - status = xnn_run_maximum_nd_f32( - num_input1_dims, input1_shape.data(), num_input2_dims, - input2_shape.data(), GetTensorData(op_context.input1), - GetTensorData(op_context.input2), - GetTensorData(op_context.output), - /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool); - if (status != xnn_status_success) { - TFLITE_LOG(TFLITE_LOG_INFO, - "Failed to run xnn_run_maximum_nd_f32. Error code: %d", - status); - TFLiteOperation(context, node, - op_context); - } - } else if (std::is_same::value) { - status = xnn_run_minimum_nd_f32( - num_input1_dims, input1_shape.data(), num_input2_dims, - input2_shape.data(), GetTensorData(op_context.input1), - GetTensorData(op_context.input2), - GetTensorData(op_context.output), - /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool); - if (status != xnn_status_success) { - TFLITE_LOG(TFLITE_LOG_INFO, - "Failed to run xnn_run_minimum_nd_f32. Error code: %d", - status); - TFLiteOperation(context, node, - op_context); - } - } - break; - } -#endif TFLiteOperation(context, node, op_context); break; } From 900423a998f329b0dcce4467cabcd2dd45043912 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 20 Sep 2024 08:21:23 -0700 Subject: [PATCH 067/483] [HLO Componentization] Create hlo/translate sub-component (Phase II). This CL takes care of 1. Migrating external projects dependencies from xla/translate --> xla/hlo/translate Phase I takes care of 1. Migrating xla/translate --> xla/hlo/translate 2. Setting up build aliases in xla/translate ensuring external dependencies are still satisfied. PiperOrigin-RevId: 676851420 --- tensorflow/compiler/mlir/BUILD | 4 ++-- tensorflow/compiler/mlir/lite/BUILD | 2 +- tensorflow/compiler/mlir/lite/python/BUILD | 2 +- tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc | 2 +- tensorflow/compiler/mlir/lite/tf_tfl_translate.cc | 2 +- tensorflow/compiler/mlir/tensorflow/BUILD | 2 +- .../compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc | 2 +- tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD | 2 +- .../mlir/tf2xla/internal/utils/test_metadata_config.cc | 2 +- tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD | 2 +- tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc | 2 +- tensorflow/compiler/tf2xla/BUILD | 4 ++-- tensorflow/compiler/tf2xla/xla_helpers.h | 2 +- third_party/xla/xla/mlir/framework/tests/BUILD | 2 +- third_party/xla/xla/python/BUILD | 2 +- third_party/xla/xla/python/mlir.cc | 2 +- third_party/xla/xla/service/cpu/BUILD | 4 ++-- third_party/xla/xla/service/cpu/cpu_compiler.cc | 2 +- third_party/xla/xla/service/gpu/BUILD | 4 ++-- third_party/xla/xla/service/gpu/fusions/triton/BUILD | 2 +- .../xla/service/gpu/fusions/triton/triton_fusion_emitter.cc | 2 +- third_party/xla/xla/tools/BUILD | 4 ++-- third_party/xla/xla/tools/run_hlo_module_main.cc | 4 ++-- 23 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index eabb87e1e89913..ef068f28a999e6 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -240,8 +240,8 @@ tf_cc_binary( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", - "@local_xla//xla/translate/hlo_to_mhlo:translate_registration", - "@local_xla//xla/translate/mhlo_to_hlo:translate_registration", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:translate_registration", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:translate_registration", ], ) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 88c47a36ff938f..5cbf2c1db15911 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1567,8 +1567,8 @@ tf_cc_binary( "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:translate", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/translate/hlo_to_mhlo:translate", "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index edddef0e7e992f..9ce36a7d8fdb1f 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -123,9 +123,9 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_xla//xla/service:hlo_parser", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", ], ) diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc index b9558bad138bd5..eba1ebc48ce445 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc @@ -40,9 +40,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/types.pb.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_parser.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index d53c01d45a0e3c..54fdcd08e75ddf 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -53,8 +53,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "xla/hlo/translate/hlo_to_mhlo/translate.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/translate/hlo_to_mhlo/translate.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 1b1674c42d5a58..5b3255fcd53c8f 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1036,9 +1036,9 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/service:hlo_module_config", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@stablehlo//:stablehlo_ops", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 5a29bae67afe01..c6fb08b6813c92 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -51,9 +51,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD index 70f8d206840047..eeb10c1e5a0854 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD @@ -73,7 +73,7 @@ cc_library( "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", "@local_xla//xla:shape_util", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc b/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc index fc11c2dab477cc..79928fd362664a 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/shape.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tsl/platform/errors.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index d3896c65d63f21..90f6d19218f51b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -156,11 +156,11 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_xla//xla/python/ifrt", "@local_xla//xla/service:computation_placer_hdr", "@local_xla//xla/service/llvm_ir:llvm_util", "@local_xla//xla/stream_executor", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index 9156e41928acc7..16e5243d6c495b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -45,12 +45,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/python/ifrt/client.h" #include "xla/service/computation_placer.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/stream_executor/platform_manager.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 91350d7d1b7184..da120ead9d7d62 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -540,9 +540,9 @@ cc_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:layout_util", "@local_xla//xla/service:computation_placer_hdr", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/mhlo_to_hlo:layout_util", ] + if_libtpu([ ":xla_tpu_backend_registration", ]) + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), @@ -715,11 +715,11 @@ cc_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:layout_util", "@local_xla//xla/service:computation_placer_hdr", "@local_xla//xla/service/gpu:gpu_executable_run_options", "@local_xla//xla/service/gpu/runtime:nccl_clique_key", "@local_xla//xla/stream_executor", - "@local_xla//xla/translate/mhlo_to_hlo:layout_util", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 9765ac8bc84db2..9e543f5e2e9d74 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -26,8 +26,8 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" #include "xla/service/computation_placer.h" -#include "xla/translate/mhlo_to_hlo/layout_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/third_party/xla/xla/mlir/framework/tests/BUILD b/third_party/xla/xla/mlir/framework/tests/BUILD index e0311ea4ac362d..f7278a1241eb99 100644 --- a/third_party/xla/xla/mlir/framework/tests/BUILD +++ b/third_party/xla/xla/mlir/framework/tests/BUILD @@ -17,7 +17,7 @@ lit_test_suite( ), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/translate:xla-translate-opt", + "//xla/hlo/translate:xla-translate-opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 6f60b9242747b4..0de43b96eb19a9 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -957,6 +957,7 @@ cc_library( "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", "//xla/client:xla_computation", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/mlir/utils:error_util", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", @@ -964,7 +965,6 @@ cc_library( "//xla/pjrt:status_casters", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", - "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/python/mlir.cc b/third_party/xla/xla/python/mlir.cc index e4aab79317e9db..b14cd236ff18a8 100644 --- a/third_party/xla/xla/python/mlir.cc +++ b/third_party/xla/xla/python/mlir.cc @@ -41,6 +41,7 @@ limitations under the License. #include "stablehlo/dialect/Serialization.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/client/xla_computation.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -49,7 +50,6 @@ limitations under the License. #include "xla/python/refine_polymorphic_shapes.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index d56ecda8559fd5..72f4d1cbb613ac 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -244,6 +244,8 @@ cc_library( "//xla/hlo/ir:hlo_module_group", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", "//xla/mlir_hlo:mhlo_passes", @@ -336,8 +338,6 @@ cc_library( "//xla/service/spmd/shardy:shardy_xla_pass", "//xla/stream_executor", "//xla/stream_executor/host:host_platform_id", - "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/translate/hlo_to_mhlo:hlo_utils", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index ada4b7ba8dfae8..ee97d2afa02402 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -89,6 +89,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/literal.h" #include "xla/map_util.h" #include "xla/mlir_hlo/transforms/passes.h" @@ -197,7 +198,6 @@ limitations under the License. #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 22beda5486c796..90dcee55408629 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1409,6 +1409,8 @@ cc_library( "//xla/hlo/ir:hlo_module_group", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", + "//xla/hlo/translate/mhlo_to_hlo:location_exporter", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service/gpu/autotuning:autotuner_util", "//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner", @@ -1568,8 +1570,6 @@ cc_library( "//xla/stream_executor:dnn", "//xla/stream_executor:platform_manager", "//xla/stream_executor:semantic_version", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "//xla/translate/mhlo_to_hlo:location_exporter", "//xla/tsl/lib/monitoring:counter", "//xla:autotune_results_proto_cc", "//xla:debug_options_flags", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 4f1d2f65e82360..029032b3c76268 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -46,6 +46,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", "//xla/hlo/utils:hlo_query", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", @@ -75,7 +76,6 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", "//xla/tools:hlo_decomposer_lib", - "//xla/translate/hlo_to_mhlo:hlo_function_importer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index a6e6ac3e2181cc..e9974d3ce1584f 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -100,6 +100,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -136,7 +137,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tools/hlo_decomposer.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 3935af05d34a4c..ddfc7c005d1d51 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -622,13 +622,13 @@ xla_cc_binary( deps = [ ":run_hlo_module_lib", "//xla:debug_options_flags", + "//xla/hlo/translate/mhlo_to_hlo:translate", + "//xla/hlo/translate/stablehlo_to_hlo:translate", "//xla/service:cpu_plugin", "//xla/service:hlo_module_config", "//xla/service:hlo_runner", "//xla/service:interpreter_plugin", "//xla/service:platform_util", - "//xla/translate/mhlo_to_hlo:translate", - "//xla/translate/stablehlo_to_hlo:translate", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/tools/run_hlo_module_main.cc b/third_party/xla/xla/tools/run_hlo_module_main.cc index 92e2e23efad6e0..17795415652bbb 100644 --- a/third_party/xla/xla/tools/run_hlo_module_main.cc +++ b/third_party/xla/xla/tools/run_hlo_module_main.cc @@ -30,12 +30,12 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/ToolOutputFile.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" +#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_runner.h" #include "xla/service/platform_util.h" #include "xla/tools/run_hlo_module.h" -#include "xla/translate/mhlo_to_hlo/translate.h" -#include "xla/translate/stablehlo_to_hlo/translate.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" From d540fd62d027dbb86e9434f8c43802997c8ab29c Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Fri, 20 Sep 2024 08:43:41 -0700 Subject: [PATCH 068/483] Integrate LLVM at llvm/llvm-project@42b696d7b994 Updates LLVM usage to match [42b696d7b994](https://github.com/llvm/llvm-project/commit/42b696d7b994) PiperOrigin-RevId: 676857806 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/temporary.patch | 10 +++++----- third_party/shardy/workspace.bzl | 4 ++-- third_party/xla/third_party/shardy/temporary.patch | 10 +++++----- third_party/xla/third_party/shardy/workspace.bzl | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c011aabc014eda..55290bf02ec4fd 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" - LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" + LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" + LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index c5aa30af88f875..613660a484cee2 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +1,15 @@ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index cd6a8b6..c011aab 100644 +index c011aab..55290bf 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "104f3c180644c8872eaad0b3fcf6a6b948d92a71" -- LLVM_SHA256 = "5caf03c6e40c87e7593ce50bfe53ec52a08677c221f4f611f30b3f40397505b8" -+ LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" -+ LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" +- LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" +- LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" ++ LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" ++ LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index f2425a6d6c98fe..b3863418126e1b 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "a66667eefd65f73d50fab04298f477fc123b6740" - SHARDY_SHA256 = "543407a5fb203959d1189813275402dc5b8af6076203700ddea96a1dd8d981e1" + SHARDY_COMMIT = "f1ed2d559c09f969d34bc870a03b882d5a4ac813" + SHARDY_SHA256 = "73858df3e06c3afad362308baca03e5bf319c31f3227c7b990046389320303c0" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index c5aa30af88f875..613660a484cee2 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,15 +1,15 @@ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index cd6a8b6..c011aab 100644 +index c011aab..55290bf 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "104f3c180644c8872eaad0b3fcf6a6b948d92a71" -- LLVM_SHA256 = "5caf03c6e40c87e7593ce50bfe53ec52a08677c221f4f611f30b3f40397505b8" -+ LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" -+ LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" +- LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" +- LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" ++ LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" ++ LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index f2425a6d6c98fe..b3863418126e1b 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "a66667eefd65f73d50fab04298f477fc123b6740" - SHARDY_SHA256 = "543407a5fb203959d1189813275402dc5b8af6076203700ddea96a1dd8d981e1" + SHARDY_COMMIT = "f1ed2d559c09f969d34bc870a03b882d5a4ac813" + SHARDY_SHA256 = "73858df3e06c3afad362308baca03e5bf319c31f3227c7b990046389320303c0" tf_http_archive( name = "shardy", From c678702aab4849ff243f278a723b476e4a568d0a Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 20 Sep 2024 09:28:55 -0700 Subject: [PATCH 069/483] [PjRt-IFRT] Migrate the include file for IFRT/XLA DType conversion functions This change updates the header include file from `pjrt_array.h` to `pjrt_dtype.h` for IFRT/XLA DType conversion functions. PiperOrigin-RevId: 676871039 --- third_party/xla/xla/python/BUILD | 3 ++- third_party/xla/xla/python/py_array.cc | 1 + third_party/xla/xla/python/py_compile_only_client.cc | 1 + third_party/xla/xla/python/py_values.cc | 2 +- third_party/xla/xla/python/types.cc | 2 +- 5 files changed, 6 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 0de43b96eb19a9..1b73a0f2185af0 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -207,7 +207,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/pjrt:exceptions", "//xla/python/ifrt", - "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:pjrt_dtype", "//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -402,6 +402,7 @@ cc_library( "//xla/python/ifrt/hlo:hlo_program", "//xla/python/pjrt_ifrt", "//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "//xla/python/pjrt_ifrt:pjrt_dtype", "//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index 8a622f031d1adf..5e948abd9efcab 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -75,6 +75,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/py_client.h" #include "xla/python/py_device.h" #include "xla/python/py_values.h" diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc index cccb29b1ecd8d3..fa32529000f6de 100644 --- a/third_party/xla/xla/python/py_compile_only_client.cc +++ b/third_party/xla/xla/python/py_compile_only_client.cc @@ -65,6 +65,7 @@ limitations under the License. #include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" diff --git a/third_party/xla/xla/python/py_values.cc b/third_party/xla/xla/python/py_values.cc index e5d37d0ebdc838..db7a38d0a50363 100644 --- a/third_party/xla/xla/python/py_values.cc +++ b/third_party/xla/xla/python/py_values.cc @@ -46,7 +46,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/py_array.h" #include "xla/python/python_ref_manager.h" #include "xla/python/sharding.h" diff --git a/third_party/xla/xla/python/types.cc b/third_party/xla/xla/python/types.cc index f3b8db6fdae018..eaad6db5f16667 100644 --- a/third_party/xla/xla/python/types.cc +++ b/third_party/xla/xla/python/types.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/python/ifrt/dtype.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" From ffe3f06206602b273fc7c374704df966578ccc66 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 09:33:57 -0700 Subject: [PATCH 070/483] Reverts c772d4766c0ac1e5565016bee7264ab7958220d8 PiperOrigin-RevId: 676872588 --- third_party/xla/xla/tsl/cuda/BUILD.bazel | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/tsl/cuda/BUILD.bazel b/third_party/xla/xla/tsl/cuda/BUILD.bazel index 002fe484ae6990..393cc5d03648cb 100644 --- a/third_party/xla/xla/tsl/cuda/BUILD.bazel +++ b/third_party/xla/xla/tsl/cuda/BUILD.bazel @@ -13,7 +13,6 @@ load( load( "//xla/tsl:tsl.bzl", "if_cuda_libs", - "if_nccl", ) load("//xla/tsl/cuda:stub.bzl", "cuda_stub") @@ -349,10 +348,11 @@ cc_library( deps = if_cuda_is_configured([ "@com_google_absl//absl/container:flat_hash_set", "@local_config_cuda//cuda:cuda_headers", + "@local_config_nccl//:nccl_headers", "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:load_library", - ]) + if_nccl(["@local_config_nccl//:nccl"]), + ]), ) alias( From da7422968abcc981a143d86ebe815c36bb25d817 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 09:34:38 -0700 Subject: [PATCH 071/483] Include the graph ID in the host CUDA graph launch event, add cuda graph id stat in the derived Hlo Op event PiperOrigin-RevId: 676872797 --- tensorflow/core/profiler/utils/BUILD | 1 + .../core/profiler/utils/derived_timeline.cc | 57 +++++++++++++++--- .../core/profiler/utils/derived_timeline.h | 25 ++++++++ .../profiler/utils/derived_timeline_test.cc | 60 +++++++++++++++++++ .../core/profiler/utils/gpu_event_stats.cc | 3 + .../core/profiler/utils/gpu_event_stats.h | 3 +- .../backends/profiler/gpu/cupti_collector.cc | 5 ++ 7 files changed, 144 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index fcd5c26b7e7104..3a0164bee093c6 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -282,6 +282,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/profiler/utils:group_events", "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", "@local_tsl//tsl/profiler/utils:xplane_schema", diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index a895af6a00b259..5cf4b280ec2961 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -140,6 +140,10 @@ DerivedXLineBuilder::DerivedXLineBuilder( int64_t timestamp_ns, std::vector dependent_lines) : group_id_stat_metadata_( plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))), + correlation_id_metadata_(plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kCorrelationId))), + cuda_graph_id_metadata_(plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kCudaGraphId))), line_(plane->GetOrCreateLine(line_id)), dependent_lines_(std::move(dependent_lines)) { line_.SetName(name); @@ -185,13 +189,31 @@ void DerivedXLineBuilder::ExpandOrAddLevelEvent( } } +void DerivedXLineBuilder::AddStatToLevelEvent(int level, + const XStatMetadata& metadata, + int64_t value) { + if (auto it = last_event_by_level_.find(level); + it != last_event_by_level_.end() && it->second.has_value()) { + it->second->SetOrAddStatValue(metadata, value); + } +} + +void DerivedXLineBuilder::AddStatToLevelEvent(int level, + const XStatMetadata& metadata, + uint64_t value) { + if (auto it = last_event_by_level_.find(level); + it != last_event_by_level_.end() && it->second.has_value()) { + it->second->SetOrAddStatValue(metadata, value); + } +} + // When deriving a bunch of events with the same timespan, there could be // indeterministic behavior of how trace viewer stacking these events. // This function will shrink the stack of events with the same timespan when -// necessary. Event at top of stack might shrink more than event at the bottom. -// Because the time unit in trace viewer is nanosecond, therefore the minimum -// difference is 1ns. However to prevent shrink induced inconsitency, we can -// not shrink more than the duration of event at the top of the stack. +// necessary. Event at top of stack might shrink more than event at the +// bottom. Because the time unit in trace viewer is nanosecond, therefore the +// minimum difference is 1ns. However to prevent shrink induced inconsitency, +// we can not shrink more than the duration of event at the top of the stack. void DerivedXLineBuilder::AdjustDurationForTraceViewer(int level) { if (level >= last_event_by_level_.size() || !last_event_by_level_[level]) return; @@ -286,8 +308,8 @@ void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver, GetSortedEvents(plane_visitor)) { GpuEventStats stats(&event); // For HLO/TF op lines, only use kernel events (i.e. excluding memcpy or - // allocation events). Also CudaGraph executions are also treated as kernel - // events. + // allocation events). Also CudaGraph executions are also treated as + // kernel events. if (!stats.IsKernel() && !stats.IsCudaGraphExecution()) continue; tsl::profiler::Timespan event_span = event.GetTimespan(); @@ -300,9 +322,26 @@ void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver, if (stats.IsXlaOp()) { auto symbol = symbol_resolver(stats.program_id, stats.hlo_module_name, stats.hlo_op_names.back()); - hlo_ops.ExpandOrAddEvents( - GetOrCreateHloOpEventsMetadata(plane_builder, stats, symbol), - event_span, stats.group_id); + auto hlo_events_metadata = + GetOrCreateHloOpEventsMetadata(plane_builder, stats, symbol); + hlo_ops.ExpandOrAddEvents(hlo_events_metadata, event_span, + stats.group_id); + // If the kernel event is nodes of a CudaGraph or a whole cuda graph + // exec, try to mark extra stats to to corresponding XLA op event here. + if (stats.cuda_graph_id_for_inner_node.has_value() && + *stats.cuda_graph_id_for_inner_node != 0) { + int level = static_cast(hlo_events_metadata.size()) - 1; + if (level >= 0) { + hlo_ops.AddStatToLevelEvent(level, *hlo_ops.GetCudaGraphIdMetadata(), + *stats.cuda_graph_id_for_inner_node); + if (stats.correlation_id.has_value()) { + hlo_ops.AddStatToLevelEvent(level, + *hlo_ops.GetCorrelationIdMetadata(), + *stats.correlation_id); + } + } + } + if (!symbol.tf_op_name.empty()) { ProcessTfOpEvent(symbol.tf_op_name, event_span, stats.group_id, plane_builder, diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h index 72583c2a79d772..535f5041269a21 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.h +++ b/tensorflow/core/profiler/utils/derived_timeline.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tsl/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/utils/group_events.h" #include "tsl/profiler/utils/timespan.h" @@ -46,6 +47,11 @@ class DerivedXEventBuilder { event_.SetTimespan(event_span); } + template + void SetOrAddStatValue(const XStatMetadata& metadata, ValueT&& value) { + event_.SetOrAddStatValue(metadata, std::forward(value)); + } + private: XEventBuilder event_; std::optional group_id_; @@ -79,6 +85,22 @@ class DerivedXLineBuilder { // Reset the last events lower than or equal to the given level. void ResetLastEvents(int level = 0); + // To avoid using templates while need hide its implementation in .cc file, + // use two functions to set stat value for int64_t and uint64_t here. + void AddStatToLevelEvent(int level, const XStatMetadata& metadata, + int64_t value); + + void AddStatToLevelEvent(int level, const XStatMetadata& metadata, + uint64_t value); + + const XStatMetadata* GetCorrelationIdMetadata() const { + return correlation_id_metadata_; + } + + const XStatMetadata* GetCudaGraphIdMetadata() const { + return cuda_graph_id_metadata_; + } + private: // If the last event of the given level has the same metadata, expands it to // include the time until the given event's end time. @@ -92,6 +114,9 @@ class DerivedXLineBuilder { void AdjustDurationForTraceViewer(int level); const XStatMetadata* group_id_stat_metadata_ = nullptr; + const XStatMetadata* correlation_id_metadata_ = nullptr; + const XStatMetadata* cuda_graph_id_metadata_ = nullptr; + XLineBuilder line_; absl::flat_hash_map> last_event_by_level_; diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc index f1f0daca282358..3bc01244a372ee 100644 --- a/tensorflow/core/profiler/utils/derived_timeline_test.cc +++ b/tensorflow/core/profiler/utils/derived_timeline_test.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include +#include #include "absl/strings/string_view.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -291,6 +293,64 @@ TEST(DerivedTimelineTest, TfOpNameScopeShrinkTest) { } } +// Checks that XLA Ops mapping to CudaGraph launch has extra stats. +TEST(DerivedTimelineTest, XloOpHasCudaGraphStats) { + constexpr absl::string_view kModuleName = "module"; + constexpr absl::string_view kHloOpName = "op_level_2"; + constexpr absl::string_view kKernelDetails = "kernel_details"; + constexpr int64_t kGroupIdValue = 1; + constexpr int64_t kCorrelationIdValue = 10000; + const uint64_t kCudaGraphIdValue = 20; + XSpace space; + tsl::profiler::GroupMetadataMap group_metadata_map; + + // Build Input Plane/Line/Events and derive events from them. + XPlane& plane = *GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); + XPlaneBuilder plane_builder(&plane); + auto line_builder = plane_builder.GetOrCreateLine(0); + CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, + {{StatType::kKernelDetails, kKernelDetails}, + {StatType::kGroupId, kGroupIdValue}, + {StatType::kHloModule, kModuleName}, + {StatType::kHloOp, kHloOpName}, + {StatType::kCorrelationId, kCorrelationIdValue}, + {StatType::kCudaGraphId, kCudaGraphIdValue}}); + CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, + {{StatType::kKernelDetails, kKernelDetails}, + {StatType::kGroupId, kGroupIdValue}, + {StatType::kHloModule, kModuleName}, + {StatType::kHloOp, kHloOpName}, + {StatType::kCorrelationId, kCorrelationIdValue}, + {StatType::kCudaGraphId, kCudaGraphIdValue}}); + GenerateDerivedTimeLines(group_metadata_map, &space); + + // Check that the HLO op line is added and has the extra stats for the first + // derived event. + size_t num_hlo_op_line = 0; + size_t num_events = 0; + std::optional correlation_id; + std::optional cuda_graph_id; + XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&plane); + plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { + if (line_visitor.Id() == kThreadIdHloOp) { + num_hlo_op_line++; + if (num_hlo_op_line == 1) { + num_events = line_visitor.NumEvents(); + line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { + correlation_id = event_visitor.GetStat(StatType::kCorrelationId); + cuda_graph_id = event_visitor.GetStat(StatType::kCudaGraphId); + }); + } + } + }); + EXPECT_EQ(num_hlo_op_line, 1); + EXPECT_EQ(num_events, 1); + ASSERT_TRUE(correlation_id.has_value()); + EXPECT_EQ(correlation_id->IntValue(), kCorrelationIdValue); + ASSERT_TRUE(cuda_graph_id.has_value()); + EXPECT_EQ(cuda_graph_id->UintValue(), kCudaGraphIdValue); +} + TEST(DerivedTimelineTest, DeriveLinesForXlaCpuOps) { XPlane xplane; XPlaneBuilder plane_builder(&xplane); diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.cc b/tensorflow/core/profiler/utils/gpu_event_stats.cc index c5b880c4498fe2..cd81aea0842dd8 100644 --- a/tensorflow/core/profiler/utils/gpu_event_stats.cc +++ b/tensorflow/core/profiler/utils/gpu_event_stats.cc @@ -70,6 +70,9 @@ GpuEventStats::GpuEventStats(const XEventVisitor* event) { case StatType::kCudaGraphExecId: cuda_graph_exec_id = stat.UintValue(); break; + case StatType::kCudaGraphId: + cuda_graph_id_for_inner_node = stat.UintValue(); + break; default: break; } diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.h b/tensorflow/core/profiler/utils/gpu_event_stats.h index 8b9ac5ae75c62d..7740e41ce5ac9c 100644 --- a/tensorflow/core/profiler/utils/gpu_event_stats.h +++ b/tensorflow/core/profiler/utils/gpu_event_stats.h @@ -56,7 +56,8 @@ struct GpuEventStats { // Stats derived by grouping. std::optional group_id; bool is_eager = false; - std::optional cuda_graph_exec_id; + std::optional cuda_graph_exec_id; + std::optional cuda_graph_id_for_inner_node; }; // Stats for a host-side GPU launch XEvent. diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc index 6191849b0d0944..937b85f9a6c8fa 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc @@ -157,6 +157,11 @@ class PerDeviceCollector { if (kernel_name.empty()) { kernel_name = GetTraceEventTypeName(event.type); } + // For CPU events like cuGraphLaunch(), add the graph id to the name. + if (event.graph_id != 0 && event.type == CuptiTracerEventType::CudaGraph && + event.source == CuptiTracerEventSource::DriverCallback) { + absl::StrAppend(&kernel_name, " (CudaGraph:", event.graph_id, ")"); + } XEventMetadata* event_metadata = plane->GetOrCreateEventMetadata(std::move(kernel_name)); XEventBuilder xevent = line->AddEvent(*event_metadata); From 55ca3b150c1913161aa9ca63d4e8fcae172444b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 10:34:10 -0700 Subject: [PATCH 072/483] Add TF wheel API test. This test verifies whether the API v2 packages can be imported from the current build. It utilizes the `_api/v2/api_packages.txt` list of packages from the local wheel file specified in the `requirements_lock_.txt`. The test should be executed after the TF wheel was built and put into `dist` dir inside Tensorflow repository. PiperOrigin-RevId: 676893008 --- .bazelrc | 10 +- .gitignore | 1 + WORKSPACE | 5 + .../requirements_updater/requirements.in | 5 + ci/official/utilities/code_check_full.bats | 2 +- ci/official/wheel.sh | 1 + ci/official/wheel_test/BUILD | 5 - ci/official/wheel_test/README.md | 94 ------------------- ci/official/wheel_test/WORKSPACE | 65 ------------- ci/official/wheel_test/update_requirements.sh | 53 ----------- requirements_lock_3_10.txt | 39 ++++++++ requirements_lock_3_11.txt | 39 ++++++++ requirements_lock_3_12.txt | 39 ++++++++ requirements_lock_3_9.txt | 39 ++++++++ tensorflow/opensource_only.files | 1 - tensorflow/tensorflow.bzl | 7 ++ tensorflow/tools/pip_package/BUILD | 34 ++++++- tensorflow/tools/pip_package/MANIFEST.in | 2 +- .../pip_package/import_api_packages_test.py | 45 ++++++--- third_party/py/python_repo.bzl | 19 +--- third_party/xla/.bazelrc | 10 +- .../xla/third_party/py/python_repo.bzl | 19 +--- third_party/xla/third_party/tsl/.bazelrc | 10 +- .../tsl/third_party/py/python_repo.bzl | 19 +--- third_party/xla/xla/tsl/BUILD | 14 +++ 25 files changed, 275 insertions(+), 302 deletions(-) delete mode 100644 ci/official/wheel_test/BUILD delete mode 100644 ci/official/wheel_test/README.md delete mode 100644 ci/official/wheel_test/WORKSPACE delete mode 100644 ci/official/wheel_test/update_requirements.sh rename ci/official/wheel_test/test_import_api_packages.py => tensorflow/tools/pip_package/import_api_packages_test.py (66%) diff --git a/.bazelrc b/.bazelrc index 8201ce4582a00f..f5b06fd21d937e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/.gitignore b/.gitignore index 614cde3446a16f..643ffca1c45c99 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ tensorflow/contrib/cmake/_build/ /api_init_files_list.txt /estimator_api_init_files_list.txt *.whl +dist # Android .gradle diff --git a/WORKSPACE b/WORKSPACE index 32ffd0433108c7..269256eadd32c1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -32,6 +32,11 @@ load("@local_xla//third_party/py:python_init_repositories.bzl", "python_init_rep python_init_repositories( default_python_version = "system", + local_wheel_dist_folder = "dist", + local_wheel_inclusion_list = [ + "tensorflow*", + ], + local_wheel_workspaces = ["//:WORKSPACE"], requirements = { "3.9": "//:requirements_lock_3_9.txt", "3.10": "//:requirements_lock_3_10.txt", diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 749d097c6997b1..305eaee3dce946 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -26,3 +26,8 @@ requests >= 2.31.0 packaging==23.2 setuptools==70.0.0 jax==0.4.7 +# The dependencies below are needed for TF wheel testing. +tensorflow-io-gcs-filesystem==0.37.1 +libclang >= 13.0.0 +google_pasta ~= 0.2 +flatbuffers ~= 24.3.25 diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index d414a88ecfad36..753b4955b6c3a4 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -316,7 +316,7 @@ EOF # See b/279852433 (internal). # TODO(b/279852433) Replace deps(//tensorflow/...) with deps(//...) @test "Verify that it's possible to query every TensorFlow target without BUILD errors" { - bazel query "deps(//tensorflow/...)" > /dev/null + bazel query "deps(//tensorflow/... -//tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/tools/pip_package:import_api_packages_test_gpu)" > /dev/null } teardown_file() { diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index ec05db2716ce63..c2cdf8f2e5e95f 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -29,6 +29,7 @@ fi tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --config=cuda_wheel //tensorflow/tools/pip_package:wheel $TFCI_BUILD_PIP_PACKAGE_ARGS tfrun find ./bazel-bin/tensorflow/tools/pip_package -iname "*.whl" -exec cp {} $TFCI_OUTPUT_DIR \; +tfrun ln -s $TFCI_OUTPUT_DIR dist tfrun ./ci/official/utilities/rename_and_verify_wheels.sh if [[ "$TFCI_ARTIFACT_STAGING_GCS_ENABLE" == 1 ]]; then diff --git a/ci/official/wheel_test/BUILD b/ci/official/wheel_test/BUILD deleted file mode 100644 index 3cca20e70a545b..00000000000000 --- a/ci/official/wheel_test/BUILD +++ /dev/null @@ -1,5 +0,0 @@ -py_test( - name = "test_import_api_packages", - srcs = ["test_import_api_packages.py"], - deps = ["@pypi_tensorflow//:pkg"], -) diff --git a/ci/official/wheel_test/README.md b/ci/official/wheel_test/README.md deleted file mode 100644 index cef4131e63f579..00000000000000 --- a/ci/official/wheel_test/README.md +++ /dev/null @@ -1,94 +0,0 @@ -## Wheel Test - -This directory is dedicated to tests that require a built TensorFlow wheel -file for testing, such as: - -* Ensuring the entire API is importable -* Testing downstream projects against the wheel - -Ensure you have Bazel installed and accessible from your command line. - -These tests use hermetic Python. They also require a built TensorFlow wheel file -and a requirements_lock file. The requirements_lock file is generated by the -[requirements_updater](https://github.com/tensorflow/tensorflow/tree/master/ci/official/requirements_updater) -tool using the path to this wheel file. - -### Hermetic Python - -For details about hermetic Python and setting its toolchain version, see -[requirements updater readme](https://github.com/tensorflow/tensorflow/blob/master/ci/official/requirements_updater/README.md) - -### Prerequisites for Local Testing - -To run tests locally, follow these steps: - -1. Navigate to the relevant directory: - ``` - cd ci/official/wheel_test - ``` -2. Run a script for creating requirements file: - ``` - bash update_requirements.sh - e.g.: - bash update_requirements.sh /tmp/tensorflow-2.14.0-cp311-cp311-linux_x86_64.whl 3_11 - ``` - -#### Requirements Updater Script -This script automates the process of updating TensorFlow requirements for a -specific Python version. - -##### Parameters -`path_to_tensorflow_wheel`: The local path to the TensorFlow wheel file. -Example: `/tmp/tensorflow-2.14.0-cp311-cp311-linux_x86_64.whl` - -`python_version`: The target Python version, replacing `.` with `_`. -Example: For Python 3.11, use `3_11` - -The script performs the following steps: - -1. Navigates to the `../requirements_updater` directory. -2. Creates a `requirements_wheel_test.in` file and specifies the path -to the actual TensorFlow wheel. -3. Creates a `requirements_lock_.txt` file. -4. Updates the `requirements_lock_.txt` file using -a Bazel command. -5. Moves the updated `requirements_lock_.txt` file -to the `../wheel_test/` directory. - - -### How it Works in the Presubmit Job - -`_requirements_lock` files will be generated by the presubmit job. A detailed -description will be provided once it's integrated into presubmit. - -### test_import_api_packages - -This Python test verifies whether the API v2 packages can be imported from the -current build. It utilizes the `_api/v2/api_packages.txt` list of packages from -the local wheel file specified in the `requirements_lock_.txt`. - -Packages are imported one by one in alphabetical order during runtime. - -The test doesn't identify package's order-dependent issues; for instance, -importing "tf.foo" followed by "tf.bar" won't reveal that "tf.bar" depends on -"tf.foo" being imported first. - -The `_api/v2/api_packages.txt` file is generated during the TensorFlow API v2 -init files creation process and is subsequently stored in the wheel file after -the build. It also contains a few paths that cannot be directly imported. These -paths point to attributes or sub-modules within a module's namespace, but they -don't correspond to an actual file or directory on the filesystem. The list of -such paths is stored in the packages_for_skip variable and will be skipped -during the test. - -##### How to Build - -``` -bazel build //:test_import_api_packages -``` - -##### How to Run - -``` -bazel test //:test_import_api_packages --test_output=all -``` diff --git a/ci/official/wheel_test/WORKSPACE b/ci/official/wheel_test/WORKSPACE deleted file mode 100644 index db46144dadbbb1..00000000000000 --- a/ci/official/wheel_test/WORKSPACE +++ /dev/null @@ -1,65 +0,0 @@ -# buildifier: disable=load-on-top - -workspace(name = "wheel_test") - -# buildifier: disable=load-on-top - -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - -http_archive( - name = "bazel_skylib", - sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", - "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", - ], -) - -http_archive( - name = "rules_python", - sha256 = "9d04041ac92a0985e344235f5d946f71ac543f1b1565f2cdbc9a2aaee8adf55b", - strip_prefix = "rules_python-0.26.0", - url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", -) - -# buildifier: disable=same-origin-load -load("@rules_python//python:repositories.bzl", "py_repositories") - -py_repositories() - -## Load HERMETIC_PYTHON_VERSION variable -local_repository( - name = "local_tensorflow", - path = "../../..", -) - -load( - "@local_tensorflow//tensorflow/tools/toolchains/python:python_repo.bzl", - "python_repository", -) - -python_repository(name = "python_version_repo") - -load("@python_version_repo//:py_version.bzl", "TF_PYTHON_VERSION") - -# Register multi toolchains -load("@rules_python//python:repositories.bzl", "python_register_toolchains") # buildifier: disable=same-origin-load - -python_register_toolchains( - name = "python", - ignore_root_user_error = True, - python_version = TF_PYTHON_VERSION, -) - -load("@python//:defs.bzl", "interpreter") -load("@rules_python//python:pip.bzl", "pip_parse") - -pip_parse( - name = "pypi", - python_interpreter_target = interpreter, - requirements = "//:requirements_lock_" + TF_PYTHON_VERSION.replace(".", "_") + ".txt", -) - -load("@pypi//:requirements.bzl", "install_deps") - -install_deps() diff --git a/ci/official/wheel_test/update_requirements.sh b/ci/official/wheel_test/update_requirements.sh deleted file mode 100644 index bed56273b48952..00000000000000 --- a/ci/official/wheel_test/update_requirements.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# script to run pip-compile for keras, tensorboard, estimator deps. -# if there is a change in requirements.in then all lock files will be updated -# accordingly. - -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euo pipefail -o history - -# Check for required arguments -if [ -z "$1" ]; then - echo "Usage: $0 " - exit 1 -fi - -TENSORFLOW_WHEEL_PATH="$1" -PYTHON_VERSION="$2" - -# All commands run relative to this directory -cd "$(dirname "${BASH_SOURCE[0]}")" -cd ../requirements_updater || exit 1 - -# Create the requirements_wheel_test.in file -echo "tensorflow @ file://localhost/$TENSORFLOW_WHEEL_PATH" > requirements_wheel_test.in - -# Create the requirements_lock file -REQUIREMENTS_LOCK_FILE="requirements_lock_${PYTHON_VERSION}.txt" -touch "$REQUIREMENTS_LOCK_FILE" - -### Update the requirements_lock file -bazel run --experimental_convenience_symlinks=ignore --repo_env=REQUIREMENTS_FILE_NAME=requirements_wheel_test.in //:requirements_${PYTHON_VERSION}.update - -# Move the updated file to the appropriate directory -mv "$REQUIREMENTS_LOCK_FILE" ../wheel_test/ - -echo "All tasks completed successfully." diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index e058fe135f93f6..3530cd725b6d78 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -166,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -256,6 +265,18 @@ keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r ci/official/requirements_updater/requirements.in @@ -518,6 +539,24 @@ tensorboard-data-server==0.7.2 \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index d57e2df7c8abf1..dced21fd467a60 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -166,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -256,6 +265,18 @@ keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r ci/official/requirements_updater/requirements.in @@ -518,6 +539,24 @@ tensorboard-data-server==0.7.2 \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 46778af8ee1b4b..581778cdc49d64 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -166,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -256,6 +265,18 @@ keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r ci/official/requirements_updater/requirements.in @@ -518,6 +539,24 @@ tensorboard-data-server==0.7.2 \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 87287e12978f31..5b0533315ddd82 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -166,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -260,6 +269,18 @@ keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r ci/official/requirements_updater/requirements.in @@ -522,6 +543,24 @@ tensorboard-data-server==0.7.2 \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index c08da1f62201fc..8344a9b4dc0a32 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -1,5 +1,4 @@ tf_staging/BUILD: -tf_staging/ci/official/wheel_test/BUILD: tf_staging/tensorflow/__init__:.py tf_staging/tensorflow/api_template.__init__:.py tf_staging/tensorflow/api_template_v1.__init__:.py diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 7fb0387e29a065..f76b13bef0aecd 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -3580,3 +3580,10 @@ def tf_python_framework_friends(): def if_cuda_tools(if_true, if_false = []): return _if_cuda_tools(if_true, if_false) + +# The config is used to determine if we need dependency on pre-built wheels. +def if_wheel_dependency(if_true, if_false = []): + return select({ + "@local_xla//xla/tsl:enable_wheel_dependency": if_true, + "//conditions:default": if_false, + }) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c6ba8762df3b0f..5c1007c272ddfe 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -4,7 +4,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") load("@local_xla//xla/tsl/mkl:build_defs.bzl", "if_enable_mkl", "if_mkl", "if_mkl_ml") -load("//tensorflow:tensorflow.bzl", "if_with_tpu_support", "transitive_hdrs") +load("//tensorflow:tensorflow.bzl", "if_wheel_dependency", "if_with_tpu_support", "transitive_hdrs") load("//tensorflow/core/platform:build_config_root.bzl", "tf_additional_license_deps") load("//tensorflow/tools/pip_package/utils:data_deps.bzl", "collect_data_files") load("//tensorflow/tools/pip_package/utils:py_deps.bzl", "transitive_py_deps") @@ -292,3 +292,35 @@ tf_wheel( ":xla_cmake", ], ) + +genrule( + name = "empty_test", + outs = ["empty_test.py"], + cmd = "echo '' > $@", +) + +py_test( + name = "import_api_packages_test_cpu", + srcs = if_wheel_dependency( + ["import_api_packages_test.py"], + [":empty_test"], + ), + main = if_wheel_dependency("import_api_packages_test.py", "empty_test.py"), + tags = [ + "cpu", + ], + deps = if_wheel_dependency(["@pypi_tensorflow//:pkg"]), +) + +py_test( + name = "import_api_packages_test_gpu", + srcs = if_wheel_dependency( + ["import_api_packages_test.py"], + [":empty_test"], + ), + main = if_wheel_dependency("import_api_packages_test.py", "empty_test.py"), + tags = [ + "gpu", + ], + deps = if_wheel_dependency(["@pypi_tensorflow//:pkg"]), +) diff --git a/tensorflow/tools/pip_package/MANIFEST.in b/tensorflow/tools/pip_package/MANIFEST.in index dafc500d7f4106..bd8ada77aef5f9 100644 --- a/tensorflow/tools/pip_package/MANIFEST.in +++ b/tensorflow/tools/pip_package/MANIFEST.in @@ -12,4 +12,4 @@ recursive-include * *.csv recursive-include tensorflow * recursive-exclude tensorflow *.md recursive-exclude tensorflow/_api/ * -include tensorflow/_api/api_packages.txt \ No newline at end of file +include tensorflow/_api/v2/api_packages.txt \ No newline at end of file diff --git a/ci/official/wheel_test/test_import_api_packages.py b/tensorflow/tools/pip_package/import_api_packages_test.py similarity index 66% rename from ci/official/wheel_test/test_import_api_packages.py rename to tensorflow/tools/pip_package/import_api_packages_test.py index 1c9fb5365500b0..ca8849fef03978 100644 --- a/ci/official/wheel_test/test_import_api_packages.py +++ b/tensorflow/tools/pip_package/import_api_packages_test.py @@ -15,19 +15,33 @@ """Import API packages test. -This is a Python test that verifies whether API v2 packages can be imported -from the current build or not. - -It uses the `_api/v2/api_packages.txt` file from the local wheel file. -The `_api/v2/api_packages.txt` file is created during the process of generating -TensorFlow API v2 init files and is stored in the wheel file after the build. - -See README.md file for "how to run" instruction. +This Python test verifies whether the API v2 packages can be imported from the +current build. It utilizes the `_api/v2/api_packages.txt` list of packages from +the local wheel file specified in the `requirements_lock_.txt`. + +Packages are imported one by one in alphabetical order during runtime. + +The test doesn't identify package's order-dependent issues; for instance, +importing "tf.foo" followed by "tf.bar" won't reveal that "tf.bar" depends on +"tf.foo" being imported first. + +The `_api/v2/api_packages.txt` file is generated during the TensorFlow API v2 +init files creation process and is subsequently stored in the wheel file after +the build. It also contains a few paths that cannot be directly imported. These +paths point to attributes or sub-modules within a module's namespace, but they +don't correspond to an actual file or directory on the filesystem. The list of +such paths is stored in the packages_for_skip variable and will be skipped +during the test. """ import logging +import os import unittest -import pkg_resources + +try: + import importlib.resources as pkg_resources # pylint: disable=g-import-not-at-top +except ImportError: + import importlib_resources as pkg_resources # pylint: disable=g-import-not-at-top logging.basicConfig(level=logging.INFO) @@ -37,13 +51,14 @@ class ImportApiPackagesTest(unittest.TestCase): def setUp(self): def _get_api_packages_v2(): - api_packages_path = pkg_resources.resource_filename( - "tensorflow", "_api/v2/api_packages.txt" - ) - + api_packages_path = os.path.join("_api", "v2", "api_packages.txt") logging.info("Load api packages file: %s", api_packages_path) - with open(api_packages_path) as file: - return set(file.read().splitlines()) + return set( + pkg_resources.files("tensorflow") + .joinpath(api_packages_path) + .read_text() + .splitlines() + ) super().setUp() self.api_packages_v2 = _get_api_packages_v2() diff --git a/third_party/py/python_repo.bzl b/third_party/py/python_repo.bzl index 13aed2b687129f..83778b744e0784 100644 --- a/third_party/py/python_repo.bzl +++ b/third_party/py/python_repo.bzl @@ -34,13 +34,11 @@ Please check python_init_repositories() in your WORKSPACE file. requirements_with_local_wheels = str(requirements) - local_wheels_dir = ctx.os.environ.get("LOCAL_WHEELS_DIR", "") - if ctx.attr.local_wheel_workspaces or local_wheels_dir: + if ctx.attr.local_wheel_workspaces: local_wheel_requirements = _get_injected_local_wheels( ctx, version, ctx.attr.local_wheel_workspaces, - local_wheels_dir, ) requirements_content = [ctx.read(requirements)] + local_wheel_requirements merged_requirements_content = "\n".join(requirements_content) @@ -118,8 +116,7 @@ def _parse_python_version(version_str): def _get_injected_local_wheels( ctx, py_version, - local_wheel_workspaces, - local_wheels_dir): + local_wheel_workspaces): local_wheel_requirements = [] py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_major_ver_marker = "-py%s-" % py_version.split(".")[0] @@ -140,18 +137,6 @@ def _get_injected_local_wheels( ctx.attr.local_wheel_inclusion_list, ctx.attr.local_wheel_exclusion_list, ) - if local_wheels_dir: - dist_folder_path = ctx.path(local_wheels_dir) - if dist_folder_path.exists: - dist_wheels = dist_folder_path.readdir() - _process_dist_wheels( - dist_wheels, - wheels, - py_ver_marker, - py_major_ver_marker, - ctx.attr.local_wheel_inclusion_list, - ctx.attr.local_wheel_exclusion_list, - ) for wheel_name, wheel_path in wheels.items(): local_wheel_requirements.append( diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 8201ce4582a00f..f5b06fd21d937e 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/third_party/py/python_repo.bzl b/third_party/xla/third_party/py/python_repo.bzl index 13aed2b687129f..83778b744e0784 100644 --- a/third_party/xla/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/py/python_repo.bzl @@ -34,13 +34,11 @@ Please check python_init_repositories() in your WORKSPACE file. requirements_with_local_wheels = str(requirements) - local_wheels_dir = ctx.os.environ.get("LOCAL_WHEELS_DIR", "") - if ctx.attr.local_wheel_workspaces or local_wheels_dir: + if ctx.attr.local_wheel_workspaces: local_wheel_requirements = _get_injected_local_wheels( ctx, version, ctx.attr.local_wheel_workspaces, - local_wheels_dir, ) requirements_content = [ctx.read(requirements)] + local_wheel_requirements merged_requirements_content = "\n".join(requirements_content) @@ -118,8 +116,7 @@ def _parse_python_version(version_str): def _get_injected_local_wheels( ctx, py_version, - local_wheel_workspaces, - local_wheels_dir): + local_wheel_workspaces): local_wheel_requirements = [] py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_major_ver_marker = "-py%s-" % py_version.split(".")[0] @@ -140,18 +137,6 @@ def _get_injected_local_wheels( ctx.attr.local_wheel_inclusion_list, ctx.attr.local_wheel_exclusion_list, ) - if local_wheels_dir: - dist_folder_path = ctx.path(local_wheels_dir) - if dist_folder_path.exists: - dist_wheels = dist_folder_path.readdir() - _process_dist_wheels( - dist_wheels, - wheels, - py_ver_marker, - py_major_ver_marker, - ctx.attr.local_wheel_inclusion_list, - ctx.attr.local_wheel_exclusion_list, - ) for wheel_name, wheel_path in wheels.items(): local_wheel_requirements.append( diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index 8201ce4582a00f..f5b06fd21d937e 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl b/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl index 13aed2b687129f..83778b744e0784 100644 --- a/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl @@ -34,13 +34,11 @@ Please check python_init_repositories() in your WORKSPACE file. requirements_with_local_wheels = str(requirements) - local_wheels_dir = ctx.os.environ.get("LOCAL_WHEELS_DIR", "") - if ctx.attr.local_wheel_workspaces or local_wheels_dir: + if ctx.attr.local_wheel_workspaces: local_wheel_requirements = _get_injected_local_wheels( ctx, version, ctx.attr.local_wheel_workspaces, - local_wheels_dir, ) requirements_content = [ctx.read(requirements)] + local_wheel_requirements merged_requirements_content = "\n".join(requirements_content) @@ -118,8 +116,7 @@ def _parse_python_version(version_str): def _get_injected_local_wheels( ctx, py_version, - local_wheel_workspaces, - local_wheels_dir): + local_wheel_workspaces): local_wheel_requirements = [] py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_major_ver_marker = "-py%s-" % py_version.split(".")[0] @@ -140,18 +137,6 @@ def _get_injected_local_wheels( ctx.attr.local_wheel_inclusion_list, ctx.attr.local_wheel_exclusion_list, ) - if local_wheels_dir: - dist_folder_path = ctx.path(local_wheels_dir) - if dist_folder_path.exists: - dist_wheels = dist_folder_path.readdir() - _process_dist_wheels( - dist_wheels, - wheels, - py_ver_marker, - py_major_ver_marker, - ctx.attr.local_wheel_inclusion_list, - ctx.attr.local_wheel_exclusion_list, - ) for wheel_name, wheel_path in wheels.items(): local_wheel_requirements.append( diff --git a/third_party/xla/xla/tsl/BUILD b/third_party/xla/xla/tsl/BUILD index 48719c0966cd16..2f93565b1b8680 100644 --- a/third_party/xla/xla/tsl/BUILD +++ b/third_party/xla/xla/tsl/BUILD @@ -554,3 +554,17 @@ alias( actual = ":empty", visibility = ["//visibility:public"], ) + +# Flag indicating if the target requires pre-built wheel. +# TODO(ybaturina): move to tsl repository. +bool_flag( + name = "wheel_dependency", + build_setting_default = False, +) + +config_setting( + name = "enable_wheel_dependency", + flag_values = { + ":wheel_dependency": "True", + }, +) From cec77f0e68d3997addeb5c29419f0b6943b7b309 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Fri, 20 Sep 2024 11:04:53 -0700 Subject: [PATCH 073/483] Add feature to build the XNNPack delegate weight cache in several steps. PiperOrigin-RevId: 676903767 --- tensorflow/lite/delegates/xnnpack/BUILD | 1 + .../lite/delegates/xnnpack/weight_cache.cc | 330 +++++++++++----- .../lite/delegates/xnnpack/weight_cache.h | 124 ++++-- .../delegates/xnnpack/weight_cache_test.cc | 368 +++++++++++++++--- .../delegates/xnnpack/xnnpack_delegate.cc | 29 +- 5 files changed, 667 insertions(+), 185 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index 02f60403dbc347..ad4fc172112086 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -2956,6 +2956,7 @@ cc_test( name = "weight_cache_test", srcs = ["weight_cache_test.cc"], deps = [ + ":file_util", ":test_main", ":weight_cache", ":weight_cache_schema", diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache.cc b/tensorflow/lite/delegates/xnnpack/weight_cache.cc index ab70bc4b63a504..da220a3a31b241 100644 --- a/tensorflow/lite/delegates/xnnpack/weight_cache.cc +++ b/tensorflow/lite/delegates/xnnpack/weight_cache.cc @@ -66,7 +66,7 @@ limitations under the License. namespace tflite::xnnpack { namespace { -constexpr size_t kMinAlignment = 64; +constexpr size_t kMinAlignment = 128; // Checks if the given path is a special value to use an in-memory cache. bool IsInMemoryCachePath(const char* path) { @@ -132,6 +132,8 @@ bool FileExists(const char* path) { void swap(MMapHandle& a, MMapHandle& b) { using std::swap; swap(a.size_, b.size_); + swap(a.offset_, b.offset_); + swap(a.offset_page_adjustment_, b.offset_page_adjustment_); swap(a.data_, b.data_); } @@ -144,11 +146,12 @@ MMapHandle& MMapHandle::operator=(MMapHandle&& other) { return *this; } -bool MMapHandle::Map(const char* path) { - return this->Map(FileDescriptor::Open(path, O_RDONLY), path); +bool MMapHandle::Map(const char* path, const size_t offset) { + return this->Map(FileDescriptor::Open(path, O_RDONLY), offset, path); } -bool MMapHandle::Map(const FileDescriptor& fd, const char* const path) { +bool MMapHandle::Map(const FileDescriptor& fd, const size_t offset, + const char* const path) { this->UnMap(); XNNPACK_RETURN_CHECK(fd.IsValid(), @@ -162,15 +165,19 @@ bool MMapHandle::Map(const FileDescriptor& fd, const char* const path) { // This will reset data_ and size_ on return until is is deactivated. ScopeGuard unmap_on_error([this] { UnMap(); }); - size_ = file_stats.st_size; + size_ = file_stats.st_size - offset; + offset_ = offset; #if defined(_MSC_VER) // This allocation is freed in UnMap and in the desctructor. data_ = new uint8_t[size_]; + fd.SetPos(offset); XNNPACK_RETURN_CHECK(fd.Read(data_, size_), "could not read file ('%s'): %s.", path, strerror(errno)); #else - data_ = static_cast(mmap(/*addr=*/nullptr, size_, PROT_READ, - MAP_SHARED, fd.Value(), /*offset=*/0)); + offset_page_adjustment_ = offset_ % getpagesize(); + data_ = static_cast( + mmap(/*addr=*/nullptr, size_ + offset_page_adjustment_, PROT_READ, + MAP_SHARED, fd.Value(), offset_ - offset_page_adjustment_)); XNNPACK_RETURN_CHECK(data_ != MAP_FAILED, "could not mmap file (%s): %s.", path, strerror(errno)); #endif @@ -178,6 +185,25 @@ bool MMapHandle::Map(const FileDescriptor& fd, const char* const path) { return true; } +bool MMapHandle::Resize(size_t new_size) { +#if defined(__linux__) || defined(__ANDROID__) + void* const remapped_data = + mremap(data_, size_ + offset_page_adjustment_, + new_size + offset_page_adjustment_, /*flags=*/0); + if (remapped_data == MAP_FAILED) { + XNNPACK_RETURN_CHECK(errno == ENOMEM, "remap failed: %s", strerror(errno)); + return false; + } + size_ = new_size; + return true; +#else + // The current implementation uses new/delete which doesn't provide a way to + // modify an allocation size. Changing to malloc/realloc/free doesn't ensure + // that a memory allocation will not be moved when reallocating + return false; +#endif +} + void MMapHandle::UnMap() { if (data_) { #if defined(_MSC_VER) @@ -187,39 +213,46 @@ void MMapHandle::UnMap() { #endif } data_ = nullptr; + offset_ = 0; + offset_page_adjustment_ = 0; size_ = 0; } -void swap(WeightCacheBuilder& a, WeightCacheBuilder& b) { - using std::swap; - swap(a.schema_, b.schema_); - swap(a.data_, b.data_); - swap(a.capacity_, b.capacity_); - swap(a.fd_, b.fd_); - swap(a.file_path_, b.file_path_); -} - -WeightCacheBuilder::WeightCacheBuilder(WeightCacheBuilder&& other) { - swap(*this, other); -} +#define XNN_MOVE_CONSTRUCT_MEMBER(x) x(std::move(other.x)) +WeightCacheBuilder::WeightCacheBuilder(WeightCacheBuilder&& other) + : XNN_MOVE_CONSTRUCT_MEMBER(data_), + XNN_MOVE_CONSTRUCT_MEMBER(schema_), + XNN_MOVE_CONSTRUCT_MEMBER(capacity_), + XNN_MOVE_CONSTRUCT_MEMBER(build_segment_size_), + XNN_MOVE_CONSTRUCT_MEMBER(build_segment_start_), + XNN_MOVE_CONSTRUCT_MEMBER(first_write_done_), + XNN_MOVE_CONSTRUCT_MEMBER(fd_), + XNN_MOVE_CONSTRUCT_MEMBER(file_path_) {} +#undef XNN_MOVE_CONSTRUCT_MEMBER WeightCacheBuilder& WeightCacheBuilder::operator=(WeightCacheBuilder&& other) { - Reset(); - swap(*this, other); +#define XNN_MOVE_MEMBER(x) x = std::move(other.x) + XNN_MOVE_MEMBER(data_); + XNN_MOVE_MEMBER(schema_); + XNN_MOVE_MEMBER(capacity_); + XNN_MOVE_MEMBER(build_segment_size_); + XNN_MOVE_MEMBER(build_segment_start_); + XNN_MOVE_MEMBER(first_write_done_); + XNN_MOVE_MEMBER(fd_); + XNN_MOVE_MEMBER(file_path_); +#undef XNN_MOVE_MEMBER return *this; } -WeightCacheBuilder::~WeightCacheBuilder() { Reset(); } - bool WeightCacheBuilder::Start(const char* path) { - Reset(); - ScopeGuard reset_on_error([this] { Reset(); }); - + XNNPACK_RETURN_CHECK(!IsStarted()); file_path_ = path; + if (IsInMemoryCachePath(file_path_)) { fd_ = CreateInMemoryFileDescriptor("XNNPack in-memory weight cache"); } else { - fd_.Reset(open(file_path_.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 0644)); + fd_ = FileDescriptor::Open(file_path_.c_str(), O_CREAT | O_TRUNC | O_RDWR, + 0644); } XNNPACK_RETURN_CHECK(fd_.IsValid(), "could not open file ('%s'): %s.", file_path_.c_str(), strerror(errno)); @@ -227,44 +260,63 @@ bool WeightCacheBuilder::Start(const char* path) { // Write data in the header, this will be overwritten in the `Finalize` call. // We explicitly set the header as invalid. If any error happens during // the build, reloading the cache file will fail. - const XNNPackCacheHeader header{XNNPackCacheHeader::kInvalidHeader}; + XNNPackCacheHeader header{XNNPackCacheHeader::kInvalidHeader}; + header.buffer_list_offset = sizeof(header); XNNPACK_RETURN_CHECK(fd_.Write(&header, sizeof(header)), - "could not write padding for flatbuffer offset in %s.", + "could not write initial cache header in %s.", file_path_.c_str()); schema_.base_offset = Align(sizeof(header), kMinAlignment); - - reset_on_error.Deactivate(); return true; } -void WeightCacheBuilder::Reset() { - fd_.Close(); - data_.reset(nullptr); - capacity_ = 0; - schema_ = cache::schema::BufferListT(); +bool WeightCacheBuilder::StartBuildStep() { + XNNPACK_RETURN_CHECK(IsStarted()); + + // Reload flatbuffer data. + XNNPackCacheHeader header; + fd_.SetPos(0); + XNNPACK_RETURN_CHECK(fd_.Read(&header, sizeof(header)), + "could not read cache file header."); + if (header.buffer_list_size) { + MMapHandle buffer_list_data; + XNNPACK_RETURN_CHECK(buffer_list_data.Map(fd_, header.buffer_list_offset), + "could not map buffer list mapping"); + cache::schema::GetBufferList(buffer_list_data.data())->UnPackTo(&schema_); + } + + // Move cursor to end of existing data. + build_segment_size_ = 0; + build_segment_start_ = fd_.SetPos(header.buffer_list_offset); + XNNPACK_RETURN_CHECK(build_segment_start_ != -1); + + is_build_step_ = true; + return true; } +void WeightCacheBuilder::Reset() { *this = WeightCacheBuilder(); } + void* WeightCacheBuilder::Reserve(size_t size) { if (size > capacity_) { // We don't care about the data when we are reserving space. We save memory // by deleting the existing buffer first. data_.reset(nullptr); - data_ = std::make_unique(size); + data_ = std::make_unique(size + kMinAlignment); capacity_ = size; } - return data_.get(); + return reinterpret_cast( + Align(reinterpret_cast(data_.get()), kMinAlignment)); } BufferLocation WeightCacheBuilder::Append(PackIdentifier pack_id, const void* data, uint64_t size) { - XNNPACK_ABORT_CHECK(IsStarted(), + XNNPACK_ABORT_CHECK(is_build_step_, "cannot append data to an unstarted builder."); // Add some padding so that the cache file can be mmaped and the buffer // stays aligned correctly. const size_t offset = Align(fd_.GetPos(), kMinAlignment); - if (fd_.SetPos(offset) != offset) { + if (fd_.SetPos(offset) == -1) { return BufferLocation::Invalid(); } @@ -278,20 +330,24 @@ BufferLocation WeightCacheBuilder::Append(PackIdentifier pack_id, schema_.buffers.push_back(std::make_unique(buffer)); if (!fd_.Write(data, size)) { - TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, file_path_.c_str(), + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, "XNNPack weight cache: cannot append buffer to cache file"); return BufferLocation::Invalid(); } return loc; } -bool WeightCacheBuilder::ShouldFinalize() const { return fd_.IsValid(); } - -bool WeightCacheBuilder::Finalize() { +bool WeightCacheBuilder::StopBuildStep() { XNNPACK_RETURN_CHECK(fd_.IsValid(), "cache file ('%s') is not open for writing: %s.", file_path_.c_str(), strerror(errno)); + is_build_step_ = false; + if (fd_.GetPos() == build_segment_start_ && first_write_done_) { + // Nothing was written to the file, we can exit early. + return true; + } + flatbuffers::FlatBufferBuilder builder; // Add a fake size and the base offset to mutate them afterwards. Otherwise // space for it won't be added to the flatbuffer. @@ -321,16 +377,19 @@ bool WeightCacheBuilder::Finalize() { XNNPACK_RETURN_CHECK(fd_.Write(builder.GetBufferPointer(), builder.GetSize()), "cannot write buffer list to '%s'.", file_path_.c_str()); + // Save the segment size for that it can be individually mapped. + build_segment_size_ = fd_.GetPos() - build_segment_start_; + // Write the header at the beginning of the file. XNNPACK_RETURN_CHECK(fd_.SetPos(0) != -1, "could not move in the file to write header to %s", strerror(errno)); - XNNPACK_ABORT_CHECK(fd_.Write(&header, sizeof(header)), - "cannot write cache header to %s.", file_path_.c_str()); + XNNPACK_RETURN_CHECK(fd_.Write(&header, sizeof(header)), + "cannot write cache header to %s.", file_path_.c_str()); TFLITE_LOG_PROD(tflite::TFLITE_LOG_VERBOSE, "XNNPack weight cache: written to '%s'.", file_path_.c_str()); - Reset(); + first_write_done_ = true; return true; } @@ -349,7 +408,7 @@ MMapWeightCacheProvider& MMapWeightCacheProvider::operator=( swap(file_path_, other.file_path_); swap(buffer_address_to_identifier_, other.buffer_address_to_identifier_); swap(cache_key_to_offset_, other.cache_key_to_offset_); - swap(mmap_handle_, other.mmap_handle_); + swap(mmap_handles_, other.mmap_handles_); swap(mmap_buffer_base_offset_, other.mmap_buffer_base_offset_); swap(builder_, other.builder_); return *this; @@ -357,7 +416,7 @@ MMapWeightCacheProvider& MMapWeightCacheProvider::operator=( void MMapWeightCacheProvider::SetFilePath(const char* path) { XNNPACK_ABORT_CHECK( - !IsFinalized(), + !IsBuilding(), "Cannot change the path of a cache that has already been loaded."); // We try to keep file_path_'s data as stable as possible. Don't overwrite // if the path hasn't changed. @@ -374,7 +433,6 @@ bool MMapWeightCacheProvider::LoadOrStartBuild(const char* path) { } else if (StartBuild(path)) { TFLITE_LOG_PROD(tflite::TFLITE_LOG_VERBOSE, "XNNPack weight cache build for '%s' started.", path); - return true; } return false; @@ -382,7 +440,13 @@ bool MMapWeightCacheProvider::LoadOrStartBuild(const char* path) { bool MMapWeightCacheProvider::StartBuild(const char* path) { SetFilePath(path); - return builder_.Start(path); + building_run_ = builder_.Start(path); + if (IsInMemoryCachePath(file_path_)) { + // Duplicate the file descriptor to avoid loosing the temporary file when + // the builder is reset. + temporary_file_descriptor_ = builder_.GetFileDescriptor().Duplicate(); + } + return building_run_; } bool MMapWeightCacheProvider::Load(const std::string& path) { @@ -393,10 +457,13 @@ bool MMapWeightCacheProvider::Load(const std::string& path) { bool MMapWeightCacheProvider::Load() { mmap_buffer_base_offset_ = 0; cache_key_to_offset_.clear(); + mmap_handles_.resize(1); + MMapHandle& mmap_handle = mmap_handles_.front(); + ScopeGuard unmap_on_fail([this] { mmap_handles_.clear(); }); if (temporary_file_descriptor_.IsValid()) { - XNNPACK_RETURN_CHECK( - mmap_handle_.Map(temporary_file_descriptor_, file_path_.c_str())); + XNNPACK_RETURN_CHECK(mmap_handle.Map(temporary_file_descriptor_, + /*offset=*/0, file_path_.c_str())); } else { XNNPACK_ABORT_CHECK(!file_path_.empty(), "Path wasn't provided to weight cache provider."); @@ -406,24 +473,22 @@ bool MMapWeightCacheProvider::Load() { file_path_.c_str(), strerror(errno)); return false; } - - XNNPACK_RETURN_CHECK(mmap_handle_.Map(file_path_.c_str())); + XNNPACK_RETURN_CHECK(mmap_handle.Map(file_path_.c_str())); } - ScopeGuard unmap_on_fail([this] { mmap_handle_.UnMap(); }); - - XNNPACK_RETURN_CHECK(mmap_handle_.size() >= sizeof(XNNPackCacheHeader), + XNNPACK_RETURN_CHECK(mmap_handle.size() >= sizeof(XNNPackCacheHeader), "invalid cache file size."); - const XNNPackCacheHeader header = [this] { + const XNNPackCacheHeader header = [&mmap_handle] { XNNPackCacheHeader header; - memcpy(&header, mmap_handle_.data(), sizeof(header)); + memcpy(&header, mmap_handle.data(), sizeof(header)); return header; }(); - XNNPACK_RETURN_CHECK( - header.version == XNNPackCacheHeader::kVersion, - "incompatible header version. Cache needs to be built again."); + XNNPACK_RETURN_CHECK(header.version == XNNPackCacheHeader::kVersion, + "incompatible header version. Got %zd, expected %zd. " + "Cache needs to be built again.", + header.version, XNNPackCacheHeader::kVersion); XNNPACK_RETURN_CHECK(xnn_experimental_check_build_identifier( header.xnnpack_build_identifier, @@ -431,22 +496,22 @@ bool MMapWeightCacheProvider::Load() { "XNNPack weight cache: incompatible XNNPack version. " "Cache needs to be built again."); - XNNPACK_RETURN_CHECK(header.buffer_list_offset < mmap_handle_.size(), + XNNPACK_RETURN_CHECK(header.buffer_list_offset < mmap_handle.size(), "invalid offset for buffer list descriptor."); - XNNPACK_RETURN_CHECK(header.buffer_list_size == - mmap_handle_.size() - header.buffer_list_offset, - "invalid size for buffer list descriptor."); + XNNPACK_RETURN_CHECK( + header.buffer_list_size == mmap_handle.size() - header.buffer_list_offset, + "invalid size for buffer list descriptor."); // Verifiy the flabuffer part of the file. - flatbuffers::Verifier verifier( - mmap_handle_.data() + header.buffer_list_offset, header.buffer_list_size); + flatbuffers::Verifier verifier(mmap_handle.data() + header.buffer_list_offset, + header.buffer_list_size); XNNPACK_RETURN_CHECK(cache::schema::VerifyBufferListBuffer(verifier), "buffer list validation failed."); // Load flatbuffer. const cache::schema::BufferList* buffer_list = cache::schema::GetBufferList( - mmap_handle_.data() + header.buffer_list_offset); + mmap_handle.data() + header.buffer_list_offset); XNNPACK_RETURN_CHECK(buffer_list, "could not get packed weights from flatbuffer."); @@ -459,6 +524,9 @@ bool MMapWeightCacheProvider::Load() { /*weights_id=*/buffer->weights_id(), /*bias_id=*/buffer->bias_id()}, BufferLocation{/*offset=*/buffer->offset(), /*size=*/buffer->size()}); + offset_to_addr_.insert( + {buffer->offset(), + mmap_handle.data() + mmap_buffer_base_offset_ + buffer->offset()}); } } @@ -466,6 +534,87 @@ bool MMapWeightCacheProvider::Load() { return true; } +bool MMapWeightCacheProvider::LoadLastBuildStep() { + if (mmap_handles_.empty()) { + return Load(); + } + + if (builder_.LastBuildStepSize() == 0) { + return true; + } + + const XNNPackCacheHeader header = [this] { + XNNPackCacheHeader header; + memcpy(&header, mmap_handles_.front().data(), sizeof(header)); + return header; + }(); + + // Map last data segment: + // - either resize the last mmap handle; + // - or add a new mapping handle. + { + MMapHandle& last_mmap_handle = mmap_handles_.back(); + const int last_mmap_size = last_mmap_handle.size(); + if (!last_mmap_handle.Resize(last_mmap_size + + builder_.LastBuildStepSize())) { + mmap_handles_.emplace_back(); + if (temporary_file_descriptor_.IsValid()) { + XNNPACK_RETURN_CHECK( + mmap_handles_.back().Map(temporary_file_descriptor_, + /*offset=*/builder_.LastBuildStepStart()), + "could not map last build step"); + } else { + XNNPACK_RETURN_CHECK( + mmap_handles_.back().Map(file_path_.c_str(), + /*offset=*/builder_.LastBuildStepStart()), + "could not map last build step"); + } + } + } + // Read the updated buffer list. + MMapHandle& segment_mmap_handle = mmap_handles_.back(); + const size_t buffer_list_offset = + header.buffer_list_offset - segment_mmap_handle.offset(); + + flatbuffers::Verifier verifier( + segment_mmap_handle.data() + buffer_list_offset, header.buffer_list_size); + XNNPACK_RETURN_CHECK(cache::schema::VerifyBufferListBuffer(verifier), + "buffer list validation failed."); + + const cache::schema::BufferList* buffer_list = cache::schema::GetBufferList( + segment_mmap_handle.data() + buffer_list_offset); + XNNPACK_RETURN_CHECK(buffer_list, + "could not get packed weights from flatbuffer."); + + // Update offset_to_addr_ with new offsets + const ptrdiff_t offset_modifier = + buffer_list->base_offset() - segment_mmap_handle.offset(); + for (const auto* buffer : *(buffer_list->buffers())) { + const size_t offset = buffer->offset(); + if (!offset_to_addr_.count(offset)) { + offset_to_addr_.insert( + {offset, segment_mmap_handle.data() + offset + offset_modifier}); + } + } + return true; +} + +bool MMapWeightCacheProvider::StartBuildStep() { + XNNPACK_RETURN_CHECK(CanStartBuildStep(), + "cannot append data to an existing cache file."); + if (IsBuilding()) { + return true; + } + is_build_step_ = builder_.StartBuildStep(); + return is_build_step_; +} + +bool MMapWeightCacheProvider::StopBuildStep() { + XNNPACK_RETURN_CHECK(builder_.StopBuildStep()); + is_build_step_ = false; + return LoadLastBuildStep(); +} + void MMapWeightCacheProvider::MapTensorIdentifiers( const TfLiteTensor* tensors, const size_t size, const std::unordered_map& tensor_index_to_identifier) { @@ -497,8 +646,8 @@ size_t MMapWeightCacheProvider::LookUp( } void* MMapWeightCacheProvider::ReserveSpace(size_t size) { - XNNPACK_ABORT_CHECK(!IsFinalized(), - "Cannot reserve space in a finalized cache."); + XNNPACK_ABORT_CHECK(IsBuilding(), + "Cannot reserve space in a cache that isn't building."); return builder_.Reserve(size); } @@ -512,8 +661,8 @@ size_t MMapWeightCacheProvider::LookUpOrInsert( return offset_it->second.offset; } - XNNPACK_ABORT_CHECK(!IsFinalized(), - "Cannot insert a buffer in a finalized cache."); + XNNPACK_ABORT_CHECK( + IsBuilding(), "Cannot insert a buffer in a cache that is not building."); const BufferLocation location = builder_.Append(pack_id, ptr, size); XNNPACK_ABORT_CHECK(!location.IsInvalid(), @@ -526,42 +675,19 @@ void* MMapWeightCacheProvider::OffsetToAddr(const size_t offset) { // While the cache is being built, the buffer could grow and need to be // reallocated so we cannot ensure pointer stability. XNNPACK_ABORT_CHECK( - IsFinalized(), - "Cannot get the address of a buffer in a non finalized cache."); - return mmap_handle_.data() + mmap_buffer_base_offset_ + offset; + !IsBuilding(), + "Cannot get the address of a buffer in a cache during a building step."); + return offset_to_addr_[offset]; } void MMapWeightCacheProvider::Release() { buffer_address_to_identifier_.clear(); cache_key_to_offset_.clear(); - mmap_handle_ = MMapHandle(); + mmap_handles_.clear(); mmap_buffer_base_offset_ = 0; builder_ = WeightCacheBuilder(); } -bool MMapWeightCacheProvider::Finalize() { - if (IsFinalized()) { - return true; - } - XNNPACK_RETURN_CHECK(!file_path_.empty(), - "file path wasn't set. Cannot finalize the cache."); - if (IsInMemoryCachePath(file_path_)) { - // Duplicate the file descriptor to avoid loosing the temporary file when - // the builder is reset. - temporary_file_descriptor_ = builder_.GetFileDescriptor().Duplicate(); - } - if (!builder_.Finalize()) { - return false; - } - builder_ = WeightCacheBuilder(); - - return Load(); -} - -bool MMapWeightCacheProvider::IsFinalized() const { - return mmap_handle_.IsMapped(); -} - size_t MMapWeightCacheProvider::look_up( void* context, const xnn_weights_cache_look_up_key* cache_key) { return reinterpret_cast(context)->LookUp(cache_key); @@ -579,7 +705,7 @@ size_t MMapWeightCacheProvider::look_up_or_insert( } bool MMapWeightCacheProvider::is_finalized(void* context) { - return reinterpret_cast(context)->IsFinalized(); + return reinterpret_cast(context)->IsActive(); } void* MMapWeightCacheProvider::offset_to_addr(void* context, size_t offset) { diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache.h b/tensorflow/lite/delegates/xnnpack/weight_cache.h index afdd4d02f068fd..3e2efed46a6c45 100644 --- a/tensorflow/lite/delegates/xnnpack/weight_cache.h +++ b/tensorflow/lite/delegates/xnnpack/weight_cache.h @@ -18,9 +18,11 @@ limitations under the License. #include #include #include +#include #include #include #include +#include #include "xnnpack.h" // from @XNNPACK #include "tensorflow/lite/c/common.h" @@ -111,13 +113,22 @@ class MMapHandle { // Maps the file at the given path. [[nodiscard /*Mapping a file can fail.*/]] - bool Map(const char* path); + bool Map(const char* path, size_t offset = 0); // Maps the fd associated to the file descriptor. // // The debug_path is printed along the error messages. [[nodiscard /*Mapping a file can fail.*/]] - bool Map(const FileDescriptor& fd, const char* debug_path = "unspecified"); + bool Map(const FileDescriptor& fd, size_t offset = 0, + const char* debug_path = "unspecified"); + + // Tries to resize the current mapping. + // + // Only succeeds if the mapping could be resized without being moved. + // + // WARNING: expects `IsMapped()` to be true. + [[nodiscard /*Resizing a file can fail.*/]] + bool Resize(size_t new_size); // Unmaps an existing mapping. void UnMap(); @@ -126,14 +137,16 @@ class MMapHandle { bool IsMapped() const { return data_ != nullptr; } // Returns the mapping buffer. - uint8_t* data() { return data_; } + uint8_t* data() { return data_ + offset_page_adjustment_; } // Returns the mapping buffer. - const uint8_t* data() const { return data_; } + const uint8_t* data() const { return data_ + offset_page_adjustment_; } // Returns the mapping size in bytes. size_t size() const { return size_; } + size_t offset() const { return offset_; } + uint8_t* begin() { return data(); } const uint8_t* begin() const { return data(); } @@ -146,6 +159,8 @@ class MMapHandle { private: size_t size_ = 0; + size_t offset_ = 0; + size_t offset_page_adjustment_ = 0; uint8_t* data_ = nullptr; }; @@ -156,7 +171,7 @@ class MMapHandle { class WeightCacheBuilder { public: WeightCacheBuilder() = default; - ~WeightCacheBuilder(); + ~WeightCacheBuilder() = default; // Non-copyable. WeightCacheBuilder(const WeightCacheBuilder&) = delete; @@ -174,6 +189,12 @@ class WeightCacheBuilder { return fd_.IsValid(); } + // Reopens the given file to add data to it. + // + // This should be only called from the weight cache provider. + [[nodiscard /*Starting a build step may fail.*/]] + bool StartBuildStep(); + // Resets the builder, discarding any data that hasn't been written. void Reset(); @@ -194,12 +215,25 @@ class WeightCacheBuilder { BufferLocation Append(PackIdentifier pack_id, const void* data, uint64_t size); - // Checks whether this builder has data that needs to be written to disk. - bool ShouldFinalize() const; - // Writes the flatbuffer to disk. [[nodiscard /*Writing the weight cache can fail.*/]] - bool Finalize(); + bool StopBuildStep(); + + // Get the offset in the cache file of the data written during the last step. + // + // This includes the buffers that were appended and the whole buffer mapping. + [[nodiscard]] + size_t LastBuildStepStart() const { + return build_segment_start_; + } + + // Get the size of the data written during the last step. + // + // This includes the buffers that were appended and the whole buffer mapping. + [[nodiscard]] + size_t LastBuildStepSize() const { + return build_segment_size_; + } // Returns the file descriptor. const FileDescriptor& GetFileDescriptor() const { return fd_; } @@ -218,15 +252,23 @@ class WeightCacheBuilder { // may be removed at any time. uint8_t* data() const { return data_.get(); } - friend void swap(WeightCacheBuilder& a, WeightCacheBuilder& b); - private: std::unique_ptr data_ = nullptr; cache::schema::BufferListT schema_; size_t capacity_ = 0; + // Size of the data written between StartBuildStep and StopBuildStep. + size_t build_segment_size_ = 0; + // Offset in the cache file when StartBuildStep was called. + size_t build_segment_start_ = 0; + // The call to StopBuildStep may short circuit when nothing was written to the + // cache. To ensure a smooth reloading, we need to ensure that the file header + // is correct. This flag lets us know if that has happened. + bool first_write_done_ = false; // Temporary file descriptor to write the weights to disk immediately. FileDescriptor fd_; std::string file_path_; + + bool is_build_step_ = false; }; // Allows XNNPack to directly load packed weights from disk instead of having to @@ -269,10 +311,25 @@ class MMapWeightCacheProvider { [[nodiscard /*Loading a cache file may fail.*/]] bool Load(const std::string& path); - // Loads the weight cache previouslt set with `SetFilePath`. + // Loads the weight cache previously set with `SetFilePath`. [[nodiscard /*Loading cache data may fail.*/]] bool Load(); + // Checks if the cache is currently being built or if it was loaded from a + // file. + [[nodiscard]] + bool CanStartBuildStep() const { + return building_run_; + }; + + // Prepares to add new data to the cache. + [[nodiscard /*Updating cache data may fail.*/]] + bool StartBuildStep(); + + // Prepares to use data that was added to the cache during a build step. + [[nodiscard /*Updating cache data may fail.*/]] + bool StopBuildStep(); + // Creates the tensor map. void MapTensorIdentifiers( const TfLiteTensor* tensors, size_t size, @@ -315,21 +372,17 @@ class MMapWeightCacheProvider { // Releases the weight cache's memory. void Release(); - // Ensures that the cache is ready. - // - // If the cache file already exists, this is a no-op. Otherwise, this writes - // the file to disk and reloads it. - [[nodiscard /*Writing the cache file may fail.*/]] - bool Finalize(); - - // Checks whether the cache is ready to be used. - bool IsFinalized() const; - // Returns true if any weights have been added to the underlying builder. - bool IsBuilding() const { return !IsFinalized() && !file_path_.empty(); }; + [[nodiscard]] + bool IsBuilding() const { + return is_build_step_; + }; // Returns true if a file is mapped or a file path is set. - bool IsActive() const { return IsFinalized() || !file_path_.empty(); }; + [[nodiscard]] + bool IsActive() const { + return !mmap_handles_.empty() || builder_.IsStarted(); + }; // Returns the cache provider expected by XNNPack. xnn_weights_cache_provider& GetCacheProvider() { return cache_provider_; } @@ -359,6 +412,10 @@ class MMapWeightCacheProvider { // Hashes a cache key to lookup in `cache_key_to_identifier_`. PackIdentifier BuildPackIdentifier(const xnn_weights_cache_look_up_key& key); + // Loads the data written by the last call to `builder_.BuildStepStop()`. + [[nodiscard /*Loading cache data may fail.*/]] + bool LoadLastBuildStep(); + // Cache provider implementation for XNNPack. xnn_weights_cache_provider cache_provider_{ /*context=*/this, @@ -382,7 +439,7 @@ class MMapWeightCacheProvider { cache_key_to_offset_; // MMap allocation handler. - MMapHandle mmap_handle_; + std::vector mmap_handles_; // The offset to the first buffer data in the MMap allocation. size_t mmap_buffer_base_offset_; @@ -393,6 +450,23 @@ class MMapWeightCacheProvider { // Used to build the cache. WeightCacheBuilder builder_; + + // True if the current run is the one building the cache file. + // + // We cannot distinguish between a wrong/outdated cache and one that is not + // fully done. To detect misuse, we still want to raise an error when XNNPack + // tries to append data to an existing file (i.e. when this is `false`). + bool building_run_ = false; + + // True between StartBuildStep and StopBuildStep. + // + // This is used to check whether the builder is active, which means that some + // of the buffers are not available/can't be retrieved. + bool is_build_step_ = false; + + // Stores the loaded buffer addresses corresponding to the given offset in the + // cache file. + std::map offset_to_addr_; }; } // namespace xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc b/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc index ecbc04dbe40073..ea3ab354fb3a59 100644 --- a/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc +++ b/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -36,6 +38,7 @@ limitations under the License. #include "xnnpack.h" // from @XNNPACK #include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/xnnpack/file_util.h" #include "tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" @@ -52,6 +55,47 @@ namespace { using testing::ElementsAreArray; using testing::Ge; +std::string GenerateRandomString(const size_t size) { + constexpr char chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz"; + std::mt19937 rg{std::random_device{}()}; + std::uniform_int_distribution pick(0, + sizeof(chars) - 1); + std::string str(size, 'a'); + std::generate(begin(str), end(str), [&] { return pick(rg); }); + return str; +}; + +template +class LightSpan { + public: + using value_type = T; + + LightSpan(const void* data, const size_t size) + : ptr_(reinterpret_cast(data)), size_(size) {} + + size_t size() const { return size(); } + const T* begin() const { return ptr_; } + const T* end() const { return ptr_ + size_; } + + friend std::ostream& operator<<(std::ostream& os, const LightSpan& s) { + os << '['; + auto it = s.begin(); + if (it != s.end()) { + os << +*it; + } + ++it; + for (; it != s.end(); ++it) { + os << ", " << +*it; + } + return os << ']'; + } + + private: + T* ptr_; + size_t size_; +}; + // Wraps a call to `mkstemp` to create temporary files. class TempFileDesc { public: @@ -184,6 +228,82 @@ TEST(MMapHandleTest, MoveConstructs) { EXPECT_THAT(handle2, ElementsAreArray(payload)); } +TEST(MMapHandleTest, Resize) { + const std::string payload = "This is some data in the file."; + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + ASSERT_EQ(write(tmp_file.GetFd(), payload.c_str(), size(payload)), + size(payload)); + tmp_file.Close(); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + +#if defined(__linux__) || defined(__ANDROID__) + const size_t kMaxResizeTestCount = 20; + bool was_resized = true; + for (size_t i = 0; i < kMaxResizeTestCount && was_resized; ++i) { + was_resized = handle.Resize(payload.size() * 2); + EXPECT_TRUE(was_resized || errno == ENOMEM); + } +#else + EXPECT_FALSE(handle.Resize(payload.size())); +#endif +} + +TEST(MMapHandleTest, MapWithOffset) { + const std::string payload = "This is some data in the file."; + const std::string payload2 = "Some other data appended to the the offset."; + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + ASSERT_EQ(write(tmp_file.GetFd(), payload.c_str(), size(payload)), + size(payload)); + ASSERT_EQ(write(tmp_file.GetFd(), payload2.c_str(), size(payload2)), + size(payload2)); + tmp_file.Close(); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath(), /*offset=*/size(payload))); + EXPECT_EQ(handle.size(), size(payload2)); + EXPECT_THAT(std::string((const char*)handle.data(), handle.size()), + testing::StrEq(payload2)); +} + +TEST(MMapHandleTest, ResizeMapWithOffset) { + const std::string payload = "This is some data in the file."; + const std::string payload2 = "Some other data appended to the the offset."; + const std::string payload3 = + "Yet some other data written after the initial mapping."; + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + ASSERT_EQ(write(tmp_file.GetFd(), payload.c_str(), size(payload)), + size(payload)); + ASSERT_EQ(write(tmp_file.GetFd(), payload2.c_str(), size(payload2)), + size(payload2)); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath(), /*offset=*/size(payload))); + + ASSERT_EQ(write(tmp_file.GetFd(), payload3.c_str(), size(payload3)), + size(payload3)); + tmp_file.Close(); +#if defined(__linux__) || defined(__ANDROID__) + bool was_resized = handle.Resize(payload2.size() + payload3.size()); + if (was_resized) { + EXPECT_THAT(std::string((const char*)handle.data(), handle.size()), + testing::StrEq(payload2 + payload3)); + } else { + GTEST_SKIP() + << "This run did not end up in a resize of the mmaped interval."; + } +#else + GTEST_SKIP() << "Resize is not supported for this build."; +#endif +} + TEST(WeightCacheBuilderTest, ReserveAppendWriteWorks) { using std::size; @@ -193,6 +313,7 @@ TEST(WeightCacheBuilderTest, ReserveAppendWriteWorks) { WeightCacheBuilder builder; const std::string cache_path = testing::TempDir() + "/cache"; ASSERT_TRUE(builder.Start(cache_path.c_str())); + ASSERT_TRUE(builder.StartBuildStep()); const size_t payload_size = size(payload); void* buffer = builder.Reserve(payload_size); @@ -201,9 +322,8 @@ TEST(WeightCacheBuilderTest, ReserveAppendWriteWorks) { EXPECT_EQ(loc.size, payload_size); EXPECT_GE(builder.capacity(), payload_size); - EXPECT_TRUE(builder.ShouldFinalize()); - ASSERT_TRUE(builder.Finalize()); + ASSERT_TRUE(builder.StopBuildStep()); MMapHandle handle; ASSERT_TRUE(handle.Map(cache_path.c_str())); @@ -258,14 +378,14 @@ TEST(WeightCacheBuilderTest, AppendWithoutReserveWriteWorks) { const std::string cache_path = testing::TempDir() + "/cache"; WeightCacheBuilder builder; ASSERT_TRUE(builder.Start(cache_path.c_str())); + ASSERT_TRUE(builder.StartBuildStep()); const size_t payload_size = size(payload); auto loc = builder.Append(dummy_id, payload.c_str(), payload_size); EXPECT_EQ(loc.size, payload_size); - EXPECT_TRUE(builder.ShouldFinalize()); - ASSERT_TRUE(builder.Finalize()); + ASSERT_TRUE(builder.StopBuildStep()); MMapHandle handle; ASSERT_TRUE(handle.Map(cache_path.c_str())); @@ -341,6 +461,127 @@ TEST(WeightCacheBuilderTest, InMemoryCacheTriggeredByCorrectPrefix) { } } +TEST(WeightCacheBuilderTest, MultipleStepBuild) { + using std::size; + + const std::string payload1 = "This is some data in the file."; + const PackIdentifier dummy_id1{1, 2, 3}; + const std::string payload2 = "Other data in the file."; + const PackIdentifier dummy_id2{2, 3, 4}; + const std::string payload3 = + GenerateRandomString(/*10 MiB*/ 10 * 1024 * 1024); + const PackIdentifier dummy_id3{3, 4, 5}; + + TempFileDesc tmp_file{TempFileDesc::kAutoClose}; + + WeightCacheBuilder builder; + ASSERT_TRUE(builder.Start(tmp_file.GetCPath())); + ASSERT_TRUE(builder.StartBuildStep()); + + { + const size_t payload_size = size(payload1); + void* buffer = builder.Reserve(payload_size); + std::memcpy(buffer, payload1.c_str(), payload_size); + const auto loc = builder.Append(dummy_id1, buffer, payload_size); + EXPECT_EQ(loc.size, payload_size); + EXPECT_GE(builder.capacity(), payload_size); + } + { + const size_t payload_size = size(payload3); + void* buffer = builder.Reserve(payload_size); + std::memcpy(buffer, payload3.c_str(), payload_size); + const auto loc = builder.Append(dummy_id3, buffer, payload_size); + (void)loc; + } + + ASSERT_TRUE(builder.StopBuildStep()); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + + ASSERT_TRUE(builder.StartBuildStep()); + { + const size_t payload_size = size(payload2); + void* buffer = builder.Reserve(payload_size); + std::memcpy(buffer, payload2.c_str(), payload_size); + const auto loc = builder.Append(dummy_id2, buffer, payload_size); + EXPECT_EQ(loc.size, payload_size); + EXPECT_GE(builder.capacity(), payload_size); + } + + ASSERT_TRUE(builder.StopBuildStep()); + + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + + const XNNPackCacheHeader& header = + *reinterpret_cast(handle.data()); + + ASSERT_EQ(header.version, XNNPackCacheHeader::kVersion); + ASSERT_NE(header.buffer_list_offset, 0); + ASSERT_NE(header.buffer_list_size, 0); + ASSERT_LE(header.buffer_list_offset + header.buffer_list_size, handle.size()); + + const cache::schema::BufferList* const packed_weights = + cache::schema::GetBufferList(handle.data() + header.buffer_list_offset); + + ASSERT_NE(packed_weights, nullptr); + ASSERT_NE(packed_weights->buffers(), nullptr); + ASSERT_EQ(packed_weights->buffers()->size(), 3); + // Payload 1. + const auto* buffer1 = packed_weights->buffers()->Get(0); + ASSERT_NE(buffer1, nullptr); + ASSERT_EQ(buffer1->size(), size(payload1)); + ASSERT_EQ(buffer1->packing_algorithm_id(), dummy_id1.pack_algorithm_id); + ASSERT_EQ(buffer1->weights_id(), dummy_id1.weights_id); + ASSERT_EQ(buffer1->bias_id(), dummy_id1.bias_id); + + // Payload 3. + const auto* buffer3 = packed_weights->buffers()->Get(1); + ASSERT_NE(buffer3, nullptr); + ASSERT_EQ(buffer3->size(), size(payload3)); + ASSERT_EQ(buffer3->packing_algorithm_id(), dummy_id3.pack_algorithm_id); + ASSERT_EQ(buffer3->weights_id(), dummy_id3.weights_id); + ASSERT_EQ(buffer3->bias_id(), dummy_id3.bias_id); + + // Payload 2. + const auto* buffer2 = packed_weights->buffers()->Get(2); + ASSERT_NE(buffer2, nullptr); + ASSERT_EQ(buffer2->size(), size(payload2)); + ASSERT_EQ(buffer2->packing_algorithm_id(), dummy_id2.pack_algorithm_id); + ASSERT_EQ(buffer2->weights_id(), dummy_id2.weights_id); + ASSERT_EQ(buffer2->bias_id(), dummy_id2.bias_id); + + flatbuffers::Verifier verifier(handle.data() + header.buffer_list_offset, + header.buffer_list_size); + EXPECT_TRUE(cache::schema::VerifyBufferListBuffer(verifier)); + + // Payload 1. + ASSERT_LE(packed_weights->base_offset() + buffer1->offset(), size(handle)); + ASSERT_LE(packed_weights->base_offset() + buffer1->offset() + buffer1->size(), + size(handle)); + + // Payload 2. + ASSERT_LE(packed_weights->base_offset() + buffer2->offset(), size(handle)); + ASSERT_LE(packed_weights->base_offset() + buffer2->offset() + buffer2->size(), + size(handle)); + + // Payload 3. + ASSERT_LE(packed_weights->base_offset() + buffer3->offset(), size(handle)); + ASSERT_LE(packed_weights->base_offset() + buffer3->offset() + buffer3->size(), + size(handle)); + + auto GetBufferData = [&handle, &packed_weights](const auto* buffer) { + return std::tuple( + reinterpret_cast( + handle.data() + packed_weights->base_offset() + buffer->offset()), + buffer->size()); + }; + + EXPECT_THAT(GetBufferData(buffer1), ElementsAreArray(payload1)); + EXPECT_THAT(GetBufferData(buffer2), ElementsAreArray(payload2)); + EXPECT_THAT(GetBufferData(buffer3), ElementsAreArray(payload3)); +} + struct FakeContext { // Adds a new tensor and it's backing buffer to the context. // @@ -447,12 +688,12 @@ struct BuildMMapWeightCacheProviderTest : testing::Test { ctx.FinalizeTensors(); cache_provider.MapTensorIdentifiers(ctx.tensors.data(), ctx.tensors.size(), ctx.tensor_buffer_identifiers); - const std::string cache_path = testing::TempDir() + "/cache"; - ASSERT_TRUE(cache_provider.StartBuild(cache_path.c_str())); + ASSERT_TRUE(cache_provider.StartBuild(tmp_file.GetCPath())); } FakeContext ctx; MMapWeightCacheProvider cache_provider; + TempFileDesc tmp_file{TempFileDesc::kAutoClose}; }; TEST_F(BuildMMapWeightCacheProviderTest, LookUpFailsIfKeyDoesntMatch) { @@ -462,8 +703,10 @@ TEST_F(BuildMMapWeightCacheProviderTest, LookUpFailsIfKeyDoesntMatch) { TEST_F(BuildMMapWeightCacheProviderTest, LookUpSucceeds) { enum { kWeightIndex, kBiasIndex }; + ASSERT_TRUE(cache_provider.StartBuildStep()); const auto pack_id = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex, kBiasIndex); + EXPECT_TRUE(cache_provider.StopBuildStep()); const xnn_weights_cache_look_up_key look_up_key = ctx.LookUpKey(kAlgoSeed1, kWeightIndex, kBiasIndex); @@ -474,10 +717,12 @@ TEST_F(BuildMMapWeightCacheProviderTest, LookUpSucceeds) { TEST_F(BuildMMapWeightCacheProviderTest, DifferentAlgoSeedsSameTensorsDontConflict) { enum { kWeightIndex, kBiasIndex }; + ASSERT_TRUE(cache_provider.StartBuildStep()); const auto pack_id_1 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex, kBiasIndex); const auto pack_id_2 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed2, kWeightIndex, kBiasIndex); + EXPECT_TRUE(cache_provider.StopBuildStep()); const xnn_weights_cache_look_up_key look_up_key_1 = ctx.LookUpKey(kAlgoSeed1, kWeightIndex, kBiasIndex); @@ -495,6 +740,7 @@ TEST_F(BuildMMapWeightCacheProviderTest, TEST_F(BuildMMapWeightCacheProviderTest, SameAlgoSeedDifferentTensorsDontConflict) { enum { kWeightIndex1, kWeightIndex2, kBiasIndex1, kBiasIndex2 }; + ASSERT_TRUE(cache_provider.StartBuildStep()); const auto pack_id_1 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex1, kBiasIndex1); @@ -507,6 +753,7 @@ TEST_F(BuildMMapWeightCacheProviderTest, const auto pack_id_4 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex2, kBiasIndex2); + EXPECT_TRUE(cache_provider.StopBuildStep()); const xnn_weights_cache_look_up_key look_up_key_1 = ctx.LookUpKey(kAlgoSeed1, kWeightIndex1, kBiasIndex1); @@ -540,10 +787,9 @@ TEST_F(BuildMMapWeightCacheProviderTest, cache_provider.LookUp(&look_up_key_4)); } -TEST_F(BuildMMapWeightCacheProviderTest, FinalizeWorks) { +TEST_F(BuildMMapWeightCacheProviderTest, BuildStepSequenceWorks) { enum { kWeightIndex1, kBiasIndex, kWeightIndex2 }; - TempFileDesc tmp_file(TempFileDesc::kAutoClose); - ASSERT_TRUE(cache_provider.StartBuild(tmp_file.GetCPath())); + ASSERT_TRUE(cache_provider.StartBuildStep()); ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex1, kBiasIndex); @@ -552,9 +798,10 @@ TEST_F(BuildMMapWeightCacheProviderTest, FinalizeWorks) { EXPECT_TRUE(cache_provider.IsActive()); EXPECT_TRUE(cache_provider.IsBuilding()); - ASSERT_TRUE(cache_provider.Finalize()); + ASSERT_TRUE(cache_provider.StopBuildStep()); - ASSERT_TRUE(cache_provider.IsFinalized()); + ASSERT_TRUE(cache_provider.IsActive()); + EXPECT_FALSE(cache_provider.IsBuilding()); } struct LoadMMapWeightCacheProviderTest : BuildMMapWeightCacheProviderTest { @@ -562,15 +809,14 @@ struct LoadMMapWeightCacheProviderTest : BuildMMapWeightCacheProviderTest { void SetUp() override { BuildMMapWeightCacheProviderTest::SetUp(); - ASSERT_TRUE(cache_provider.StartBuild(tmp_file.GetCPath())); + ASSERT_TRUE(cache_provider.StartBuildStep()); pack_id_1 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex1, kBiasIndex); pack_id_2 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed2, kWeightIndex2); - ASSERT_TRUE(cache_provider.Finalize()); - ASSERT_TRUE(cache_provider.IsFinalized()); + ASSERT_TRUE(cache_provider.StopBuildStep()); } xnn_weights_cache_look_up_key LookUpKey1() const { @@ -581,7 +827,6 @@ struct LoadMMapWeightCacheProviderTest : BuildMMapWeightCacheProviderTest { return ctx.LookUpKey(kAlgoSeed2, kWeightIndex2); } - TempFileDesc tmp_file; PackIdentifier pack_id_1; PackIdentifier pack_id_2; }; @@ -591,36 +836,6 @@ TEST_F(LoadMMapWeightCacheProviderTest, LookUpFailsIfKeyDoesntMatch) { EXPECT_EQ(cache_provider.LookUp(&look_up_key), SIZE_MAX); } -template -class LightSpan { - public: - using value_type = T; - - LightSpan(const void* data, const size_t size) - : ptr_(reinterpret_cast(data)), size_(size) {} - - size_t size() const { return size(); } - const T* begin() const { return ptr_; } - const T* end() const { return ptr_ + size_; } - - friend std::ostream& operator<<(std::ostream& os, const LightSpan& s) { - os << '['; - auto it = s.begin(); - if (it != s.end()) { - os << +*it; - } - ++it; - for (; it != s.end(); ++it) { - os << ", " << +*it; - } - return os << ']'; - } - - private: - T* ptr_; - size_t size_; -}; - TEST_F(LoadMMapWeightCacheProviderTest, LookUpSucceeds) { const auto& reference_1 = ctx.packed_buffers.find(pack_id_1)->second; const auto& reference_2 = ctx.packed_buffers.find(pack_id_2)->second; @@ -652,6 +867,8 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { const int32_t fake_packing_algo_seed = 0xBA0BAB; const char packed_data_ref_1[] = "abcdefghij"; const char packed_data_ref_2[] = "klmnopqr"; + const std::string packed_data_ref_3 = + GenerateRandomString(/*10 MiB*/ 10 * 1024 * 1024); auto bytes = [](const auto& array) { return size(array) * sizeof(array[0]); }; constexpr int kBufferCount = 10; @@ -660,6 +877,8 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { char fake_buffer_pointer[kBufferCount] = {0}; { // Build and reload scenario. + // This isn't factored between the two scenarios. When reloading the cache + // in another process, the buffer addresses will have changed. TfLiteTensor tensors[kBufferCount]; std::unordered_map tensor_buffer_identifiers; for (int i = 0; i < kBufferCount; ++i) { @@ -669,6 +888,8 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { MMapWeightCacheProvider cache_provider; ASSERT_TRUE(cache_provider.StartBuild(temp_fd.GetCPath())); + // 1st build step. + ASSERT_TRUE(cache_provider.StartBuildStep()); xnn_weights_cache_t cache = &cache_provider.GetCacheProvider(); cache_provider.MapTensorIdentifiers(tensors, size(tensors), @@ -679,6 +900,11 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { .kernel = tensors[0].data.data, .bias = tensors[1].data.data}; + const xnn_weights_cache_look_up_key look_up_key_3{ + .seed = fake_packing_algo_seed, + .kernel = tensors[3].data.data, + .bias = tensors[4].data.data}; + // Lookup non-packed tensor. ASSERT_EQ(cache->look_up(cache, &look_up_key_1), SIZE_MAX); // Reserve space, write data and add packed data. @@ -689,25 +915,50 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { const size_t build_offset_1 = cache->look_up_or_insert( cache, &look_up_key_1, reserved_ptr, bytes(packed_data_ref_1)); - // Check that a second insertion with the same key returns the same offset. + // Check that a second insertion with the same key returns the same + // offset. const size_t build_offset_redundant = cache->look_up_or_insert( cache, &look_up_key_1, reserved_ptr, bytes(packed_data_ref_1)); EXPECT_EQ(build_offset_1, build_offset_redundant); + // Lookup and insert other tensor + ASSERT_EQ(cache->look_up(cache, &look_up_key_3), SIZE_MAX); + void* const reserved_ptr_3 = + cache->reserve_space(cache, bytes(packed_data_ref_3)); + ASSERT_NE(reserved_ptr_3, nullptr); + std::memcpy(reserved_ptr_3, packed_data_ref_3.data(), + bytes(packed_data_ref_3)); + const size_t build_offset_3 = cache->look_up_or_insert( + cache, &look_up_key_3, reserved_ptr_3, bytes(packed_data_ref_3)); + + ASSERT_TRUE(cache_provider.StopBuildStep()); + // Lookup newly packed tensor. ASSERT_EQ(cache->look_up(cache, &look_up_key_1), build_offset_1); + ASSERT_EQ(cache->look_up(cache, &look_up_key_3), build_offset_3); + + // 2nd build step. + ASSERT_TRUE(cache_provider.StartBuildStep()); // Add a tensor without reserving before. const xnn_weights_cache_look_up_key look_up_key_2{ .seed = fake_packing_algo_seed, .kernel = tensors[2].data.data, .bias = tensors[3].data.data}; + const size_t build_offset_2 = cache->look_up_or_insert( cache, &look_up_key_2, (void*)packed_data_ref_2, bytes(packed_data_ref_2)); + // Buffer inserted during build step 1 can be looked up. + EXPECT_EQ(cache->look_up(cache, &look_up_key_3), build_offset_3); + // Reinsert buffer inserted during build step 1 should be a no-op. + EXPECT_EQ(cache->look_up_or_insert(cache, &look_up_key_3, reserved_ptr_3, + bytes(packed_data_ref_3)), + build_offset_3); + // Save the cache to disk and reload. - ASSERT_TRUE(cache_provider.Finalize()); + ASSERT_TRUE(cache_provider.StopBuildStep()); ASSERT_TRUE(cache->is_finalized(cache)); @@ -730,6 +981,16 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { EXPECT_THAT( LightSpan(loaded_packed_data_2, size(packed_data_ref_2)), ElementsAreArray(packed_data_ref_2)); + + const size_t reload_offset_3 = cache->look_up(cache, &look_up_key_3); + ASSERT_EQ(reload_offset_3, build_offset_3); + + const void* const loaded_packed_data_3 = + cache->offset_to_addr(cache, reload_offset_3); + ASSERT_NE(loaded_packed_data_3, nullptr); + EXPECT_THAT( + LightSpan(loaded_packed_data_3, size(packed_data_ref_3)), + ElementsAreArray(packed_data_ref_3)); } { // Load existing cache scenario. @@ -757,6 +1018,11 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { .kernel = tensors[2].data.data, .bias = tensors[3].data.data}; + const xnn_weights_cache_look_up_key look_up_key_3{ + .seed = fake_packing_algo_seed, + .kernel = tensors[3].data.data, + .bias = tensors[4].data.data}; + ASSERT_TRUE(cache->is_finalized(cache)); const size_t offset_1 = cache->look_up(cache, &look_up_key_1); @@ -775,6 +1041,14 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { EXPECT_THAT( LightSpan(loaded_packed_data_2, size(packed_data_ref_2)), ElementsAreArray(packed_data_ref_2)); + + const size_t offset_3 = cache->look_up(cache, &look_up_key_3); + const void* const loaded_packed_data_3 = + cache->offset_to_addr(cache, offset_3); + ASSERT_NE(loaded_packed_data_3, nullptr); + EXPECT_THAT( + LightSpan(loaded_packed_data_3, size(packed_data_ref_3)), + ElementsAreArray(packed_data_ref_3)); } } diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 99e2605779f8bd..376d4a96ca8ad0 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -1149,9 +1149,26 @@ class Subgraph { if (context->profiler) { flags |= XNN_FLAG_BASIC_PROFILING; } + + if (delegate.weight_cache_provider_.IsActive() && + delegate.weight_cache_provider_.CanStartBuildStep()) { + if (!delegate.weight_cache_provider_.StartBuildStep()) { + TF_LITE_KERNEL_LOG( + context, "XNNPack delegate failed to start cache build step."); + return nullptr; + } + } status = xnn_create_runtime_v4(subgraph.get(), delegate.weights_cache(), delegate.workspace(), delegate.threadpool(), flags, &runtime_ptr); + if (delegate.weight_cache_provider_.IsActive() && + delegate.weight_cache_provider_.CanStartBuildStep()) { + if (!delegate.weight_cache_provider_.StopBuildStep()) { + TF_LITE_KERNEL_LOG(context, + "XNNPack delegate failed to stop cache build step."); + return nullptr; + } + } if (status != xnn_status_success) { TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK runtime"); return nullptr; @@ -1165,17 +1182,6 @@ class Subgraph { bool enable_subgraph_reshaping, Delegate* delegate) { std::lock_guard lock(delegate->workspace_mutex_); - // The weights cache needs to be finalized only once. Prepare will be called - // for each partition after all the partitions have been created (therefore - // all the weights are known and have been packed). - if (delegate->weight_cache_provider_.IsActive()) { - if (!delegate->weight_cache_provider_.Finalize()) { - TF_LITE_KERNEL_LOG(context, - "XNNPack delegate failed to finalize cache."); - return kTfLiteError; - } - } - if (enable_subgraph_reshaping) { xnn_status status = xnn_status_invalid_state; for (int i = 0; i < inputs_.size(); ++i) { @@ -1232,6 +1238,7 @@ class Subgraph { TfLiteStatus Invoke(TfLiteContext* context, bool enable_subgraph_reshaping, Delegate* delegate) { std::lock_guard lock(delegate->workspace_mutex_); + bool any_pointers_changed = false; for (std::pair io_info : externals_) { const TfLiteTensor& tensor = context->tensors[io_info.first]; From c3dc420707d1afa9e19db9aa8e9659babe3c7bc6 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 20 Sep 2024 11:33:56 -0700 Subject: [PATCH 074/483] [PjRt-IFRT] Migrate the include file for IFRT/XLA DType conversion functions This change updates the header include file from `pjrt_array.h` to `pjrt_dtype.h` for IFRT/XLA DType conversion functions. PiperOrigin-RevId: 676914070 --- tensorflow/core/tfrt/ifrt/BUILD | 3 +-- tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index 69e129b4353c07..58ff8e28b3a114 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -276,10 +276,9 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/python/ifrt", - "@local_xla//xla/python/pjrt_ifrt", + "@local_xla//xla/python/pjrt_ifrt:pjrt_dtype", ], ) diff --git a/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc b/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc index b920efe932bd5c..f3ef48f0a2e3d6 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/shape.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" From 9e806cc3a117af59133b82cab337c15f243123ab Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 20 Sep 2024 12:08:23 -0700 Subject: [PATCH 075/483] Move `tsl/lib` to `xla/tsl/lib` PiperOrigin-RevId: 676926239 --- tensorflow/compiler/mlir/lite/debug/BUILD | 2 +- tensorflow/compiler/mlir/lite/debug/debug.cc | 2 +- tensorflow/compiler/mlir/tensorflow/BUILD | 4 +- .../mlir/tensorflow/utils/dump_mlir_util.cc | 2 +- .../tensorflow/utils/xla_sharding_util.cc | 2 +- tensorflow/core/BUILD | 18 +- .../gpu/gpu_bfc_allocator_test.cc | 4 +- .../gpu/gpu_debug_allocator_test.cc | 2 +- .../snapshot/distributed_snapshot_test.cc | 2 +- .../data/service/snapshot/file_utils_test.cc | 2 +- .../snapshot/parallel_tfrecord_writer_test.cc | 2 +- .../snapshot/prefetched_split_provider.cc | 2 +- .../prefetched_split_provider_test.cc | 2 +- .../data/service/snapshot/snapshot_manager.cc | 2 +- .../snapshot/snapshot_split_provider_test.cc | 2 +- .../snapshot_stream_writer_checkpoint_test.cc | 2 +- .../snapshot/snapshot_stream_writer_test.cc | 2 +- tensorflow/core/data/snapshot_utils.cc | 4 +- tensorflow/core/data/snapshot_utils.h | 2 +- tensorflow/core/framework/BUILD | 2 +- .../data/experimental/distributed_save_op.cc | 2 +- .../core/kernels/stochastic_cast_op_test.cc | 2 +- tensorflow/core/lib/gtl/BUILD | 30 +- tensorflow/core/lib/gtl/compactptrset.h | 2 +- tensorflow/core/lib/gtl/flatmap.h | 2 +- tensorflow/core/lib/gtl/flatrep.h | 2 +- tensorflow/core/lib/gtl/flatset.h | 2 +- tensorflow/core/lib/gtl/inlined_vector.h | 2 +- tensorflow/core/lib/gtl/int_type.h | 2 +- tensorflow/core/lib/gtl/iterator_range.h | 2 +- tensorflow/core/lib/gtl/map_util.h | 2 +- tensorflow/core/lib/gtl/subtle/BUILD | 2 +- tensorflow/core/lib/gtl/subtle/map_traits.h | 2 +- tensorflow/core/lib/hash/BUILD | 6 +- tensorflow/core/lib/hash/crc32c.h | 2 +- tensorflow/core/lib/io/BUILD | 42 +-- tensorflow/core/lib/io/block.h | 2 +- tensorflow/core/lib/io/block_builder.h | 2 +- tensorflow/core/lib/io/buffered_inputstream.h | 2 +- tensorflow/core/lib/io/cache.h | 2 +- tensorflow/core/lib/io/compression.h | 2 +- tensorflow/core/lib/io/format.h | 2 +- tensorflow/core/lib/io/inputbuffer.h | 2 +- .../core/lib/io/inputstream_interface.h | 2 +- tensorflow/core/lib/io/iterator.h | 2 +- tensorflow/core/lib/io/proto_encode_helper.h | 2 +- tensorflow/core/lib/io/random_inputstream.h | 2 +- tensorflow/core/lib/io/record_reader.h | 2 +- tensorflow/core/lib/io/record_writer.h | 2 +- tensorflow/core/lib/io/table.h | 2 +- tensorflow/core/lib/io/table_builder.h | 2 +- tensorflow/core/lib/io/table_options.h | 2 +- tensorflow/core/lib/io/two_level_iterator.h | 2 +- .../core/lib/io/zlib_compression_options.h | 2 +- tensorflow/core/lib/io/zlib_inputstream.h | 2 +- tensorflow/core/lib/io/zlib_outputbuffer.h | 2 +- tensorflow/core/lib/math/BUILD | 2 +- tensorflow/core/lib/math/math_util.h | 2 +- tensorflow/core/lib/random/BUILD | 18 +- .../core/lib/random/distribution_sampler.h | 2 +- .../core/lib/random/exact_uniform_int.h | 2 +- tensorflow/core/lib/random/philox_random.h | 2 +- .../core/lib/random/random_distributions.h | 2 +- .../lib/random/random_distributions_utils.h | 2 +- tensorflow/core/lib/random/simple_philox.h | 2 +- tensorflow/core/lib/random/weighted_picker.h | 2 +- .../core/profiler/convert/trace_viewer/BUILD | 2 +- .../convert/trace_viewer/trace_events.cc | 8 +- .../convert/trace_viewer/trace_events.h | 2 +- tensorflow/core/util/tensor_bundle/BUILD | 2 +- .../core/util/tensor_bundle/tensor_bundle.cc | 2 +- .../core/util/tensor_bundle/tensor_bundle.h | 2 +- tensorflow/lite/kernels/BUILD | 4 +- tensorflow/lite/kernels/random_ops.cc | 4 +- .../third_party/tsl/tsl/platform/cloud/BUILD | 2 +- .../tsl/platform/cloud/curl_http_request.cc | 2 +- .../third_party/tsl/tsl/profiler/utils/BUILD | 4 +- .../tsl/tsl/profiler/utils/group_events.cc | 2 +- .../tsl/tsl/profiler/utils/xplane_schema.cc | 2 +- third_party/xla/xla/BUILD | 12 +- third_party/xla/xla/client/BUILD | 2 +- third_party/xla/xla/client/padding.cc | 2 +- third_party/xla/xla/ffi/BUILD | 2 +- third_party/xla/xla/ffi/type_id_registry.h | 2 +- third_party/xla/xla/hlo/ir/BUILD | 4 +- third_party/xla/xla/hlo/ir/hlo_computation.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_computation.h | 2 +- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 4 +- third_party/xla/xla/hlo/ir/hlo_instruction.h | 2 +- .../xla/xla/hlo/ir/hlo_instructions.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_instructions.h | 2 +- third_party/xla/xla/hlo/ir/hlo_module.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_module.h | 2 +- third_party/xla/xla/hlo/ir/hlo_schedule.cc | 2 +- third_party/xla/xla/iterator_util.h | 2 +- third_party/xla/xla/pjrt/BUILD | 2 +- third_party/xla/xla/pjrt/pjrt_common.h | 2 +- third_party/xla/xla/python/ifrt/BUILD | 2 +- third_party/xla/xla/python/ifrt/device.h | 2 +- third_party/xla/xla/python/ifrt/device_list.h | 1 + third_party/xla/xla/reference_util.cc | 2 +- third_party/xla/xla/service/BUILD | 24 +- .../xla/xla/service/buffer_value_containers.h | 2 +- third_party/xla/xla/service/cpu/BUILD | 2 +- third_party/xla/xla/service/cpu/ir_emitter.cc | 2 +- third_party/xla/xla/service/dump.cc | 4 +- .../xla/xla/service/global_device_id.h | 2 +- third_party/xla/xla/service/gpu/kernels/BUILD | 2 +- .../xla/service/gpu/kernels/topk_kernel.cu.h | 2 +- third_party/xla/xla/service/gpu/model/BUILD | 2 +- .../service/gpu/model/tiled_hlo_computation.h | 2 +- third_party/xla/xla/service/gpu/runtime/BUILD | 4 +- .../xla/service/gpu/runtime/nccl_clique_key.h | 2 +- .../xla/xla/service/gpu/runtime/thunk.h | 2 +- .../xla/xla/service/hlo_cost_analysis.cc | 2 +- .../xla/xla/service/hlo_graph_dumper.cc | 6 +- third_party/xla/xla/service/hlo_parser.cc | 2 +- third_party/xla/xla/service/logical_buffer.h | 2 +- .../xla/service/tuple_points_to_analysis.h | 2 +- third_party/xla/xla/shape_tree.h | 2 +- third_party/xla/xla/stream_executor/BUILD | 6 +- .../xla/xla/stream_executor/command_buffer.h | 2 +- .../xla/stream_executor/device_description.cc | 2 +- third_party/xla/xla/stream_executor/gpu/BUILD | 2 +- .../stream_executor/gpu/redzone_allocator.cc | 2 +- third_party/xla/xla/tests/BUILD | 2 +- .../xla/xla/tests/batch_normalization_test.cc | 2 +- third_party/xla/xla/text_literal_reader.cc | 4 +- third_party/xla/xla/tools/BUILD | 4 +- .../xla/tools/hex_floats_to_packed_literal.cc | 4 +- .../xla/xla/tsl/distributed_runtime/rpc/BUILD | 2 +- .../distributed_runtime/rpc/grpc_channel.cc | 2 +- third_party/xla/xla/tsl/framework/BUILD | 12 +- .../xla/xla/tsl/framework/cancellation.h | 2 +- third_party/xla/xla/tsl/framework/device_id.h | 2 +- .../xla/tsl/framework/tracking_allocator.h | 2 +- .../xla/xla/tsl/lib/core/bitmap_test.cc | 2 +- .../tsl => xla}/tsl/lib/gtl/BUILD | 64 ++-- .../tsl => xla}/tsl/lib/gtl/compactptrset.h | 8 +- .../tsl/lib/gtl/compactptrset_test.cc | 2 +- .../tsl => xla}/tsl/lib/gtl/flatmap.h | 8 +- .../tsl => xla}/tsl/lib/gtl/flatmap_test.cc | 2 +- .../tsl => xla}/tsl/lib/gtl/flatrep.h | 6 +- .../tsl => xla}/tsl/lib/gtl/flatset.h | 8 +- .../tsl => xla}/tsl/lib/gtl/flatset_test.cc | 2 +- .../tsl => xla}/tsl/lib/gtl/inlined_vector.h | 6 +- .../tsl => xla}/tsl/lib/gtl/int_type.h | 6 +- .../tsl => xla}/tsl/lib/gtl/int_type_test.cc | 2 +- .../tsl => xla}/tsl/lib/gtl/iterator_range.h | 6 +- .../tsl/lib/gtl/iterator_range_test.cc | 2 +- .../tsl => xla}/tsl/lib/gtl/map_util.h | 8 +- .../tsl => xla}/tsl/lib/gtl/map_util_test.cc | 2 +- .../tsl => xla}/tsl/lib/gtl/subtle/BUILD | 6 +- .../tsl/lib/gtl/subtle/map_traits.h | 6 +- .../tsl => xla}/tsl/lib/hash/BUILD | 32 +- .../tsl => xla}/tsl/lib/hash/crc32c.cc | 2 +- .../tsl => xla}/tsl/lib/hash/crc32c.h | 6 +- .../tsl => xla}/tsl/lib/hash/crc32c_test.cc | 2 +- .../{third_party/tsl => xla}/tsl/lib/io/BUILD | 318 +++++++++--------- .../tsl => xla}/tsl/lib/io/block.cc | 4 +- .../tsl => xla}/tsl/lib/io/block.h | 8 +- .../tsl => xla}/tsl/lib/io/block_builder.cc | 4 +- .../tsl => xla}/tsl/lib/io/block_builder.h | 6 +- .../tsl => xla}/tsl/lib/io/buffered_file.h | 8 +- .../tsl/lib/io/buffered_file_test.cc | 2 +- .../tsl/lib/io/buffered_inputstream.cc | 4 +- .../tsl/lib/io/buffered_inputstream.h | 8 +- .../tsl/lib/io/buffered_inputstream_test.cc | 4 +- .../tsl => xla}/tsl/lib/io/cache.cc | 2 +- .../tsl => xla}/tsl/lib/io/cache.h | 6 +- .../tsl => xla}/tsl/lib/io/cache_test.cc | 2 +- .../tsl => xla}/tsl/lib/io/compression.cc | 2 +- .../tsl => xla}/tsl/lib/io/compression.h | 6 +- .../tsl => xla}/tsl/lib/io/format.cc | 6 +- .../tsl => xla}/tsl/lib/io/format.h | 8 +- .../tsl => xla}/tsl/lib/io/inputbuffer.cc | 2 +- .../tsl => xla}/tsl/lib/io/inputbuffer.h | 6 +- .../tsl/lib/io/inputbuffer_test.cc | 2 +- .../tsl/lib/io/inputstream_interface.cc | 2 +- .../tsl/lib/io/inputstream_interface.h | 6 +- .../tsl/lib/io/inputstream_interface_test.cc | 2 +- .../tsl => xla}/tsl/lib/io/iterator.cc | 2 +- .../tsl => xla}/tsl/lib/io/iterator.h | 6 +- .../tsl/lib/io/proto_encode_helper.h | 6 +- .../tsl/lib/io/random_inputstream.cc | 2 +- .../tsl/lib/io/random_inputstream.h | 8 +- .../tsl/lib/io/random_inputstream_test.cc | 2 +- .../tsl => xla}/tsl/lib/io/record_reader.cc | 10 +- .../tsl => xla}/tsl/lib/io/record_reader.h | 16 +- .../tsl/lib/io/record_reader_writer_test.cc | 4 +- .../tsl => xla}/tsl/lib/io/record_writer.cc | 6 +- .../tsl => xla}/tsl/lib/io/record_writer.h | 16 +- .../tsl => xla}/tsl/lib/io/recordio_test.cc | 8 +- .../tsl => xla}/tsl/lib/io/snappy/BUILD | 52 +-- .../io/snappy/snappy_compression_options.h | 6 +- .../tsl/lib/io/snappy/snappy_inputbuffer.cc | 2 +- .../tsl/lib/io/snappy/snappy_inputbuffer.h | 8 +- .../tsl/lib/io/snappy/snappy_inputstream.cc | 2 +- .../tsl/lib/io/snappy/snappy_inputstream.h | 8 +- .../tsl/lib/io/snappy/snappy_outputbuffer.cc | 2 +- .../tsl/lib/io/snappy/snappy_outputbuffer.h | 6 +- .../tsl/lib/io/snappy/snappy_test.cc | 10 +- .../tsl => xla}/tsl/lib/io/table.cc | 12 +- .../tsl => xla}/tsl/lib/io/table.h | 8 +- .../tsl => xla}/tsl/lib/io/table_builder.cc | 10 +- .../tsl => xla}/tsl/lib/io/table_builder.h | 8 +- .../tsl => xla}/tsl/lib/io/table_format.txt | 0 .../tsl => xla}/tsl/lib/io/table_options.h | 6 +- .../tsl => xla}/tsl/lib/io/table_test.cc | 14 +- .../tsl/lib/io/two_level_iterator.cc | 10 +- .../tsl/lib/io/two_level_iterator.h | 8 +- .../tsl/lib/io/zlib_buffers_test.cc | 8 +- .../tsl/lib/io/zlib_compression_options.cc | 2 +- .../tsl/lib/io/zlib_compression_options.h | 6 +- .../tsl/lib/io/zlib_inputstream.cc | 2 +- .../tsl => xla}/tsl/lib/io/zlib_inputstream.h | 10 +- .../tsl/lib/io/zlib_outputbuffer.cc | 2 +- .../tsl/lib/io/zlib_outputbuffer.h | 8 +- .../tsl => xla}/tsl/lib/math/BUILD | 16 +- .../tsl => xla}/tsl/lib/math/math_util.h | 6 +- .../tsl/lib/math/math_util_test.cc | 2 +- .../tsl => xla}/tsl/lib/random/BUILD | 74 ++-- .../tsl/lib/random/distribution_sampler.cc | 2 +- .../tsl/lib/random/distribution_sampler.h | 8 +- .../lib/random/distribution_sampler_test.cc | 4 +- .../tsl/lib/random/exact_uniform_int.h | 8 +- .../tsl/lib/random/philox_random.h | 6 +- .../tsl/lib/random/philox_random_test.cc | 6 +- .../tsl/lib/random/philox_random_test_utils.h | 8 +- .../tsl/lib/random/random_distributions.cc | 4 +- .../tsl/lib/random/random_distributions.h | 12 +- .../lib/random/random_distributions_test.cc | 8 +- .../lib/random/random_distributions_utils.h | 8 +- .../tsl/lib/random/simple_philox.cc | 4 +- .../tsl/lib/random/simple_philox.h | 10 +- .../tsl/lib/random/simple_philox_test.cc | 2 +- .../tsl/lib/random/weighted_picker.cc | 4 +- .../tsl/lib/random/weighted_picker.h | 6 +- .../tsl/lib/random/weighted_picker_test.cc | 4 +- third_party/xla/xla/tsl/lib/strings/BUILD | 2 +- .../tsl/lib/strings/proto_serialization.cc | 2 +- .../xla/xla/tsl/profiler/rpc/client/BUILD | 4 +- .../tsl/profiler/rpc/client/save_profile.cc | 4 +- third_party/xla/xla/util.h | 2 +- 244 files changed, 788 insertions(+), 787 deletions(-) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/BUILD (77%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/compactptrset.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/compactptrset_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/flatmap.h (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/flatmap_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/flatrep.h (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/flatset.h (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/flatset_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/inlined_vector.h (89%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/int_type.h (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/int_type_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/iterator_range.h (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/iterator_range_test.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/map_util.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/map_util_test.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/subtle/BUILD (69%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/gtl/subtle/map_traits.h (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/hash/BUILD (71%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/hash/crc32c.cc (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/hash/crc32c.h (94%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/hash/crc32c_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/BUILD (55%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/block.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/block.h (89%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/block_builder.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/block_builder.h (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/buffered_file.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/buffered_file_test.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/buffered_inputstream.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/buffered_inputstream.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/buffered_inputstream_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/cache.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/cache.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/cache_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/compression.cc (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/compression.h (86%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/format.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/format.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/inputbuffer.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/inputbuffer.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/inputbuffer_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/inputstream_interface.cc (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/inputstream_interface.h (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/inputstream_interface_test.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/iterator.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/iterator.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/proto_encode_helper.h (94%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/random_inputstream.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/random_inputstream.h (89%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/random_inputstream_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/record_reader.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/record_reader.h (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/record_reader_writer_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/record_writer.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/record_writer.h (92%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/recordio_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/BUILD (56%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/snappy_compression_options.h (84%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/snappy_inputbuffer.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/snappy_inputbuffer.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/snappy_inputstream.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/snappy_inputstream.h (92%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/snappy_outputbuffer.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/snappy_outputbuffer.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/snappy/snappy_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/table.cc (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/table.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/table_builder.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/table_builder.h (94%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/table_format.txt (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/table_options.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/table_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/two_level_iterator.cc (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/two_level_iterator.h (88%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/zlib_buffers_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/zlib_compression_options.cc (94%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/zlib_compression_options.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/zlib_inputstream.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/zlib_inputstream.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/zlib_outputbuffer.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/io/zlib_outputbuffer.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/math/BUILD (74%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/math/math_util.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/math/math_util_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/BUILD (73%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/distribution_sampler.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/distribution_sampler.h (92%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/distribution_sampler_test.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/exact_uniform_int.h (91%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/philox_random.h (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/philox_random_test.cc (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/philox_random_test_utils.h (86%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/random_distributions.cc (90%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/random_distributions.h (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/random_distributions_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/random_distributions_utils.h (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/simple_philox.cc (92%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/simple_philox.h (89%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/simple_philox_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/weighted_picker.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/weighted_picker.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/lib/random/weighted_picker_test.cc (98%) diff --git a/tensorflow/compiler/mlir/lite/debug/BUILD b/tensorflow/compiler/mlir/lite/debug/BUILD index ea0af7970aed14..4cd1b6e7eddb35 100644 --- a/tensorflow/compiler/mlir/lite/debug/BUILD +++ b/tensorflow/compiler/mlir/lite/debug/BUILD @@ -34,10 +34,10 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/lib/io:buffered_file", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:stringpiece", + "@local_xla//xla/tsl/lib/io:buffered_file", ], ) diff --git a/tensorflow/compiler/mlir/lite/debug/debug.cc b/tensorflow/compiler/mlir/lite/debug/debug.cc index d0b85019cfe200..8b4b611a18108e 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug.cc @@ -47,8 +47,8 @@ limitations under the License. #include "re2/re2.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" +#include "xla/tsl/lib/io/buffered_file.h" #include "tensorflow/core/platform/logging.h" -#include "tsl/lib/io/buffered_file.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" #include "tsl/platform/path.h" diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 5b3255fcd53c8f..d5f01c3d68e544 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1217,7 +1217,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/lib/io:buffered_file", + "@local_xla//xla/tsl/lib/io:buffered_file", ], ) @@ -1395,9 +1395,9 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/lib/math:math_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/service:hlo_parser", + "@local_xla//xla/tsl/lib/math:math_util", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index c49971a8a8c0c7..1270865e551d52 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -28,11 +28,11 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "xla/tsl/lib/io/buffered_file.h" #include "tensorflow/core/platform/crash_analysis.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/path.h" -#include "tsl/lib/io/buffered_file.h" using llvm::raw_ostream; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index addf76366984a0..d10b908e02d3c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -51,9 +51,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/service/hlo_parser.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tsl/lib/math/math_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b3054a1a951f0b..1bc6f7ba2ece5f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -273,13 +273,13 @@ cc_library( "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_headers", "//tensorflow/core/lib/io:legacy_lib_io_headers", "//tensorflow/core/lib/math:math_util.h", - "@local_tsl//tsl/lib/math:math_util.h", + "@local_xla//xla/tsl/lib/math:math_util.h", "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_lib_headers", "//tensorflow/core/lib/random:legacy_lib_random_headers", "//tensorflow/core/lib/strings:legacy_lib_string_headers", "//tensorflow/core/platform:lib_hdrs", "//tensorflow/core/util:lib_hdrs", - "@local_tsl//tsl/lib/io:legacy_lib_io_headers", + "@local_xla//xla/tsl/lib/io:legacy_lib_io_headers", "@local_tsl//tsl/platform:lib_hdrs", ], visibility = ["//visibility:public"], @@ -844,7 +844,7 @@ filegroup( "//tensorflow/core/graph:mobile_srcs_only_runtime", "//tensorflow/core/kernels:mobile_srcs", "//tensorflow/core/lib/io:mobile_srcs_only_runtime", - "@local_tsl//tsl/lib/io:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/io:mobile_srcs_only_runtime", "//tensorflow/core/nccl:mobile_srcs", "//tensorflow/core/profiler:mobile_srcs_only_runtime", "//tensorflow/core/public:mobile_srcs_only_runtime", @@ -859,7 +859,7 @@ filegroup( "//tensorflow/core/lib/hash:mobile_srcs_only_runtime", "//tensorflow/core/lib/histogram:mobile_srcs_only_runtime", "//tensorflow/core/lib/math:mobile_srcs_only_runtime", - "@local_tsl//tsl/lib/math:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/math:mobile_srcs_only_runtime", "//tensorflow/core/lib/monitoring:mobile_srcs_only_runtime", "//tensorflow/core/lib/random:mobile_srcs_only_runtime", "//tensorflow/core/lib/strings:mobile_srcs_only_runtime", @@ -1203,11 +1203,11 @@ filegroup( "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", "//tensorflow/core/platform:legacy_lib_internal_headers", "//tensorflow/core/platform:lib_internal_private_hdrs", - "@local_tsl//tsl/lib/io:legacy_lib_io_all_headers", - "@local_tsl//tsl/lib/math:math_util.h", "@local_tsl//tsl/platform:legacy_lib_internal_headers", "@local_tsl//tsl/platform:lib_internal_private_hdrs", "@local_xla//xla/tsl/lib/core:legacy_lib_core_all_headers", + "@local_xla//xla/tsl/lib/io:legacy_lib_io_all_headers", + "@local_xla//xla/tsl/lib/math:math_util.h", ] + glob( [ "lib/**/*.h", @@ -1236,9 +1236,9 @@ filegroup( "//tensorflow/core/platform:legacy_platform_lib_hdrs", "//tensorflow/core/platform:lib_internal_public_hdrs", "//tensorflow/core/util:lib_internal_public_hdrs", - "@local_tsl//tsl/lib/io:legacy_lib_internal_public_headers", "@local_tsl//tsl/platform:lib_internal_public_hdrs", "@local_xla//xla/stream_executor/integrations:device_mem_allocator_headers", + "@local_xla//xla/tsl/lib/io:legacy_lib_internal_public_headers", ], visibility = ["//visibility:private"], ) @@ -1441,7 +1441,7 @@ cc_library( "@com_google_protobuf//:protobuf", "@double_conversion//:double-conversion", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/math:math_util", + "@local_xla//xla/tsl/lib/math:math_util", "@ml_dtypes//:float8", "@ml_dtypes//:intn", "@snappy", @@ -1849,7 +1849,7 @@ cc_library( hdrs = [ "//tensorflow/core/lib/gtl:legacy_lib_test_internal_headers", "//tensorflow/core/lib/io:legacy_lib_test_internal_headers", - "@local_tsl//tsl/lib/io:legacy_lib_test_internal_headers", + "@local_xla//xla/tsl/lib/io:legacy_lib_test_internal_headers", ], deps = [ ":lib", diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index b3dc824bcdd4f9..83c18087ac73bf 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -26,12 +26,12 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/device_id.h" +#include "xla/tsl/lib/gtl/inlined_vector.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tensorflow/core/common_runtime/device/device_mem_allocator.h" #include "tensorflow/core/framework/typed_allocator.h" #include "tensorflow/core/protobuf/bfc_memory_map.pb.h" #include "tensorflow/core/protobuf/config.pb.h" -#include "tsl/lib/gtl/inlined_vector.h" -#include "tsl/lib/random/simple_philox.h" #include "tsl/platform/logging.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc index bb932a5af9e0c0..de65df20e2dad4 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc @@ -25,10 +25,10 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/device_id.h" +#include "xla/tsl/lib/gtl/inlined_vector.h" #include "tensorflow/core/common_runtime/device/device_mem_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" #include "tensorflow/core/framework/typed_allocator.h" -#include "tsl/lib/gtl/inlined_vector.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" #include "tsl/platform/types.h" diff --git a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc index 8974964c9b3a81..f95fafb9343669 100644 --- a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc +++ b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/dispatcher_client.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/service/snapshot/test_utils.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/file_utils_test.cc b/tensorflow/core/data/service/snapshot/file_utils_test.cc index 9582cab18bc143..4404719e95c38d 100644 --- a/tensorflow/core/data/service/snapshot/file_utils_test.cc +++ b/tensorflow/core/data/service/snapshot/file_utils_test.cc @@ -19,13 +19,13 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/dataset_test_base.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc index 43944c6a41b8f1..1623ac904c5484 100644 --- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc @@ -31,10 +31,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc b/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc index d28285966ff251..db6b7bd6733818 100644 --- a/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc +++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc @@ -26,10 +26,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc index 1e019a1742651e..8e7c473840c004 100644 --- a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc @@ -32,13 +32,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/split_provider.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.cc b/tensorflow/core/data/service/snapshot/snapshot_manager.cc index 8ea22871945af2..a3845361762bd8 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" @@ -43,7 +44,6 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/mutex.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc index b9b9f3d3d4d8ed..0f41ba42404b62 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc index 071c7e1f1c72a1..dea3f3dd5785d8 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/data/service/snapshot/test_utils.h" #include "tensorflow/core/data/service/task_runner.h" #include "tensorflow/core/data/service/test_util.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc index 00dbdd947eefac..2748ec214f1293 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "xla/tsl/lib/monitoring/cell_reader.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc index 8874484c835af0..b36970825ca050 100644 --- a/tensorflow/core/data/snapshot_utils.cc +++ b/tensorflow/core/data/snapshot_utils.cc @@ -27,6 +27,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/lib/io/snappy/snappy_inputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/framework/dataset.h" @@ -47,8 +49,6 @@ limitations under the License. #include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/io/snappy/snappy_inputbuffer.h" -#include "tsl/lib/io/snappy/snappy_outputbuffer.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/core/data/snapshot_utils.h b/tensorflow/core/data/snapshot_utils.h index 7e3b897bdddb0d..c421e783706622 100644 --- a/tensorflow/core/data/snapshot_utils.h +++ b/tensorflow/core/data/snapshot_utils.h @@ -251,7 +251,7 @@ class TFRecordReaderImpl { // Constructs a `TFRecordReaderImpl`. // `filename` is the file to read from. // `compression_type` is the compression method, as defined in - // tensorflow/tsl/lib/io/compression.h. + // tensorflow/compiler/xla/tsl/lib/io/compression.h. // `output_buffer_size` specifies the buffer size required by Snappy/Zlib // compression algorithms. Ignored if compression is not enabled. TFRecordReaderImpl(const std::string& filename, const string& compression, diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 5b20f183865e58..ee5080c90dd9c8 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -31,7 +31,7 @@ default_visibility = [ "//tensorflow/core:__subpackages__", "//tensorflow/security/fuzzing:__subpackages__", # TODO(pedaveeraiah): to be removed when summary.proto.h deps moves to TSL - "@local_tsl//tsl/lib:__subpackages__", + "@local_xla//xla/tsl/lib:__subpackages__", # copybara:uncomment "//learning/brain/tfrt/aot:__subpackages__", # copybara:uncomment "//platforms/xla/megascale/tensorflow:__subpackages__", # copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/graph_executor:__subpackages__", diff --git a/tensorflow/core/kernels/data/experimental/distributed_save_op.cc b/tensorflow/core/kernels/data/experimental/distributed_save_op.cc index 7e737026c8d4a4..82908eaf8ffb84 100644 --- a/tensorflow/core/kernels/data/experimental/distributed_save_op.cc +++ b/tensorflow/core/kernels/data/experimental/distributed_save_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/io/compression.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/stochastic_cast_op_test.cc b/tensorflow/core/kernels/stochastic_cast_op_test.cc index 10d9eae13249ff..9543afda020307 100644 --- a/tensorflow/core/kernels/stochastic_cast_op_test.cc +++ b/tensorflow/core/kernels/stochastic_cast_op_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/random/philox_random.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" -#include "tsl/lib/random/philox_random.h" namespace Eigen { namespace internal { diff --git a/tensorflow/core/lib/gtl/BUILD b/tensorflow/core/lib/gtl/BUILD index 868d05f0912fc8..801bf59f7dcf79 100644 --- a/tensorflow/core/lib/gtl/BUILD +++ b/tensorflow/core/lib/gtl/BUILD @@ -48,7 +48,7 @@ cc_library( name = "compactptrset", hdrs = ["compactptrset.h"], deps = [ - "@local_tsl//tsl/lib/gtl:compactptrset", + "@local_xla//xla/tsl/lib/gtl:compactptrset", ], ) @@ -75,7 +75,7 @@ cc_library( "//tensorflow/core/lib/hash", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/gtl:flatmap", + "@local_xla//xla/tsl/lib/gtl:flatmap", ], ) @@ -83,7 +83,7 @@ cc_library( name = "flatrep", hdrs = ["flatrep.h"], deps = [ - "@local_tsl//tsl/lib/gtl:flatrep", + "@local_xla//xla/tsl/lib/gtl:flatrep", ], ) @@ -91,7 +91,7 @@ cc_library( name = "flatset", hdrs = ["flatset.h"], deps = [ - "@local_tsl//tsl/lib/gtl:flatset", + "@local_xla//xla/tsl/lib/gtl:flatset", ], ) @@ -102,7 +102,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:types", "@com_google_absl//absl/container:inlined_vector", - "@local_tsl//tsl/lib/gtl:inlined_vector", + "@local_xla//xla/tsl/lib/gtl:inlined_vector", ], ) @@ -110,7 +110,7 @@ cc_library( name = "int_type", hdrs = ["int_type.h"], deps = [ - "@local_tsl//tsl/lib/gtl:int_type", + "@local_xla//xla/tsl/lib/gtl:int_type", ], ) @@ -118,7 +118,7 @@ cc_library( name = "iterator_range", hdrs = ["iterator_range.h"], deps = [ - "@local_tsl//tsl/lib/gtl:iterator_range", + "@local_xla//xla/tsl/lib/gtl:iterator_range", ], ) @@ -140,7 +140,7 @@ cc_library( hdrs = ["map_util.h"], deps = [ "//tensorflow/core/platform:hash", # TODO(dduneavy) examples/custom_ops_doc transitively depends on this - "@local_tsl//tsl/lib/gtl:map_util", + "@local_xla//xla/tsl/lib/gtl:map_util", ], ) @@ -167,7 +167,7 @@ filegroup( "inlined_vector.h", "iterator_range.h", "priority_queue_util.h", - "@local_tsl//tsl/lib/gtl:legacy_lib_gtl_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_lib_gtl_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -180,7 +180,7 @@ filegroup( "manual_constructor.h", "map_util.h", "top_n.h", - "@local_tsl//tsl/lib/gtl:legacy_lib_internal_public_gtl_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_lib_internal_public_gtl_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -189,7 +189,7 @@ filegroup( name = "legacy_lib_test_internal_headers", srcs = [ "manual_constructor.h", - "@local_tsl//tsl/lib/gtl:legacy_lib_test_internal_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_lib_test_internal_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -198,7 +198,7 @@ filegroup( name = "legacy_android_gif_internal_headers", srcs = [ "cleanup.h", - "@local_tsl//tsl/lib/gtl:legacy_android_gif_internal_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_android_gif_internal_headers", ], visibility = [ "//tensorflow/core:__pkg__", @@ -215,7 +215,7 @@ filegroup( "flatrep.h", "inlined_vector.h", "top_n.h", - "@local_tsl//tsl/lib/gtl:mobile_srcs_no_runtime", + "@local_xla//xla/tsl/lib/gtl:mobile_srcs_no_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -232,7 +232,7 @@ filegroup( "map_util.h", "priority_queue_util.h", "//tensorflow/core/lib/gtl/subtle:map_traits", - "@local_tsl//tsl/lib/gtl:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/gtl:mobile_srcs_only_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -255,7 +255,7 @@ filegroup( "priority_queue_util.h", "top_n.h", "//tensorflow/core/lib/gtl/subtle:map_traits", - "@local_tsl//tsl/lib/gtl:legacy_lib_gtl_all_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_lib_gtl_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/gtl/compactptrset.h b/tensorflow/core/lib/gtl/compactptrset.h index 326aca55d34f0e..6655ac92d99ec7 100644 --- a/tensorflow/core/lib/gtl/compactptrset.h +++ b/tensorflow/core/lib/gtl/compactptrset.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ #define TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ -#include "tsl/lib/gtl/compactptrset.h" +#include "xla/tsl/lib/gtl/compactptrset.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/flatmap.h b/tensorflow/core/lib/gtl/flatmap.h index 15c02b381daa19..3b112a714cb883 100644 --- a/tensorflow/core/lib/gtl/flatmap.h +++ b/tensorflow/core/lib/gtl/flatmap.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ #define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ +#include "xla/tsl/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatrep.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/gtl/flatmap.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/flatrep.h b/tensorflow/core/lib/gtl/flatrep.h index dfc8af5ef01dfa..59caa4b086708a 100644 --- a/tensorflow/core/lib/gtl/flatrep.h +++ b/tensorflow/core/lib/gtl/flatrep.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ #define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ -#include "tsl/lib/gtl/flatrep.h" +#include "xla/tsl/lib/gtl/flatrep.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/flatset.h b/tensorflow/core/lib/gtl/flatset.h index 659bf98f74f9b1..fcb7ed96b9a166 100644 --- a/tensorflow/core/lib/gtl/flatset.h +++ b/tensorflow/core/lib/gtl/flatset.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ #define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ -#include "tsl/lib/gtl/flatset.h" +#include "xla/tsl/lib/gtl/flatset.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index b94af28b839346..df9d1a245dbf9a 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ -#include "tsl/lib/gtl/inlined_vector.h" // IWYU pragma: export +#include "xla/tsl/lib/gtl/inlined_vector.h" // IWYU pragma: export // TODO(kramerb): This is kept only because lots of targets transitively depend // on it. Remove all targets' dependencies. #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/core/lib/gtl/int_type.h b/tensorflow/core/lib/gtl/int_type.h index 2259cb8bc3b84d..c161ee917e82cc 100644 --- a/tensorflow/core/lib/gtl/int_type.h +++ b/tensorflow/core/lib/gtl/int_type.h @@ -17,7 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_INT_TYPE_H_ #define TENSORFLOW_CORE_LIB_GTL_INT_TYPE_H_ -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/iterator_range.h b/tensorflow/core/lib/gtl/iterator_range.h index 4748761d8da0a8..ca980fd536b2d8 100644 --- a/tensorflow/core/lib/gtl/iterator_range.h +++ b/tensorflow/core/lib/gtl/iterator_range.h @@ -25,7 +25,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_ITERATOR_RANGE_H_ #define TENSORFLOW_CORE_LIB_GTL_ITERATOR_RANGE_H_ -#include "tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/iterator_range.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/map_util.h b/tensorflow/core/lib/gtl/map_util.h index 3b2d767548809d..47d28e7dd23e1b 100644 --- a/tensorflow/core/lib/gtl/map_util.h +++ b/tensorflow/core/lib/gtl/map_util.h @@ -20,7 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_MAP_UTIL_H_ #define TENSORFLOW_CORE_LIB_GTL_MAP_UTIL_H_ -#include "tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/gtl/map_util.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/subtle/BUILD b/tensorflow/core/lib/gtl/subtle/BUILD index 2b79160c01ea0e..f74d6f7604eec5 100644 --- a/tensorflow/core/lib/gtl/subtle/BUILD +++ b/tensorflow/core/lib/gtl/subtle/BUILD @@ -12,7 +12,7 @@ filegroup( name = "map_traits", srcs = [ "map_traits.h", - "@local_tsl//tsl/lib/gtl/subtle:map_traits", + "@local_xla//xla/tsl/lib/gtl/subtle:map_traits", ], visibility = ["//tensorflow/core/lib/gtl:__pkg__"], ) diff --git a/tensorflow/core/lib/gtl/subtle/map_traits.h b/tensorflow/core/lib/gtl/subtle/map_traits.h index a5296b8b93a010..c4cca1fb644640 100644 --- a/tensorflow/core/lib/gtl/subtle/map_traits.h +++ b/tensorflow/core/lib/gtl/subtle/map_traits.h @@ -23,7 +23,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ #define TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ -#include "tsl/lib/gtl/subtle/map_traits.h" +#include "xla/tsl/lib/gtl/subtle/map_traits.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/hash/BUILD b/tensorflow/core/lib/hash/BUILD index c2b6018d034e64..8c1e8cd471776d 100644 --- a/tensorflow/core/lib/hash/BUILD +++ b/tensorflow/core/lib/hash/BUILD @@ -26,7 +26,7 @@ cc_library( "//tensorflow/core/platform", "//tensorflow/core/platform:cord", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/hash:crc32c", + "@local_xla//xla/tsl/lib/hash:crc32c", ], ) @@ -51,7 +51,7 @@ filegroup( name = "mobile_srcs_only_runtime", srcs = [ "crc32c.h", - "@local_tsl//tsl/lib/hash:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/hash:mobile_srcs_only_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -61,7 +61,7 @@ filegroup( srcs = [ "crc32c.h", "hash.h", - "@local_tsl//tsl/lib/hash:legacy_lib_hash_all_headers", + "@local_xla//xla/tsl/lib/hash:legacy_lib_hash_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h index 07945b10a92b0e..7e8c8307af2e92 100644 --- a/tensorflow/core/lib/hash/crc32c.h +++ b/tensorflow/core/lib/hash/crc32c.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "xla/tsl/lib/hash/crc32c.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/hash/crc32c.h" namespace tensorflow { namespace crc32c { diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD index a525a92c43ba65..2a4a4aa5e00c17 100644 --- a/tensorflow/core/lib/io/BUILD +++ b/tensorflow/core/lib/io/BUILD @@ -9,7 +9,7 @@ package( default_visibility = [ "//tensorflow/c/experimental/filesystem:__pkg__", "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__", - "@local_tsl//tsl/lib/io/snappy:__pkg__", + "@local_xla//xla/tsl/lib/io/snappy:__pkg__", "//third_party/py/tensorflow_io:__subpackages__", # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core:__pkg__", @@ -31,7 +31,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:block", + "@local_xla//xla/tsl/lib/io:block", ], ) @@ -41,14 +41,14 @@ cc_library( deps = [ ":inputstream_interface", "//tensorflow/core/platform:env", - "@local_tsl//tsl/lib/io:buffered_inputstream", + "@local_xla//xla/tsl/lib/io:buffered_inputstream", ], ) cc_library( name = "compression", hdrs = ["compression.h"], - deps = ["@local_tsl//tsl/lib/io:compression"], + deps = ["@local_xla//xla/tsl/lib/io:compression"], ) cc_library( @@ -60,7 +60,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:status", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:inputbuffer", + "@local_xla//xla/tsl/lib/io:inputbuffer", ], ) @@ -72,7 +72,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:inputstream_interface", + "@local_xla//xla/tsl/lib/io:inputstream_interface", ], ) @@ -82,7 +82,7 @@ cc_library( deps = [ "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", - "@local_tsl//tsl/lib/io:iterator", + "@local_xla//xla/tsl/lib/io:iterator", ], ) @@ -100,7 +100,7 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:stringpiece", - "@local_tsl//tsl/lib/io:proto_encode_helper", + "@local_xla//xla/tsl/lib/io:proto_encode_helper", ], ) @@ -111,7 +111,7 @@ cc_library( ":inputstream_interface", "//tensorflow/core/platform:cord", "//tensorflow/core/platform:env", - "@local_tsl//tsl/lib/io:random_inputstream", + "@local_xla//xla/tsl/lib/io:random_inputstream", ], ) @@ -127,7 +127,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:record_reader", + "@local_xla//xla/tsl/lib/io:record_reader", ], ) @@ -145,28 +145,28 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:record_writer", + "@local_xla//xla/tsl/lib/io:record_writer", ], ) alias( name = "snappy_inputbuffer", - actual = "@local_tsl//tsl/lib/io/snappy:snappy_inputbuffer", + actual = "@local_xla//xla/tsl/lib/io/snappy:snappy_inputbuffer", ) alias( name = "snappy_inputstream", - actual = "@local_tsl//tsl/lib/io/snappy:snappy_inputstream", + actual = "@local_xla//xla/tsl/lib/io/snappy:snappy_inputstream", ) alias( name = "snappy_outputbuffer", - actual = "@local_tsl//tsl/lib/io/snappy:snappy_outputbuffer", + actual = "@local_xla//xla/tsl/lib/io/snappy:snappy_outputbuffer", ) alias( name = "snappy_compression_options", - actual = "@local_tsl//tsl/lib/io/snappy:snappy_compression_options", + actual = "@local_xla//xla/tsl/lib/io/snappy:snappy_compression_options", ) cc_library( @@ -174,7 +174,7 @@ cc_library( hdrs = ["cache.h"], deps = [ "//tensorflow/core/platform:stringpiece", - "@local_tsl//tsl/lib/io:cache", + "@local_xla//xla/tsl/lib/io:cache", ], ) @@ -186,14 +186,14 @@ cc_library( ], deps = [ ":iterator", - "@local_tsl//tsl/lib/io:table", + "@local_xla//xla/tsl/lib/io:table", ], ) cc_library( name = "table_options", hdrs = ["table_options.h"], - deps = ["@local_tsl//tsl/lib/io:table_options"], + deps = ["@local_xla//xla/tsl/lib/io:table_options"], ) cc_library( @@ -201,7 +201,7 @@ cc_library( hdrs = ["zlib_compression_options.h"], deps = [ "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:zlib_compression_options", + "@local_xla//xla/tsl/lib/io:zlib_compression_options", ], ) @@ -215,7 +215,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:status", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:zlib_inputstream", + "@local_xla//xla/tsl/lib/io:zlib_inputstream", ], ) @@ -229,7 +229,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:zlib_outputbuffer", + "@local_xla//xla/tsl/lib/io:zlib_outputbuffer", ], ) diff --git a/tensorflow/core/lib/io/block.h b/tensorflow/core/lib/io/block.h index e6417881718060..d3cfb88f97e46f 100644 --- a/tensorflow/core/lib/io/block.h +++ b/tensorflow/core/lib/io/block.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_BLOCK_H_ #define TENSORFLOW_CORE_LIB_IO_BLOCK_H_ +#include "xla/tsl/lib/io/block.h" #include "tensorflow/core/lib/io/iterator.h" -#include "tsl/lib/io/block.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h index b83db6dbfa5726..b47278cba40e30 100644 --- a/tensorflow/core/lib/io/block_builder.h +++ b/tensorflow/core/lib/io/block_builder.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_BLOCK_BUILDER_H_ #define TENSORFLOW_CORE_LIB_IO_BLOCK_BUILDER_H_ +#include "xla/tsl/lib/io/block_builder.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/block_builder.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h index b211dc05efefc1..15023e6aa5d5b0 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.h +++ b/tensorflow/core/lib/io/buffered_inputstream.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_ #define TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#include "xla/tsl/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/file_system.h" -#include "tsl/lib/io/buffered_inputstream.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/cache.h b/tensorflow/core/lib/io/cache.h index 7c647d80090fdf..3afd011fdf79e5 100644 --- a/tensorflow/core/lib/io/cache.h +++ b/tensorflow/core/lib/io/cache.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_CACHE_H_ #define TENSORFLOW_CORE_LIB_IO_CACHE_H_ +#include "xla/tsl/lib/io/cache.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/cache.h" namespace tensorflow { using tsl::Slice; // NOLINT(misc-unused-using-decls) diff --git a/tensorflow/core/lib/io/compression.h b/tensorflow/core/lib/io/compression.h index 326e1a17e5144d..628de3751edb04 100644 --- a/tensorflow/core/lib/io/compression.h +++ b/tensorflow/core/lib/io/compression.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ #define TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ -#include "tsl/lib/io/compression.h" +#include "xla/tsl/lib/io/compression.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/format.h b/tensorflow/core/lib/io/format.h index 0ea77614bc4a5c..49f96d1929c658 100644 --- a/tensorflow/core/lib/io/format.h +++ b/tensorflow/core/lib/io/format.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_FORMAT_H_ #define TENSORFLOW_CORE_LIB_IO_FORMAT_H_ +#include "xla/tsl/lib/io/format.h" #include "tensorflow/core/lib/io/table_builder.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/format.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h index 0cc33cb1aac895..2573a81657c056 100644 --- a/tensorflow/core/lib/io/inputbuffer.h +++ b/tensorflow/core/lib/io/inputbuffer.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_INPUTBUFFER_H_ #define TENSORFLOW_CORE_LIB_IO_INPUTBUFFER_H_ +#include "xla/tsl/lib/io/inputbuffer.h" #include "tensorflow/core/platform/coding.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/inputbuffer.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/inputstream_interface.h b/tensorflow/core/lib/io/inputstream_interface.h index 135043a4f356ba..f38489d55c6d86 100644 --- a/tensorflow/core/lib/io/inputstream_interface.h +++ b/tensorflow/core/lib/io/inputstream_interface.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ #define TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/inputstream_interface.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/iterator.h b/tensorflow/core/lib/io/iterator.h index a758cf51ef10e9..4f3c096086a4a5 100644 --- a/tensorflow/core/lib/io/iterator.h +++ b/tensorflow/core/lib/io/iterator.h @@ -26,9 +26,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_ITERATOR_H_ #define TENSORFLOW_CORE_LIB_IO_ITERATOR_H_ +#include "xla/tsl/lib/io/iterator.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/iterator.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/proto_encode_helper.h b/tensorflow/core/lib/io/proto_encode_helper.h index 97b98bac26630b..8ca1d5beb300da 100644 --- a/tensorflow/core/lib/io/proto_encode_helper.h +++ b/tensorflow/core/lib/io/proto_encode_helper.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_ #define TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_ +#include "xla/tsl/lib/io/proto_encode_helper.h" #include "tensorflow/core/platform/coding.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/proto_encode_helper.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/random_inputstream.h b/tensorflow/core/lib/io/random_inputstream.h index cb3c5ed6f98326..70651bc67f3d5c 100644 --- a/tensorflow/core/lib/io/random_inputstream.h +++ b/tensorflow/core/lib/io/random_inputstream.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_ #define TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_ +#include "xla/tsl/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/file_system.h" -#include "tsl/lib/io/random_inputstream.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index 51a332b6fbd0db..c2a06c6b666908 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" #endif // IS_SLIM_BUILD +#include "xla/tsl/lib/io/record_reader.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/record_reader.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 63dd44427a2ea1..602de00ed872d5 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -24,10 +24,10 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" #endif // IS_SLIM_BUILD +#include "xla/tsl/lib/io/record_writer.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/record_writer.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h index 93466ddaa6e315..0045829a1af5c1 100644 --- a/tensorflow/core/lib/io/table.h +++ b/tensorflow/core/lib/io/table.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_TABLE_H_ #define TENSORFLOW_CORE_LIB_IO_TABLE_H_ +#include "xla/tsl/lib/io/table.h" #include "tensorflow/core/lib/io/iterator.h" -#include "tsl/lib/io/table.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h index a895387a2c3fbe..52e27e9af9ef94 100644 --- a/tensorflow/core/lib/io/table_builder.h +++ b/tensorflow/core/lib/io/table_builder.h @@ -24,10 +24,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_ #define TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_ +#include "xla/tsl/lib/io/table_builder.h" #include "tensorflow/core/lib/io/table_options.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/table_builder.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h index b751aefdbeaef3..c16d4aca7e30b6 100644 --- a/tensorflow/core/lib/io/table_options.h +++ b/tensorflow/core/lib/io/table_options.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_ #define TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_ -#include "tsl/lib/io/table_options.h" +#include "xla/tsl/lib/io/table_options.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/two_level_iterator.h b/tensorflow/core/lib/io/two_level_iterator.h index 357efd1d5993b0..c2b94de7f26439 100644 --- a/tensorflow/core/lib/io/two_level_iterator.h +++ b/tensorflow/core/lib/io/two_level_iterator.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_TWO_LEVEL_ITERATOR_H_ #define TENSORFLOW_CORE_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#include "xla/tsl/lib/io/two_level_iterator.h" #include "tensorflow/core/lib/io/iterator.h" -#include "tsl/lib/io/two_level_iterator.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/zlib_compression_options.h b/tensorflow/core/lib/io/zlib_compression_options.h index 643c041ec6efc0..a0d433782b69cb 100644 --- a/tensorflow/core/lib/io/zlib_compression_options.h +++ b/tensorflow/core/lib/io/zlib_compression_options.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ #define TENSORFLOW_CORE_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#include "xla/tsl/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/zlib_compression_options.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h index 75bef87ca38c2e..086493e31face5 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.h +++ b/tensorflow/core/lib/io/zlib_inputstream.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_INPUTSTREAM_H_ #define TENSORFLOW_CORE_LIB_IO_ZLIB_INPUTSTREAM_H_ +#include "xla/tsl/lib/io/zlib_inputstream.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/zlib_inputstream.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index f68a594bda8551..7d3950f633abbe 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_OUTPUTBUFFER_H_ #define TENSORFLOW_CORE_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/zlib_outputbuffer.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/math/BUILD b/tensorflow/core/lib/math/BUILD index 751af8e9d026ac..7c14708ed116ec 100644 --- a/tensorflow/core/lib/math/BUILD +++ b/tensorflow/core/lib/math/BUILD @@ -20,7 +20,7 @@ cc_library( deps = [ "//tensorflow/core/platform:logging", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/math:math_util", + "@local_xla//xla/tsl/lib/math:math_util", ], ) diff --git a/tensorflow/core/lib/math/math_util.h b/tensorflow/core/lib/math/math_util.h index b92e421b6e2821..39bae7f4308a48 100644 --- a/tensorflow/core/lib/math/math_util.h +++ b/tensorflow/core/lib/math/math_util.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_ #define TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_ +#include "xla/tsl/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/math/math_util.h" namespace tensorflow { // NOLINTBEGIN(misc-unused-using-decls) diff --git a/tensorflow/core/lib/random/BUILD b/tensorflow/core/lib/random/BUILD index db2c962671c3f0..ef6262c5876cef 100644 --- a/tensorflow/core/lib/random/BUILD +++ b/tensorflow/core/lib/random/BUILD @@ -13,7 +13,7 @@ package( cc_library( name = "exact_uniform_int", hdrs = ["exact_uniform_int.h"], - deps = ["@local_tsl//tsl/lib/random:exact_uniform_int"], + deps = ["@local_xla//xla/tsl/lib/random:exact_uniform_int"], ) cc_library( @@ -32,7 +32,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:types", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/random:philox", + "@local_xla//xla/tsl/lib/random:philox", ], ) @@ -43,7 +43,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":philox_random", - "@local_tsl//tsl/lib/random:random_distributions_utils", + "@local_xla//xla/tsl/lib/random:random_distributions_utils", ], ) @@ -51,7 +51,7 @@ cc_library( name = "philox_random", hdrs = ["philox_random.h"], compatible_with = get_compatible_with_portable(), - deps = ["@local_tsl//tsl/lib/random:philox_random"], + deps = ["@local_xla//xla/tsl/lib/random:philox_random"], ) cc_library( @@ -74,7 +74,7 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/random:weighted_picker", + "@local_xla//xla/tsl/lib/random:weighted_picker", ], ) @@ -90,7 +90,7 @@ filegroup( "random_distributions_utils.h", "simple_philox.h", "weighted_picker.h", - "@local_tsl//tsl/lib/random:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/random:mobile_srcs_only_runtime", ], ) @@ -102,7 +102,7 @@ filegroup( "random_distributions.h", "random_distributions_utils.h", "simple_philox.h", - "@local_tsl//tsl/lib/random:legacy_lib_random_headers", + "@local_xla//xla/tsl/lib/random:legacy_lib_random_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -114,7 +114,7 @@ filegroup( "random_distributions.h", "random_distributions_utils.h", "weighted_picker.h", - "@local_tsl//tsl/lib/random:legacy_lib_internal_public_random_headers", + "@local_xla//xla/tsl/lib/random:legacy_lib_internal_public_random_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -130,7 +130,7 @@ filegroup( "random_distributions_utils.h", "simple_philox.h", "weighted_picker.h", - "@local_tsl//tsl/lib/random:legacy_lib_random_all_headers", + "@local_xla//xla/tsl/lib/random:legacy_lib_random_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h index 585def64e91b73..6218d8998fa1ab 100644 --- a/tensorflow/core/lib/random/distribution_sampler.h +++ b/tensorflow/core/lib/random/distribution_sampler.h @@ -31,12 +31,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ #define TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#include "xla/tsl/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/random/distribution_sampler.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/exact_uniform_int.h b/tensorflow/core/lib/random/exact_uniform_int.h index 5f02c664b74f0b..cd511d43f55510 100644 --- a/tensorflow/core/lib/random/exact_uniform_int.h +++ b/tensorflow/core/lib/random/exact_uniform_int.h @@ -18,7 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_EXACT_UNIFORM_INT_H_ #define TENSORFLOW_CORE_LIB_RANDOM_EXACT_UNIFORM_INT_H_ -#include "tsl/lib/random/exact_uniform_int.h" +#include "xla/tsl/lib/random/exact_uniform_int.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h index c2d44ecfd541de..2fe4120f9674b3 100644 --- a/tensorflow/core/lib/random/philox_random.h +++ b/tensorflow/core/lib/random/philox_random.h @@ -20,7 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_ #define TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_ -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index 0a9e4f94f3d72d..57ce99a07333b0 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/tsl/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions_utils.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/random/random_distributions.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/random_distributions_utils.h b/tensorflow/core/lib/random/random_distributions_utils.h index 4b7267031bdd86..4c2680493bceae 100644 --- a/tensorflow/core/lib/random/random_distributions_utils.h +++ b/tensorflow/core/lib/random/random_distributions_utils.h @@ -20,8 +20,8 @@ limitations under the License. #include +#include "xla/tsl/lib/random/random_distributions_utils.h" #include "tensorflow/core/lib/random/philox_random.h" -#include "tsl/lib/random/random_distributions_utils.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h index fa7f49ebecfaaf..7c94ca21414459 100644 --- a/tensorflow/core/lib/random/simple_philox.h +++ b/tensorflow/core/lib/random/simple_philox.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_ #define TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_ +#include "xla/tsl/lib/random/simple_philox.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" -#include "tsl/lib/random/simple_philox.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/weighted_picker.h b/tensorflow/core/lib/random/weighted_picker.h index 58d4198ae97d4a..ae404814960096 100644 --- a/tensorflow/core/lib/random/weighted_picker.h +++ b/tensorflow/core/lib/random/weighted_picker.h @@ -29,10 +29,10 @@ limitations under the License. #include +#include "xla/tsl/lib/random/weighted_picker.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/random/weighted_picker.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/profiler/convert/trace_viewer/BUILD b/tensorflow/core/profiler/convert/trace_viewer/BUILD index 85ee0784aada65..eef42433fbdadd 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/BUILD +++ b/tensorflow/core/profiler/convert/trace_viewer/BUILD @@ -125,9 +125,9 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/lib/io:iterator", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:context_types_hdrs", "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/lib/io:iterator", ], ) diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc index cb0de415dba4f8..828550de9a4594 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc @@ -31,6 +31,10 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/table.h" +#include "xla/tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/io/table_options.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -39,10 +43,6 @@ limitations under the License. #include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" -#include "tsl/lib/io/iterator.h" -#include "tsl/lib/io/table.h" -#include "tsl/lib/io/table_builder.h" -#include "tsl/lib/io/table_options.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events.h index 3b627417e6d706..e5a76838a6fa6b 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events.h @@ -34,13 +34,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/lib/io/table.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h" #include "tensorflow/core/profiler/lib/context_types.h" #include "tensorflow/core/profiler/protobuf/task.pb.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/lib/io/table.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" #include "tsl/platform/status.h" diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD index 4ca9b222fb114b..9aad78e3890057 100644 --- a/tensorflow/core/util/tensor_bundle/BUILD +++ b/tensorflow/core/util/tensor_bundle/BUILD @@ -61,7 +61,7 @@ cc_library( "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/io:buffered_file", + "@local_xla//xla/tsl/lib/io:buffered_file", "@local_xla//xla/tsl/util:byte_swap_array", ], ) diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 87038432ae9428..c97356202bcd93 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/lib/io/buffered_file.h" #include "xla/tsl/util/byte_swap_array.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -54,7 +55,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h" #include "tensorflow/core/util/tensor_bundle/naming.h" #include "tensorflow/core/util/tensor_slice_util.h" -#include "tsl/lib/io/buffered_file.h" #ifdef PLATFORM_WINDOWS #undef DeleteFile diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index ba1a4f7053aac6..e3d8bb590ce411 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -72,6 +72,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/functional/function_ref.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/io/buffered_file.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" @@ -87,7 +88,6 @@ limitations under the License. #include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/protobuf/tensor_bundle.pb.h" #include "tensorflow/core/util/tensor_slice_set.h" -#include "tsl/lib/io/buffered_file.h" #include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 4c08175c3abe93..bc7d65f517ca9e 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -885,8 +885,8 @@ cc_library( "//tensorflow/lite:array", "//tensorflow/lite:builtin_ops", "//tensorflow/lite:cc_api_stable", - "@local_tsl//tsl/lib/random:philox_random", - "@local_tsl//tsl/lib/random:random_distributions_utils", + "@local_xla//xla/tsl/lib/random:philox_random", + "@local_xla//xla/tsl/lib/random:random_distributions_utils", "//tensorflow/lite/core/c:c_api_types", # TODO(b/179298174): Move out from the experimental directory. "//tensorflow/lite/experimental/resource", diff --git a/tensorflow/lite/kernels/random_ops.cc b/tensorflow/lite/kernels/random_ops.cc index 70665061f39cb4..28f0e3f80ccf2b 100644 --- a/tensorflow/lite/kernels/random_ops.cc +++ b/tensorflow/lite/kernels/random_ops.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "xla/tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/random_distributions_utils.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tsl/lib/random/philox_random.h" -#include "tsl/lib/random/random_distributions_utils.h" namespace tflite { namespace ops { diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD index 77a79f56e1ae44..46bca19c70940d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD @@ -207,7 +207,6 @@ cc_library( copts = tsl_copts(), deps = [ ":http_request", - "//tsl/lib/gtl:map_util", "//tsl/platform:env", "//tsl/platform:errors", "//tsl/platform:macros", @@ -218,6 +217,7 @@ cc_library( "//tsl/platform:stringpiece", "//tsl/platform:types", "@curl", + "@local_xla//xla/tsl/lib/gtl:map_util", "@local_xla//xla/tsl/util:env_var", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc index 44eeab7f511fb9..fde422c2d04919 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/tsl/util/env_var.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/scanner.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD index 4c0eda1496c5b7..6539e90ea0157c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD @@ -119,7 +119,6 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":tf_op_utils", - "//tsl/lib/gtl:map_util", "//tsl/platform:logging", "//tsl/platform:macros", "//tsl/platform:types", @@ -128,6 +127,7 @@ cc_library( "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_xla//xla/tsl/lib/gtl:map_util", ], ) @@ -289,7 +289,6 @@ cc_library( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/lib/gtl:map_util", "//tsl/platform:dso_loader", "//tsl/platform:env", "//tsl/platform:logging", @@ -301,6 +300,7 @@ cc_library( "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_xla//xla/tsl/lib/gtl:map_util", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc index 2811232b43bc5e..d8f3d4ad94c12e 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc @@ -31,7 +31,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" -#include "tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/types.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index da16a8704187be..deed680f1d8bb5 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "tsl/profiler/utils/tf_op_utils.h" namespace tsl { diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index c9316d9c546119..0d9ad2f953a560 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -336,6 +336,8 @@ cc_library( ":status_macros", ":types", ":xla_data_proto_cc", + "//xla/tsl/lib/gtl:iterator_range", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -353,8 +355,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/gtl:iterator_range", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:bfloat16", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", @@ -911,10 +911,10 @@ cc_library( ":util", ":xla_data_proto_cc", "//xla/service:hlo_parser", + "//xla/tsl/lib/io:buffered_inputstream", + "//xla/tsl/lib/io:random_inputstream", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/io:buffered_inputstream", - "@local_tsl//tsl/lib/io:random_inputstream", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:protobuf", ], @@ -976,12 +976,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":shape_util", + "//xla/tsl/lib/gtl:iterator_range", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:iterator_range", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -1063,11 +1063,11 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/service:shape_inference", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", ], ) diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 1e19e92cbc844c..115e56f66fbcfb 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -48,10 +48,10 @@ cc_library( deps = [ "//xla:types", "//xla:util", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", ], diff --git a/third_party/xla/xla/client/padding.cc b/third_party/xla/xla/client/padding.cc index daf26d5467ac7b..8f4a536c0805e4 100644 --- a/third_party/xla/xla/client/padding.cc +++ b/third_party/xla/xla/client/padding.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/util.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index d4ff3dff4c8215..f8105e33fffe54 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -228,11 +228,11 @@ cc_library( hdrs = ["type_id_registry.h"], deps = [ "//xla:util", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/gtl:int_type", ], ) diff --git a/third_party/xla/xla/ffi/type_id_registry.h b/third_party/xla/xla/ffi/type_id_registry.h index 116142b3de0f23..5672ac691e253b 100644 --- a/third_party/xla/xla/ffi/type_id_registry.h +++ b/third_party/xla/xla/ffi/type_id_registry.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla::ffi { diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index e65c48d982d89d..b39bc59b5a801d 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -85,6 +85,8 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:mapped_ptr_container_sorter", "//xla/service:name_uniquer", + "//xla/tsl/lib/gtl:iterator_range", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", @@ -102,8 +104,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:iterator_range", - "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 34ea42a0536f9e..13255a919211e7 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -51,9 +51,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 3e73a68762e74f..e6463f774cf513 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -48,9 +48,9 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 2860693f88b2c0..ad23f22909db5e 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -76,10 +76,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/sort_json.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 3dcf016acd6cd0..42729daec64df3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -64,8 +64,8 @@ limitations under the License. #include "xla/service/name_uniquer.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/protobuf.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index 0f88a38a0448e2..3909ced6f00f92 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -59,10 +59,10 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/protobuf.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index c0f03248dbf772..8a0fca0bade44d 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -44,8 +44,8 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index 0e98622801e97f..fc210509351a22 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -56,9 +56,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_module.h b/third_party/xla/xla/hlo/ir/hlo_module.h index 6dea7d5234fc8b..65004c386e8bfa 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.h +++ b/third_party/xla/xla/hlo/ir/hlo_module.h @@ -46,8 +46,8 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/name_uniquer.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/xla.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/logging.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_schedule.cc b/third_party/xla/xla/hlo/ir/hlo_schedule.cc index 18747b3ae50f9e..6a472b55904023 100644 --- a/third_party/xla/xla/hlo/ir/hlo_schedule.cc +++ b/third_party/xla/xla/hlo/ir/hlo_schedule.cc @@ -33,8 +33,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/map_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" -#include "tsl/lib/gtl/map_util.h" namespace xla { diff --git a/third_party/xla/xla/iterator_util.h b/third_party/xla/xla/iterator_util.h index 80001e4a9b2996..2457348e2d3f8e 100644 --- a/third_party/xla/xla/iterator_util.h +++ b/third_party/xla/xla/iterator_util.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/iterator_range.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index a242a71adc4747..b98825d15f9b97 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -364,7 +364,7 @@ cc_library( hdrs = ["pjrt_common.h"], visibility = internal_visibility([":friends"]), deps = [ - "@local_tsl//tsl/lib/gtl:int_type", + "//xla/tsl/lib/gtl:int_type", ], ) diff --git a/third_party/xla/xla/pjrt/pjrt_common.h b/third_party/xla/xla/pjrt/pjrt_common.h index 042d28acd12a05..8d11cdae79b3c9 100644 --- a/third_party/xla/xla/pjrt/pjrt_common.h +++ b/third_party/xla/xla/pjrt/pjrt_common.h @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 6531018329ad3c..9380b8c811db8e 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -105,6 +105,7 @@ cc_library( "//xla/python/ifrt/ir:sharding_param", "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -120,7 +121,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/python/ifrt/device.h b/third_party/xla/xla/python/ifrt/device.h index 3da8fe57c5a772..8c5db780ff41c3 100644 --- a/third_party/xla/xla/python/ifrt/device.h +++ b/third_party/xla/xla/python/ifrt/device.h @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.pb.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt/device_list.h b/third_party/xla/xla/python/ifrt/device_list.h index b34522f9b75686..b2ddbe221abe3f 100644 --- a/third_party/xla/xla/python/ifrt/device_list.h +++ b/third_party/xla/xla/python/ifrt/device_list.h @@ -34,6 +34,7 @@ limitations under the License. #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device.pb.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/reference_util.cc b/third_party/xla/xla/reference_util.cc index d7461cadcfe4a4..6322783b4ad940 100644 --- a/third_party/xla/xla/reference_util.cc +++ b/third_party/xla/xla/reference_util.cc @@ -38,9 +38,9 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" namespace xla { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 2f770516b8d9d1..c307f128ca1342 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -654,6 +654,8 @@ cc_library( "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/io:zlib_compression_options", + "//xla/tsl/lib/io:zlib_outputbuffer", "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -669,8 +671,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/lib/io:zlib_compression_options", - "@local_tsl//tsl/lib/io:zlib_outputbuffer", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -2180,13 +2180,13 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service/heap_simulator", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", @@ -4483,13 +4483,13 @@ cc_library( "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:errors", ], ) @@ -4635,8 +4635,8 @@ cc_library( deps = [ ":buffer_value", ":logical_buffer", + "//xla/tsl/lib/gtl:compactptrset", "@com_google_absl//absl/container:flat_hash_set", - "@local_tsl//tsl/lib/gtl:compactptrset", ], ) @@ -4651,9 +4651,9 @@ cc_library( "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", ], ) @@ -4982,6 +4982,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/gtl:compactptrset", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -4990,7 +4991,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:compactptrset", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", @@ -5925,6 +5925,9 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/stream_executor:dnn", + "//xla/tsl/lib/gtl:map_util", + "//xla/tsl/lib/io:zlib_compression_options", + "//xla/tsl/lib/io:zlib_outputbuffer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -5937,9 +5940,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/gtl:map_util", - "@local_tsl//tsl/lib/io:zlib_compression_options", - "@local_tsl//tsl/lib/io:zlib_outputbuffer", "@local_tsl//tsl/platform:base64", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -6993,6 +6993,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", @@ -7005,7 +7006,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", @@ -7633,9 +7633,9 @@ cc_library( hdrs = ["global_device_id.h"], deps = [ "//xla:types", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", ], ) diff --git a/third_party/xla/xla/service/buffer_value_containers.h b/third_party/xla/xla/service/buffer_value_containers.h index 2e02dd8df7dec3..9b2cfaffee730b 100644 --- a/third_party/xla/xla/service/buffer_value_containers.h +++ b/third_party/xla/xla/service/buffer_value_containers.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "xla/service/buffer_value.h" #include "xla/service/logical_buffer.h" -#include "tsl/lib/gtl/compactptrset.h" +#include "xla/tsl/lib/gtl/compactptrset.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 72f4d1cbb613ac..ee615cd9ef75d2 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -768,6 +768,7 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "//xla/service/llvm_ir:tuple_ops", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -782,7 +783,6 @@ cc_library( "@llvm-project//llvm:Core", "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index e043b5c2e13bec..9d566236406b78 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -87,10 +87,10 @@ limitations under the License. #include "xla/service/llvm_ir/tuple_ops.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index 3aa3a8862011a3..fcfbc22159d4ac 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -50,10 +50,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_proto_util.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" diff --git a/third_party/xla/xla/service/global_device_id.h b/third_party/xla/xla/service/global_device_id.h index 78f4c0a3dc914a..92f30b9f1c11cb 100644 --- a/third_party/xla/xla/service/global_device_id.h +++ b/third_party/xla/xla/service/global_device_id.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index db3a909357eda6..6676cafcc20dd1 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -178,7 +178,7 @@ gpu_kernel_library( deps = [ "//xla:types", "//xla/stream_executor/gpu:gpu_types_header", - "@local_tsl//tsl/lib/math:math_util", + "//xla/tsl/lib/math:math_util", ], ) diff --git a/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h index ee3d71f4a36423..c5390649ac9945 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h @@ -26,7 +26,7 @@ limitations under the License. #include "xla/service/gpu/kernels/topk_kernel_common.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "tsl/lib/math/math_util.h" +#include "xla/tsl/lib/math/math_util.h" #if GOOGLE_CUDA diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index af58af0a37647b..051e81048e96a4 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -650,6 +650,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:name_uniquer", "//xla/service/gpu:backend_configs_cc", + "//xla/tsl/lib/gtl:iterator_range", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", @@ -660,7 +661,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@local_tsl//tsl/lib/gtl:iterator_range", "@local_tsl//tsl/platform:errors", ], ) diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h b/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h index 5708a4c3401c36..13d7456dfccaec 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h @@ -27,8 +27,8 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/util.h" -#include "tsl/lib/gtl/iterator_range.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index e3c9a4bb7f2de9..773106baf16ac5 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -291,6 +291,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//xla/service:global_device_id", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/crc:crc32c", "@com_google_absl//absl/status", @@ -299,7 +300,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", ], ) @@ -1118,6 +1118,7 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor", "//xla/translate/mhlo_to_hlo:location_exporter", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -1127,7 +1128,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h index 0946ce62ef7275..8bb3f2740320e0 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h @@ -29,7 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/service/global_device_id.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla::gpu { diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index db0e49c355f102..4b6d345bbc2228 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -46,7 +46,7 @@ limitations under the License. #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/hlo_cost_analysis.cc b/third_party/xla/xla/service/hlo_cost_analysis.cc index c49476e8f927e3..ef7e2de0bdaf95 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis.cc @@ -35,9 +35,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" #include "xla/window_util.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index de91298040015d..6e78d70b9638f9 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -73,12 +73,12 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" #include "xla/stream_executor/dnn.h" +#include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "xla/types.h" #include "xla/util.h" #include "xla/window_util.h" -#include "tsl/lib/gtl/map_util.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/base64.h" #include "tsl/platform/env.h" #include "tsl/platform/numbers.h" diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index 7977e011fda528..9befbac1b48fcb 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -74,10 +74,10 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/service/logical_buffer.h b/third_party/xla/xla/service/logical_buffer.h index f951baea8c26f8..350bbdfcd31c46 100644 --- a/third_party/xla/xla/service/logical_buffer.h +++ b/third_party/xla/xla/service/logical_buffer.h @@ -24,9 +24,9 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/int_type.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/int_type.h" namespace xla { diff --git a/third_party/xla/xla/service/tuple_points_to_analysis.h b/third_party/xla/xla/service/tuple_points_to_analysis.h index 0b9710d3075810..2832f2e1f2193a 100644 --- a/third_party/xla/xla/service/tuple_points_to_analysis.h +++ b/third_party/xla/xla/service/tuple_points_to_analysis.h @@ -34,9 +34,9 @@ limitations under the License. #include "xla/service/logical_buffer.h" #include "xla/service/logical_buffer_analysis.h" #include "xla/shape_tree.h" +#include "xla/tsl/lib/gtl/compactptrset.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/compactptrset.h" #include "tsl/platform/status.h" namespace xla { diff --git a/third_party/xla/xla/shape_tree.h b/third_party/xla/xla/shape_tree.h index ba4e13560fd2c3..fd4448e0265089 100644 --- a/third_party/xla/xla/shape_tree.h +++ b/third_party/xla/xla/shape_tree.h @@ -31,7 +31,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 510c2b6db1799f..4bc6ee63aef070 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -107,6 +107,7 @@ cc_library( "//xla/stream_executor/platform", "//xla/tsl/framework:device_id", "//xla/tsl/framework:device_type", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -121,7 +122,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -157,9 +157,9 @@ cc_library( ":device_description_proto_cc", ":launch_dim", ":semantic_version", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", ], ) @@ -592,10 +592,10 @@ cc_library( ":kernel", ":launch_dim", ":platform", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:errors", ], ) diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 2b92b504f2059a..a5e7ac61ccd0e4 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -28,7 +28,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" #include "tsl/platform/errors.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/device_description.cc b/third_party/xla/xla/stream_executor/device_description.cc index ca19e68ffdc382..7486bda1002d72 100644 --- a/third_party/xla/xla/stream_executor/device_description.cc +++ b/third_party/xla/xla/stream_executor/device_description.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/launch_dim.h" -#include "tsl/lib/math/math_util.h" +#include "xla/tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 1812e7e269cf17..c5c4a3fbb1ba75 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -579,6 +579,7 @@ gpu_only_cc_library( "//xla/stream_executor:device_memory_handle", "//xla/stream_executor:scratch_allocator", "//xla/tsl/framework:allocator", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -586,7 +587,6 @@ gpu_only_cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ] + if_rocm_is_configured([ diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc index f15b543167d7d1..57aea9ac3fa644 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/allocator.h" -#include "tsl/lib/math/math_util.h" +#include "xla/tsl/lib/math/math_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 43e24cfd8d86f3..4c61d86e168f46 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1418,9 +1418,9 @@ xla_test( "//xla/client/lib:arithmetic", "//xla/client/lib:math", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], diff --git a/third_party/xla/xla/tests/batch_normalization_test.cc b/third_party/xla/xla/tests/batch_normalization_test.cc index c388c0fd72b89c..7fd723d82f1f97 100644 --- a/third_party/xla/xla/tests/batch_normalization_test.cc +++ b/third_party/xla/xla/tests/batch_normalization_test.cc @@ -39,10 +39,10 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/text_literal_reader.cc b/third_party/xla/xla/text_literal_reader.cc index 209790bc33eafa..3aaa23a8f958a0 100644 --- a/third_party/xla/xla/text_literal_reader.cc +++ b/third_party/xla/xla/text_literal_reader.cc @@ -30,11 +30,11 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/io/buffered_inputstream.h" -#include "tsl/lib/io/random_inputstream.h" #include "tsl/platform/protobuf.h" namespace xla { diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index ddfc7c005d1d51..acc2b000972189 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -56,11 +56,11 @@ xla_cc_binary( name = "hex_floats_to_packed_literal", srcs = ["hex_floats_to_packed_literal.cc"], deps = [ + "//xla/tsl/lib/io:buffered_inputstream", + "//xla/tsl/lib/io:random_inputstream", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/io:buffered_inputstream", - "@local_tsl//tsl/lib/io:random_inputstream", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", diff --git a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc index 6388a8fb84d71c..659e4cde814b5d 100644 --- a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc +++ b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "xla/tsl/util/command_line_flags.h" -#include "tsl/lib/io/buffered_inputstream.h" -#include "tsl/lib/io/random_inputstream.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" #include "tsl/platform/init_main.h" diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD index 6fe7b4064235f8..5b7462f31a7ffd 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD @@ -83,9 +83,9 @@ cc_library( deps = [ ":grpc_channel_common", ":grpc_util", + "//xla/tsl/lib/gtl:map_util", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc index 3e80a68c3bab02..bd925a64bf0335 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc @@ -26,8 +26,8 @@ limitations under the License. #include "absl/strings/str_split.h" #include "grpcpp/create_channel.h" #include "xla/tsl/distributed_runtime/rpc/grpc_channel_common.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" diff --git a/third_party/xla/xla/tsl/framework/BUILD b/third_party/xla/xla/tsl/framework/BUILD index 27bfe40de0ea69..24172266463a92 100644 --- a/third_party/xla/xla/tsl/framework/BUILD +++ b/third_party/xla/xla/tsl/framework/BUILD @@ -118,7 +118,7 @@ cc_library( ] + if_static( extra_deps = [ ":allocator_registry_impl", - "@local_tsl//tsl/lib/gtl:inlined_vector", + "//xla/tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringprintf", "@local_tsl//tsl/platform:env", @@ -131,7 +131,7 @@ cc_library( "@local_tsl//tsl/platform:types", ], otherwise = [ - "@local_tsl//tsl/lib/gtl:inlined_vector", + "//xla/tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:strcat", @@ -162,10 +162,10 @@ cc_library( deps = [ ":numeric_types", ":type_traits", + "//xla/tsl/lib/gtl:inlined_vector", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", @@ -233,7 +233,7 @@ cc_library( "device_id_manager.h", ], deps = [ - "@local_tsl//tsl/lib/gtl:int_type", + "//xla/tsl/lib/gtl:int_type", ] + if_static([ ":device_id_impl", ]), @@ -248,9 +248,9 @@ cc_library( ], deps = [ ":device_type", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", @@ -365,8 +365,8 @@ cc_library( ), visibility = ["//visibility:public"], deps = [ + "//xla/tsl/lib/gtl:flatmap", "@com_google_absl//absl/memory", - "@local_tsl//tsl/lib/gtl:flatmap", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/tsl/framework/cancellation.h b/third_party/xla/xla/tsl/framework/cancellation.h index 38f7ebf60a63b2..6dd04e269ff5d3 100644 --- a/third_party/xla/xla/tsl/framework/cancellation.h +++ b/third_party/xla/xla/tsl/framework/cancellation.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/flatmap.h" +#include "xla/tsl/lib/gtl/flatmap.h" #include "tsl/platform/hash.h" #include "tsl/platform/mutex.h" #include "tsl/platform/notification.h" diff --git a/third_party/xla/xla/tsl/framework/device_id.h b/third_party/xla/xla/tsl/framework/device_id.h index b56c9ecbc64ec1..e80d84298195fe 100644 --- a/third_party/xla/xla/tsl/framework/device_id.h +++ b/third_party/xla/xla/tsl/framework/device_id.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_FRAMEWORK_DEVICE_ID_H_ #define XLA_TSL_FRAMEWORK_DEVICE_ID_H_ -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/tracking_allocator.h b/third_party/xla/xla/tsl/framework/tracking_allocator.h index 32d0026db63464..b0e4288fc99617 100644 --- a/third_party/xla/xla/tsl/framework/tracking_allocator.h +++ b/third_party/xla/xla/tsl/framework/tracking_allocator.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/tsl/framework/allocator.h" -#include "tsl/lib/gtl/inlined_vector.h" +#include "xla/tsl/lib/gtl/inlined_vector.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" #include "tsl/platform/types.h" diff --git a/third_party/xla/xla/tsl/lib/core/bitmap_test.cc b/third_party/xla/xla/tsl/lib/core/bitmap_test.cc index b748249d7b79f1..bab7f7e4bc9bf5 100644 --- a/third_party/xla/xla/tsl/lib/core/bitmap_test.cc +++ b/third_party/xla/xla/tsl/lib/core/bitmap_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/lib/core/bitmap.h" -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/macros.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD b/third_party/xla/xla/tsl/lib/gtl/BUILD similarity index 77% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD rename to third_party/xla/xla/tsl/lib/gtl/BUILD index f601fb129e1521..ceac5767c2cf08 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD +++ b/third_party/xla/xla/tsl/lib/gtl/BUILD @@ -1,13 +1,13 @@ load( - "@local_tsl//tsl/platform:rules_cc.bzl", - "cc_library", + "@local_tsl//tsl/platform:build_config.bzl", + "tsl_cc_test", ) -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", + "@local_tsl//tsl/platform:rules_cc.bzl", + "cc_library", ) +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,21 +16,21 @@ package( "//tensorflow/core:__pkg__", # tensorflow/core/lib/strings:proto_serialization uses on gtl:inlined_vector "//tensorflow/core/lib/strings:__pkg__", - "//tsl/lib/strings:__pkg__", + "//xla/tsl/lib/strings:__pkg__", # tensorflow/core/framework uses map_util, and flatmap "//tensorflow/core/framework:__pkg__", - "@local_xla//xla/tsl/framework:__pkg__", - "//tsl/platform/cloud:__pkg__", + "//xla/tsl/framework:__pkg__", + "@local_tsl//tsl/platform/cloud:__pkg__", # tensorflow/core/util uses inlined_vector "//tensorflow/core/util:__pkg__", # tensorflow/core/tfrt/utils uses inlined_vector "//tensorflow/core/tfrt/utils:__pkg__", # tensorflow/examples/custom_ops_doc/simple_hash_table uses map_util "//tensorflow/examples/custom_ops_doc/simple_hash_table:__pkg__", - "@local_xla//xla:__subpackages__", + "//xla:__subpackages__", "//tensorflow/core/lib/gtl:__subpackages__", - "@local_xla//xla/tsl/distributed_runtime/rpc:__pkg__", - "//tsl/profiler/utils:__pkg__", + "//xla/tsl/distributed_runtime/rpc:__pkg__", + "@local_tsl//tsl/profiler/utils:__pkg__", ]), licenses = ["notice"], ) @@ -48,9 +48,9 @@ cc_library( hdrs = ["flatmap.h"], deps = [ ":flatrep", - "//tsl/platform:hash", - "//tsl/platform:logging", - "//tsl/platform:types", + "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", ], ) @@ -58,8 +58,8 @@ cc_library( name = "flatrep", hdrs = ["flatrep.h"], deps = [ - "//tsl/platform:types", "@com_google_absl//absl/base:prefetch", + "@local_tsl//tsl/platform:types", ], ) @@ -68,9 +68,9 @@ cc_library( hdrs = ["flatset.h"], deps = [ ":flatrep", - "//tsl/platform:hash", - "//tsl/platform:logging", - "//tsl/platform:types", + "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", ], ) @@ -78,10 +78,10 @@ cc_library( name = "inlined_vector", hdrs = ["inlined_vector.h"], deps = [ - "//tsl/platform:macros", - "//tsl/platform:types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", ], ) @@ -89,8 +89,8 @@ cc_library( name = "int_type", hdrs = ["int_type.h"], deps = [ - "//tsl/platform:macros", - "//tsl/platform:types", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", ], ) @@ -103,7 +103,7 @@ cc_library( name = "map_util", srcs = [ "map_util.h", - "//tsl/lib/gtl/subtle:map_traits", + "//xla/tsl/lib/gtl/subtle:map_traits", ], hdrs = ["map_util.h"], ) @@ -166,7 +166,7 @@ filegroup( visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", - "//tsl:__subpackages__", + "@local_tsl//tsl:__subpackages__", ]), ) @@ -177,7 +177,7 @@ filegroup( "int_type.h", "iterator_range.h", "map_util.h", - "//tsl/lib/gtl/subtle:map_traits", + "//xla/tsl/lib/gtl/subtle:map_traits", ], visibility = internal_visibility([ "//tensorflow/core:__pkg__", @@ -196,7 +196,7 @@ filegroup( "int_type.h", "iterator_range.h", "map_util.h", - "//tsl/lib/gtl/subtle:map_traits", + "//xla/tsl/lib/gtl/subtle:map_traits", ], visibility = internal_visibility([ "//tensorflow/core:__pkg__", @@ -221,10 +221,10 @@ tsl_cc_test( ":int_type", ":iterator_range", ":map_util", - "//tsl/platform:hash", - "//tsl/platform:macros", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset.h b/third_party/xla/xla/tsl/lib/gtl/compactptrset.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset.h rename to third_party/xla/xla/tsl/lib/gtl/compactptrset.h index 8fbb7a8560dd1b..3848430e76fb92 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset.h +++ b/third_party/xla/xla/tsl/lib/gtl/compactptrset.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_COMPACTPTRSET_H_ -#define TENSORFLOW_TSL_LIB_GTL_COMPACTPTRSET_H_ +#ifndef XLA_TSL_LIB_GTL_COMPACTPTRSET_H_ +#define XLA_TSL_LIB_GTL_COMPACTPTRSET_H_ #include -#include "tsl/lib/gtl/flatset.h" +#include "xla/tsl/lib/gtl/flatset.h" namespace tsl { namespace gtl { @@ -206,4 +206,4 @@ class CompactPointerSet { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_COMPACTPTRSET_H_ +#endif // XLA_TSL_LIB_GTL_COMPACTPTRSET_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset_test.cc b/third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset_test.cc rename to third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc index 9dc146c2e52b79..6f5e52dc085047 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/compactptrset.h" +#include "xla/tsl/lib/gtl/compactptrset.h" #include "tsl/platform/hash.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h b/third_party/xla/xla/tsl/lib/gtl/flatmap.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h rename to third_party/xla/xla/tsl/lib/gtl/flatmap.h index 8d5cf7912e9d78..e74dbd46531d9a 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatmap.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ -#define TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ +#ifndef XLA_TSL_LIB_GTL_FLATMAP_H_ +#define XLA_TSL_LIB_GTL_FLATMAP_H_ #include @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/flatrep.h" +#include "xla/tsl/lib/gtl/flatrep.h" #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" @@ -393,4 +393,4 @@ class FlatMap { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ +#endif // XLA_TSL_LIB_GTL_FLATMAP_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap_test.cc b/third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap_test.cc rename to third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc index a2b4fd11df3dbf..231970ccbe45ac 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/flatmap.h" +#include "xla/tsl/lib/gtl/flatmap.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h b/third_party/xla/xla/tsl/lib/gtl/flatrep.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h rename to third_party/xla/xla/tsl/lib/gtl/flatrep.h index d6c77e7de363ea..74ae18fc37c0f8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatrep.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_FLATREP_H_ -#define TENSORFLOW_TSL_LIB_GTL_FLATREP_H_ +#ifndef XLA_TSL_LIB_GTL_FLATREP_H_ +#define XLA_TSL_LIB_GTL_FLATREP_H_ #include @@ -350,4 +350,4 @@ class FlatRep { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_FLATREP_H_ +#endif // XLA_TSL_LIB_GTL_FLATREP_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h b/third_party/xla/xla/tsl/lib/gtl/flatset.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h rename to third_party/xla/xla/tsl/lib/gtl/flatset.h index b3178225647fe1..f272ad1fa7bd1d 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatset.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_FLATSET_H_ -#define TENSORFLOW_TSL_LIB_GTL_FLATSET_H_ +#ifndef XLA_TSL_LIB_GTL_FLATSET_H_ +#define XLA_TSL_LIB_GTL_FLATSET_H_ #include @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/flatrep.h" +#include "xla/tsl/lib/gtl/flatrep.h" #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" @@ -293,4 +293,4 @@ class FlatSet { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_FLATSET_H_ +#endif // XLA_TSL_LIB_GTL_FLATSET_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset_test.cc b/third_party/xla/xla/tsl/lib/gtl/flatset_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatset_test.cc rename to third_party/xla/xla/tsl/lib/gtl/flatset_test.cc index abf7892f2d8798..8adb9133a76ecb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/flatset_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/flatset.h" +#include "xla/tsl/lib/gtl/flatset.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/inlined_vector.h b/third_party/xla/xla/tsl/lib/gtl/inlined_vector.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/inlined_vector.h rename to third_party/xla/xla/tsl/lib/gtl/inlined_vector.h index fc8533b02937ab..6072f87ff6931d 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/inlined_vector.h +++ b/third_party/xla/xla/tsl/lib/gtl/inlined_vector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_INLINED_VECTOR_H_ -#define TENSORFLOW_TSL_LIB_GTL_INLINED_VECTOR_H_ +#ifndef XLA_TSL_LIB_GTL_INLINED_VECTOR_H_ +#define XLA_TSL_LIB_GTL_INLINED_VECTOR_H_ #include @@ -39,4 +39,4 @@ using InlinedVector ABSL_DEPRECATE_AND_INLINE() = absl::InlinedVector; } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_INLINED_VECTOR_H_ +#endif // XLA_TSL_LIB_GTL_INLINED_VECTOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/int_type.h b/third_party/xla/xla/tsl/lib/gtl/int_type.h similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/int_type.h rename to third_party/xla/xla/tsl/lib/gtl/int_type.h index 7a5d7935782884..2a54fc58fada8f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/int_type.h +++ b/third_party/xla/xla/tsl/lib/gtl/int_type.h @@ -149,8 +149,8 @@ limitations under the License. // void GetGlobalDoc(int64 global) { ... // GetGlobalDoc(local.value()); <-- Compiles fine. -#ifndef TENSORFLOW_TSL_LIB_GTL_INT_TYPE_H_ -#define TENSORFLOW_TSL_LIB_GTL_INT_TYPE_H_ +#ifndef XLA_TSL_LIB_GTL_INT_TYPE_H_ +#define XLA_TSL_LIB_GTL_INT_TYPE_H_ #include @@ -361,4 +361,4 @@ INT_TYPE_COMPARISON_OP(>=); // NOLINT } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_INT_TYPE_H_ +#endif // XLA_TSL_LIB_GTL_INT_TYPE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/int_type_test.cc b/third_party/xla/xla/tsl/lib/gtl/int_type_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/int_type_test.cc rename to third_party/xla/xla/tsl/lib/gtl/int_type_test.cc index 2716eb139fa0a9..6ab47039fe1653 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/int_type_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/int_type_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Unit test cases for IntType. -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range.h b/third_party/xla/xla/tsl/lib/gtl/iterator_range.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range.h rename to third_party/xla/xla/tsl/lib/gtl/iterator_range.h index 6e420c940142cc..2914dce38c7f9e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range.h +++ b/third_party/xla/xla/tsl/lib/gtl/iterator_range.h @@ -22,8 +22,8 @@ limitations under the License. // // Converted from chandlerc@'s code to Google style by joshl@. -#ifndef TENSORFLOW_TSL_LIB_GTL_ITERATOR_RANGE_H_ -#define TENSORFLOW_TSL_LIB_GTL_ITERATOR_RANGE_H_ +#ifndef XLA_TSL_LIB_GTL_ITERATOR_RANGE_H_ +#define XLA_TSL_LIB_GTL_ITERATOR_RANGE_H_ #include @@ -65,4 +65,4 @@ iterator_range make_range(T x, T y) { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_ITERATOR_RANGE_H_ +#endif // XLA_TSL_LIB_GTL_ITERATOR_RANGE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range_test.cc b/third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range_test.cc rename to third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc index 35d1fe5854d8b8..08028094552ff1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/map_util.h b/third_party/xla/xla/tsl/lib/gtl/map_util.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/map_util.h rename to third_party/xla/xla/tsl/lib/gtl/map_util.h index 63a966228481bc..d04ba3644a09a0 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/map_util.h +++ b/third_party/xla/xla/tsl/lib/gtl/map_util.h @@ -17,8 +17,8 @@ limitations under the License. // structures, such as std::map and hash_map. Some functions will also work with // sets, such as ContainsKey(). -#ifndef TENSORFLOW_TSL_LIB_GTL_MAP_UTIL_H_ -#define TENSORFLOW_TSL_LIB_GTL_MAP_UTIL_H_ +#ifndef XLA_TSL_LIB_GTL_MAP_UTIL_H_ +#define XLA_TSL_LIB_GTL_MAP_UTIL_H_ #include @@ -27,7 +27,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/subtle/map_traits.h" +#include "xla/tsl/lib/gtl/subtle/map_traits.h" namespace tsl { namespace gtl { @@ -212,4 +212,4 @@ typename Collection::value_type::second_type EraseKeyReturnValuePtr( } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_MAP_UTIL_H_ +#endif // XLA_TSL_LIB_GTL_MAP_UTIL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/map_util_test.cc b/third_party/xla/xla/tsl/lib/gtl/map_util_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/map_util_test.cc rename to third_party/xla/xla/tsl/lib/gtl/map_util_test.cc index 7ecf4a4b394251..ce2a13c9e394e9 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/map_util_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/map_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/gtl/map_util.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD b/third_party/xla/xla/tsl/lib/gtl/subtle/BUILD similarity index 69% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD rename to third_party/xla/xla/tsl/lib/gtl/subtle/BUILD index e2f9763dfd2b58..dfedc36f004cb1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD +++ b/third_party/xla/xla/tsl/lib/gtl/subtle/BUILD @@ -1,8 +1,8 @@ # Description: # gtl subtle packages. -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,6 +16,6 @@ filegroup( ], visibility = internal_visibility([ "//tensorflow/core/lib/gtl/subtle:__pkg__", - "//tsl/lib/gtl:__pkg__", + "//xla/tsl/lib/gtl:__pkg__", ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/map_traits.h b/third_party/xla/xla/tsl/lib/gtl/subtle/map_traits.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/map_traits.h rename to third_party/xla/xla/tsl/lib/gtl/subtle/map_traits.h index 535db74402ba91..961dc550747bd2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/map_traits.h +++ b/third_party/xla/xla/tsl/lib/gtl/subtle/map_traits.h @@ -17,8 +17,8 @@ limitations under the License. // 1. If T has a `first` or `second` field, use them. // 2. Otherwise if it has `key()` or `value()` methods, use them. // 3. Otherwise the program is ill-formed. -#ifndef TENSORFLOW_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ -#define TENSORFLOW_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ +#ifndef XLA_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ +#define XLA_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ #include namespace tsl { namespace gtl { @@ -62,4 +62,4 @@ auto GetMapped(V&& v) } // namespace subtle } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ +#endif // XLA_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/BUILD b/third_party/xla/xla/tsl/lib/hash/BUILD similarity index 71% rename from third_party/xla/third_party/tsl/tsl/lib/hash/BUILD rename to third_party/xla/xla/tsl/lib/hash/BUILD index c497abfe17ac47..a25dc7d9cda14b 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/BUILD +++ b/third_party/xla/xla/tsl/lib/hash/BUILD @@ -1,24 +1,24 @@ +load( + "@local_tsl//tsl/platform:build_config.bzl", + "tsl_cc_test", +) load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) load( - "@local_xla//xla/tsl:tsl.bzl", + "//xla/tsl:tsl.bzl", "if_linux_x86_64", "internal_visibility", "tsl_copts", ) -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") -load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", -) +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = internal_visibility([ - # tensorflow/tsl/lib/io/table_builder.cc uses crc functionality - "//tsl/lib/io:__pkg__", + # tensorflow/compiler/xla/tsl/lib/io/table_builder.cc uses crc functionality + "//xla/tsl/lib/io:__pkg__", # tensorflow/core/lib/hash aliases hash for now "//tensorflow/core/lib/hash:__pkg__", ]), @@ -34,12 +34,12 @@ cc_library( # -msse4.2 enables the use of crc32c compiler builtins. copts = tsl_copts() + if_linux_x86_64(["-msse4.2"]), deps = [ - "//tsl/platform", - "//tsl/platform:cord", - "//tsl/platform:types", "@com_google_absl//absl/crc:crc32c", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:types", ], ) @@ -67,11 +67,11 @@ tsl_cc_test( srcs = ["crc32c_test.cc"], deps = [ ":crc32c", - "//tsl/platform:logging", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "//tsl/platform:types", "@com_google_absl//absl/strings:cord", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.cc b/third_party/xla/xla/tsl/lib/hash/crc32c.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.cc rename to third_party/xla/xla/tsl/lib/hash/crc32c.cc index 1bd005b6b05297..8ad835fb1d80f8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.cc +++ b/third_party/xla/xla/tsl/lib/hash/crc32c.cc @@ -16,7 +16,7 @@ limitations under the License. // A portable implementation of crc32c, optimized to handle // four bytes at a time. -#include "tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/hash/crc32c.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.h b/third_party/xla/xla/tsl/lib/hash/crc32c.h similarity index 94% rename from third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.h rename to third_party/xla/xla/tsl/lib/hash/crc32c.h index 10c4ea13e864d5..29c71eed3f0a99 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.h +++ b/third_party/xla/xla/tsl/lib/hash/crc32c.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_HASH_CRC32C_H_ -#define TENSORFLOW_TSL_LIB_HASH_CRC32C_H_ +#ifndef XLA_TSL_LIB_HASH_CRC32C_H_ +#define XLA_TSL_LIB_HASH_CRC32C_H_ #include @@ -67,4 +67,4 @@ inline uint32 Unmask(uint32 masked_crc) { } // namespace crc32c } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_HASH_CRC32C_H_ +#endif // XLA_TSL_LIB_HASH_CRC32C_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c_test.cc b/third_party/xla/xla/tsl/lib/hash/crc32c_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/hash/crc32c_test.cc rename to third_party/xla/xla/tsl/lib/hash/crc32c_test.cc index 9ba2e6e8108cf7..291121d5043f6f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c_test.cc +++ b/third_party/xla/xla/tsl/lib/hash/crc32c_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/hash/crc32c.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD b/third_party/xla/xla/tsl/lib/io/BUILD similarity index 55% rename from third_party/xla/third_party/tsl/tsl/lib/io/BUILD rename to third_party/xla/xla/tsl/lib/io/BUILD index cd527743282c01..43152cb1ea1444 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD +++ b/third_party/xla/xla/tsl/lib/io/BUILD @@ -1,24 +1,24 @@ +load("@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test") load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") -load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = internal_visibility([ "//tensorflow/c/experimental/filesystem:__pkg__", "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__", - "//tsl/lib/io/snappy:__pkg__", - "@local_xla//xla:__subpackages__", + "//xla/tsl/lib/io/snappy:__pkg__", + "//xla:__subpackages__", # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core/util:__subpackages__", "//tensorflow/core:__pkg__", "//tensorflow/core/lib/io:__subpackages__", - "@local_xla//xla/tsl/profiler:__subpackages__", - "//tsl/profiler:__subpackages__", + "//xla/tsl/profiler:__subpackages__", + "@local_tsl//tsl/profiler:__subpackages__", "//tensorflow/core/profiler:__subpackages__", ]), licenses = ["notice"], @@ -41,16 +41,16 @@ cc_library( deps = [ ":iterator", ":table_options", - "//tsl/lib/hash:crc32c", - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:platform_port", - "//tsl/platform:raw_coding", - "//tsl/platform:status", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "//xla/tsl/lib/hash:crc32c", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:raw_coding", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -62,8 +62,8 @@ cc_library( deps = [ ":inputstream_interface", ":random_inputstream", - "//tsl/platform:env", "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:env", ], alwayslink = True, ) @@ -80,13 +80,13 @@ cc_library( srcs = ["inputbuffer.cc"], hdrs = ["inputbuffer.h"], deps = [ - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:status", - "//tsl/platform:types", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -96,10 +96,10 @@ cc_library( srcs = ["inputstream_interface.cc"], hdrs = ["inputstream_interface.h"], deps = [ - "//tsl/platform:cord", - "//tsl/platform:errors", - "//tsl/platform:status", - "//tsl/platform:types", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -109,8 +109,8 @@ cc_library( srcs = ["iterator.cc"], hdrs = ["iterator.h"], deps = [ - "//tsl/platform:status", - "//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", ], alwayslink = True, ) @@ -119,10 +119,10 @@ cc_library( name = "proto_encode_helper", hdrs = ["proto_encode_helper.h"], deps = [ - "//tsl/platform:coding", - "//tsl/platform:logging", - "//tsl/platform:protobuf", - "//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:stringpiece", ], ) @@ -132,8 +132,8 @@ cc_library( hdrs = ["random_inputstream.h"], deps = [ ":inputstream_interface", - "//tsl/platform:cord", - "//tsl/platform:env", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:env", ], alwayslink = True, ) @@ -151,13 +151,13 @@ cc_library( ":snappy_inputstream", ":zlib_compression_options", ":zlib_inputstream", - "//tsl/lib/hash:crc32c", - "//tsl/platform:env", - "//tsl/platform:errors", - "//tsl/platform:macros", - "//tsl/platform:raw_coding", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "//xla/tsl/lib/hash:crc32c", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:raw_coding", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -172,36 +172,36 @@ cc_library( ":snappy_outputbuffer", ":zlib_compression_options", ":zlib_outputbuffer", - "//tsl/lib/hash:crc32c", - "//tsl/platform:coding", - "//tsl/platform:cord", - "//tsl/platform:env", - "//tsl/platform:macros", - "//tsl/platform:status", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "//xla/tsl/lib/hash:crc32c", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) alias( name = "snappy_inputbuffer", - actual = "//tsl/lib/io/snappy:snappy_inputbuffer", + actual = "//xla/tsl/lib/io/snappy:snappy_inputbuffer", ) alias( name = "snappy_inputstream", - actual = "//tsl/lib/io/snappy:snappy_inputstream", + actual = "//xla/tsl/lib/io/snappy:snappy_inputstream", ) alias( name = "snappy_outputbuffer", - actual = "//tsl/lib/io/snappy:snappy_outputbuffer", + actual = "//xla/tsl/lib/io/snappy:snappy_outputbuffer", ) alias( name = "snappy_compression_options", - actual = "//tsl/lib/io/snappy:snappy_compression_options", + actual = "//xla/tsl/lib/io/snappy:snappy_compression_options", ) cc_library( @@ -213,9 +213,9 @@ cc_library( "cache.h", ], deps = [ - "//tsl/platform:mutex", - "//tsl/platform:raw_coding", - "//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:raw_coding", + "@local_tsl//tsl/platform:stringpiece", ], ) @@ -234,9 +234,9 @@ cc_library( ":cache", ":iterator", ":table_options", - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:errors", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", ], alwayslink = True, ) @@ -251,10 +251,10 @@ cc_library( hdrs = ["buffered_file.h"], visibility = ["//visibility:public"], deps = [ - "//tsl/lib/hash:crc32c", - "//tsl/platform:cord", - "//tsl/platform:env", - "//tsl/platform:status", + "//xla/tsl/lib/hash:crc32c", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status", ], ) @@ -264,12 +264,12 @@ tsl_cc_test( srcs = ["buffered_file_test.cc"], deps = [ ":buffered_file", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", ], ) @@ -278,7 +278,7 @@ cc_library( srcs = ["zlib_compression_options.cc"], hdrs = ["zlib_compression_options.h"], deps = [ - "//tsl/platform:types", + "@local_tsl//tsl/platform:types", "@zlib", ], alwayslink = True, @@ -291,12 +291,12 @@ cc_library( deps = [ ":inputstream_interface", ":zlib_compression_options", - "//tsl/platform:env", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:status", - "//tsl/platform:strcat", - "//tsl/platform:types", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:types", "@zlib", ], alwayslink = True, @@ -308,12 +308,12 @@ cc_library( hdrs = ["zlib_outputbuffer.h"], deps = [ ":zlib_compression_options", - "//tsl/platform:env", - "//tsl/platform:errors", - "//tsl/platform:macros", - "//tsl/platform:status", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", "@zlib", ], alwayslink = True, @@ -357,9 +357,9 @@ filegroup( "zlib_compression_options.h", "zlib_inputstream.cc", "zlib_inputstream.h", - "//tsl/lib/io/snappy:snappy_compression_options.h", - "//tsl/lib/io/snappy:snappy_inputstream.cc", - "//tsl/lib/io/snappy:snappy_inputstream.h", + "//xla/tsl/lib/io/snappy:snappy_compression_options.h", + "//xla/tsl/lib/io/snappy:snappy_inputstream.cc", + "//xla/tsl/lib/io/snappy:snappy_inputstream.h", ], ) @@ -385,10 +385,10 @@ filegroup( "zlib_compression_options.h", "zlib_inputstream.h", "zlib_outputbuffer.h", - "//tsl/lib/io/snappy:snappy_compression_options.h", - "//tsl/lib/io/snappy:snappy_inputbuffer.h", - "//tsl/lib/io/snappy:snappy_inputstream.h", - "//tsl/lib/io/snappy:snappy_outputbuffer.h", + "//xla/tsl/lib/io/snappy:snappy_compression_options.h", + "//xla/tsl/lib/io/snappy:snappy_inputbuffer.h", + "//xla/tsl/lib/io/snappy:snappy_inputstream.h", + "//xla/tsl/lib/io/snappy:snappy_outputbuffer.h", ], visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) @@ -419,10 +419,10 @@ filegroup( "zlib_compression_options.h", "zlib_inputstream.h", "zlib_outputbuffer.h", - "//tsl/lib/io/snappy:snappy_compression_options.h", - "//tsl/lib/io/snappy:snappy_inputbuffer.h", - "//tsl/lib/io/snappy:snappy_inputstream.h", - "//tsl/lib/io/snappy:snappy_outputbuffer.h", + "//xla/tsl/lib/io/snappy:snappy_compression_options.h", + "//xla/tsl/lib/io/snappy:snappy_inputbuffer.h", + "//xla/tsl/lib/io/snappy:snappy_inputstream.h", + "//xla/tsl/lib/io/snappy:snappy_outputbuffer.h", ], visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) @@ -444,12 +444,12 @@ tsl_cc_test( deps = [ ":buffered_inputstream", ":random_inputstream", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", ], ) @@ -459,10 +459,10 @@ tsl_cc_test( srcs = ["cache_test.cc"], deps = [ ":cache", - "//tsl/platform:coding", - "//tsl/platform:raw_coding", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:raw_coding", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -472,17 +472,17 @@ tsl_cc_test( srcs = ["inputbuffer_test.cc"], deps = [ ":inputbuffer", - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:status", - "//tsl/platform:str_util", - "//tsl/platform:strcat", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -492,10 +492,10 @@ tsl_cc_test( srcs = ["inputstream_interface_test.cc"], deps = [ ":inputstream_interface", - "//tsl/platform:errors", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -505,11 +505,11 @@ tsl_cc_test( srcs = ["random_inputstream_test.cc"], deps = [ ":random_inputstream", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -520,15 +520,15 @@ tsl_cc_test( deps = [ ":record_reader", ":record_writer", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:status", - "//tsl/platform:strcat", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", "@zlib", ], ) @@ -540,16 +540,16 @@ tsl_cc_test( deps = [ ":record_reader", ":record_writer", - "//tsl/lib/hash:crc32c", - "//tsl/lib/random:philox", - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:str_util", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/hash:crc32c", + "//xla/tsl/lib/random:philox", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -561,14 +561,14 @@ tsl_cc_test( ":block", ":iterator", ":table", - "//tsl/lib/random:philox", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:platform_port", - "//tsl/platform:test", - "//tsl/platform:test_main", + "//xla/tsl/lib/random:philox", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -581,12 +581,12 @@ tsl_cc_test( ":zlib_compression_options", ":zlib_inputstream", ":zlib_outputbuffer", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:strcat", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc b/third_party/xla/xla/tsl/lib/io/block.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/block.cc rename to third_party/xla/xla/tsl/lib/io/block.cc index afae26cf20caec..eed80e59cf9243 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc +++ b/third_party/xla/xla/tsl/lib/io/block.cc @@ -15,11 +15,11 @@ limitations under the License. // Decodes the blocks generated by block_builder.cc. -#include "tsl/lib/io/block.h" +#include "xla/tsl/lib/io/block.h" #include -#include "tsl/lib/io/format.h" +#include "xla/tsl/lib/io/format.h" #include "tsl/platform/coding.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block.h b/third_party/xla/xla/tsl/lib/io/block.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/io/block.h rename to third_party/xla/xla/tsl/lib/io/block.h index b31808627157c2..c97a0f9830d48f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block.h +++ b/third_party/xla/xla/tsl/lib/io/block.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_BLOCK_H_ -#define TENSORFLOW_TSL_LIB_IO_BLOCK_H_ +#ifndef XLA_TSL_LIB_IO_BLOCK_H_ +#define XLA_TSL_LIB_IO_BLOCK_H_ #include #include -#include "tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/iterator.h" namespace tsl { namespace table { @@ -54,4 +54,4 @@ class Block { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_BLOCK_H_ +#endif // XLA_TSL_LIB_IO_BLOCK_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block_builder.cc b/third_party/xla/xla/tsl/lib/io/block_builder.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/block_builder.cc rename to third_party/xla/xla/tsl/lib/io/block_builder.cc index d28d718a24f6d7..e471852a7bfda4 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block_builder.cc +++ b/third_party/xla/xla/tsl/lib/io/block_builder.cc @@ -37,13 +37,13 @@ limitations under the License. // num_restarts: uint32 // restarts[i] contains the offset within the block of the ith restart point. -#include "tsl/lib/io/block_builder.h" +#include "xla/tsl/lib/io/block_builder.h" #include #include -#include "tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/io/table_builder.h" #include "tsl/platform/coding.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block_builder.h b/third_party/xla/xla/tsl/lib/io/block_builder.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/io/block_builder.h rename to third_party/xla/xla/tsl/lib/io/block_builder.h index 578d8bab57e854..0defea6d866e0f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block_builder.h +++ b/third_party/xla/xla/tsl/lib/io/block_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_BLOCK_BUILDER_H_ -#define TENSORFLOW_TSL_LIB_IO_BLOCK_BUILDER_H_ +#ifndef XLA_TSL_LIB_IO_BLOCK_BUILDER_H_ +#define XLA_TSL_LIB_IO_BLOCK_BUILDER_H_ #include @@ -67,4 +67,4 @@ class BlockBuilder { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_BLOCK_BUILDER_H_ +#endif // XLA_TSL_LIB_IO_BLOCK_BUILDER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h b/third_party/xla/xla/tsl/lib/io/buffered_file.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h rename to third_party/xla/xla/tsl/lib/io/buffered_file.h index e6abe32c465fff..6d173c83d12530 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h +++ b/third_party/xla/xla/tsl/lib/io/buffered_file.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_BUFFERED_FILE_H_ -#define TENSORFLOW_TSL_LIB_IO_BUFFERED_FILE_H_ +#ifndef XLA_TSL_LIB_IO_BUFFERED_FILE_H_ +#define XLA_TSL_LIB_IO_BUFFERED_FILE_H_ #include #include #include #include -#include "tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/hash/crc32c.h" #include "tsl/platform/cord.h" #include "tsl/platform/file_system.h" #include "tsl/platform/status.h" @@ -113,4 +113,4 @@ class BufferedWritableFile : public WritableFile { }; } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_BUFFERED_FILE_H_ +#endif // XLA_TSL_LIB_IO_BUFFERED_FILE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc b/third_party/xla/xla/tsl/lib/io/buffered_file_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc rename to third_party/xla/xla/tsl/lib/io/buffered_file_test.cc index f9fa67dd1572f5..2c3fc0fe5070ca 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc +++ b/third_party/xla/xla/tsl/lib/io/buffered_file_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/buffered_file.h" +#include "xla/tsl/lib/io/buffered_file.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc rename to third_party/xla/xla/tsl/lib/io/buffered_inputstream.cc index 89ed20757cf093..244c15882ab502 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" #include "absl/status/status.h" -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" namespace tsl { namespace io { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h rename to third_party/xla/xla/tsl/lib/io/buffered_inputstream.h index 6681f1bbfbed32..1a187012766ab1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ -#define TENSORFLOW_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#ifndef XLA_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#define XLA_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ #include -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/file_system.h" namespace tsl { @@ -124,4 +124,4 @@ extern template Status BufferedInputStream::ReadAll(tstring* result); } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#endif // XLA_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc rename to third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc index 83e5776d6602d2..1ad2476eb5af6e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc +++ b/third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/cache.cc b/third_party/xla/xla/tsl/lib/io/cache.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/cache.cc rename to third_party/xla/xla/tsl/lib/io/cache.cc index dee0871e6b8539..6515783f5c99e2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/cache.cc +++ b/third_party/xla/xla/tsl/lib/io/cache.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/cache.h" +#include "xla/tsl/lib/io/cache.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/cache.h b/third_party/xla/xla/tsl/lib/io/cache.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/cache.h rename to third_party/xla/xla/tsl/lib/io/cache.h index 831288b56abd75..9cd5502cb2e715 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/cache.h +++ b/third_party/xla/xla/tsl/lib/io/cache.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_CACHE_H_ -#define TENSORFLOW_TSL_LIB_IO_CACHE_H_ +#ifndef XLA_TSL_LIB_IO_CACHE_H_ +#define XLA_TSL_LIB_IO_CACHE_H_ #include @@ -124,4 +124,4 @@ class Cache { } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_CACHE_H_ +#endif // XLA_TSL_LIB_IO_CACHE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/cache_test.cc b/third_party/xla/xla/tsl/lib/io/cache_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/cache_test.cc rename to third_party/xla/xla/tsl/lib/io/cache_test.cc index 62a53601fd3c73..3c54c82a11ac25 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/cache_test.cc +++ b/third_party/xla/xla/tsl/lib/io/cache_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/cache.h" +#include "xla/tsl/lib/io/cache.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/compression.cc b/third_party/xla/xla/tsl/lib/io/compression.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/compression.cc rename to third_party/xla/xla/tsl/lib/io/compression.cc index 18f821bc805efe..450962fde73d98 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/compression.cc +++ b/third_party/xla/xla/tsl/lib/io/compression.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/compression.h" +#include "xla/tsl/lib/io/compression.h" namespace tsl { namespace io { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/compression.h b/third_party/xla/xla/tsl/lib/io/compression.h similarity index 86% rename from third_party/xla/third_party/tsl/tsl/lib/io/compression.h rename to third_party/xla/xla/tsl/lib/io/compression.h index bed94981eca9c4..ce3b7fb4ca3e4c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/compression.h +++ b/third_party/xla/xla/tsl/lib/io/compression.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_COMPRESSION_H_ -#define TENSORFLOW_TSL_LIB_IO_COMPRESSION_H_ +#ifndef XLA_TSL_LIB_IO_COMPRESSION_H_ +#define XLA_TSL_LIB_IO_COMPRESSION_H_ namespace tsl { namespace io { @@ -29,4 +29,4 @@ extern const char kZlib[]; } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_COMPRESSION_H_ +#endif // XLA_TSL_LIB_IO_COMPRESSION_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc b/third_party/xla/xla/tsl/lib/io/format.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/format.cc rename to third_party/xla/xla/tsl/lib/io/format.cc index aa26afd84b6677..e02451c08d7e0e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc +++ b/third_party/xla/xla/tsl/lib/io/format.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/format.h" +#include "xla/tsl/lib/io/format.h" #include -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/block.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/block.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.h b/third_party/xla/xla/tsl/lib/io/format.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/format.h rename to third_party/xla/xla/tsl/lib/io/format.h index 2f704c9ca9d200..3cf5d6312a5f02 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.h +++ b/third_party/xla/xla/tsl/lib/io/format.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_FORMAT_H_ -#define TENSORFLOW_TSL_LIB_IO_FORMAT_H_ +#ifndef XLA_TSL_LIB_IO_FORMAT_H_ +#define XLA_TSL_LIB_IO_FORMAT_H_ #include #include -#include "tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/io/table_builder.h" #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" @@ -110,4 +110,4 @@ inline BlockHandle::BlockHandle() } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_FORMAT_H_ +#endif // XLA_TSL_LIB_IO_FORMAT_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc b/third_party/xla/xla/tsl/lib/io/inputbuffer.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc rename to third_party/xla/xla/tsl/lib/io/inputbuffer.cc index f5a46ae7e87c1d..5fdff4943331ed 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/inputbuffer.h" +#include "xla/tsl/lib/io/inputbuffer.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h b/third_party/xla/xla/tsl/lib/io/inputbuffer.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h rename to third_party/xla/xla/tsl/lib/io/inputbuffer.h index 57a4a983c11e75..bec656ecd00ef6 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_INPUTBUFFER_H_ -#define TENSORFLOW_TSL_LIB_IO_INPUTBUFFER_H_ +#ifndef XLA_TSL_LIB_IO_INPUTBUFFER_H_ +#define XLA_TSL_LIB_IO_INPUTBUFFER_H_ #include @@ -149,4 +149,4 @@ inline absl::Status InputBuffer::ReadVarint64(uint64* result) { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_INPUTBUFFER_H_ +#endif // XLA_TSL_LIB_IO_INPUTBUFFER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc b/third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc rename to third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc index ae99467be0ea2a..a4d170101ea675 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/inputbuffer.h" +#include "xla/tsl/lib/io/inputbuffer.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc b/third_party/xla/xla/tsl/lib/io/inputstream_interface.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc rename to third_party/xla/xla/tsl/lib/io/inputstream_interface.cc index 6425ff0656b658..7bf261f6757609 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h b/third_party/xla/xla/tsl/lib/io/inputstream_interface.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h rename to third_party/xla/xla/tsl/lib/io/inputstream_interface.h index 8eb7f2ad868965..3ecb5b5af9e8df 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ -#define TENSORFLOW_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#ifndef XLA_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#define XLA_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ #include @@ -67,4 +67,4 @@ class InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#endif // XLA_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc rename to third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc index c9c34dba55364e..9021440b6e1d84 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc b/third_party/xla/xla/tsl/lib/io/iterator.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc rename to third_party/xla/xla/tsl/lib/io/iterator.cc index b7e69f7081aa92..2db370d7478d21 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc +++ b/third_party/xla/xla/tsl/lib/io/iterator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/iterator.h" namespace tsl { namespace table { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h b/third_party/xla/xla/tsl/lib/io/iterator.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/iterator.h rename to third_party/xla/xla/tsl/lib/io/iterator.h index 7fe51bfd785bc1..ba0b1dbc4b76de 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h +++ b/third_party/xla/xla/tsl/lib/io/iterator.h @@ -23,8 +23,8 @@ limitations under the License. // non-const method, all threads accessing the same Iterator must use // external synchronization. -#ifndef TENSORFLOW_TSL_LIB_IO_ITERATOR_H_ -#define TENSORFLOW_TSL_LIB_IO_ITERATOR_H_ +#ifndef XLA_TSL_LIB_IO_ITERATOR_H_ +#define XLA_TSL_LIB_IO_ITERATOR_H_ #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" @@ -101,4 +101,4 @@ extern Iterator* NewErrorIterator(const absl::Status& status); } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_ITERATOR_H_ +#endif // XLA_TSL_LIB_IO_ITERATOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/proto_encode_helper.h b/third_party/xla/xla/tsl/lib/io/proto_encode_helper.h similarity index 94% rename from third_party/xla/third_party/tsl/tsl/lib/io/proto_encode_helper.h rename to third_party/xla/xla/tsl/lib/io/proto_encode_helper.h index c5bf8262f9df91..33c2411cbc3ca3 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/proto_encode_helper.h +++ b/third_party/xla/xla/tsl/lib/io/proto_encode_helper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ -#define TENSORFLOW_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ +#ifndef XLA_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ +#define XLA_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ #include "tsl/platform/coding.h" #include "tsl/platform/logging.h" @@ -96,4 +96,4 @@ class ProtoEncodeHelper { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ +#endif // XLA_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc b/third_party/xla/xla/tsl/lib/io/random_inputstream.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc rename to third_party/xla/xla/tsl/lib/io/random_inputstream.cc index 6802707c3387fe..26e138c0e231c2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/random_inputstream.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h b/third_party/xla/xla/tsl/lib/io/random_inputstream.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h rename to third_party/xla/xla/tsl/lib/io/random_inputstream.h index 4d48db62c2b03f..99685ab055ac6a 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/random_inputstream.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ -#define TENSORFLOW_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ +#ifndef XLA_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ +#define XLA_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/cord.h" #include "tsl/platform/file_system.h" @@ -59,4 +59,4 @@ class RandomAccessInputStream : public InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ +#endif // XLA_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc b/third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc rename to third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc index dfa4ec80e20a17..e2fc82374e47bb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc +++ b/third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc b/third_party/xla/xla/tsl/lib/io/record_reader.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc rename to third_party/xla/xla/tsl/lib/io/record_reader.cc index 8d17c610b09f71..8332debff876c2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc +++ b/third_party/xla/xla/tsl/lib/io/record_reader.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/record_reader.h" +#include "xla/tsl/lib/io/record_reader.h" #include -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/buffered_inputstream.h" -#include "tsl/lib/io/compression.h" -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/compression.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/raw_coding.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h b/third_party/xla/xla/tsl/lib/io/record_reader.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h rename to third_party/xla/xla/tsl/lib/io/record_reader.h index 61540a657324c8..3c18992ec86279 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h +++ b/third_party/xla/xla/tsl/lib/io/record_reader.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_RECORD_READER_H_ -#define TENSORFLOW_TSL_LIB_IO_RECORD_READER_H_ +#ifndef XLA_TSL_LIB_IO_RECORD_READER_H_ +#define XLA_TSL_LIB_IO_RECORD_READER_H_ -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/errors.h" #include "tsl/platform/stringpiece.h" #if !defined(IS_SLIM_BUILD) -#include "tsl/lib/io/snappy/snappy_compression_options.h" -#include "tsl/lib/io/snappy/snappy_inputstream.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_inputstream.h" +#include "xla/tsl/lib/io/snappy/snappy_compression_options.h" +#include "xla/tsl/lib/io/snappy/snappy_inputstream.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_inputstream.h" #endif // IS_SLIM_BUILD #include "tsl/platform/macros.h" #include "tsl/platform/types.h" @@ -177,4 +177,4 @@ class SequentialRecordReader { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_RECORD_READER_H_ +#endif // XLA_TSL_LIB_IO_RECORD_READER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc rename to third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc index 45934c9f355576..e91f1ecaed1b99 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc +++ b/third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // clang-format off -#include "tsl/lib/io/record_reader.h" -#include "tsl/lib/io/record_writer.h" +#include "xla/tsl/lib/io/record_reader.h" +#include "xla/tsl/lib/io/record_writer.h" // clang-format on #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc b/third_party/xla/xla/tsl/lib/io/record_writer.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc rename to third_party/xla/xla/tsl/lib/io/record_writer.cc index b6e829206e5f2e..2e47e9d0686eb0 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc +++ b/third_party/xla/xla/tsl/lib/io/record_writer.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/record_writer.h" +#include "xla/tsl/lib/io/record_writer.h" -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/compression.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/compression.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h b/third_party/xla/xla/tsl/lib/io/record_writer.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h rename to third_party/xla/xla/tsl/lib/io/record_writer.h index 94c7ca576403df..5cb160790b9f1c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h +++ b/third_party/xla/xla/tsl/lib/io/record_writer.h @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_RECORD_WRITER_H_ -#define TENSORFLOW_TSL_LIB_IO_RECORD_WRITER_H_ +#ifndef XLA_TSL_LIB_IO_RECORD_WRITER_H_ +#define XLA_TSL_LIB_IO_RECORD_WRITER_H_ -#include "tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/hash/crc32c.h" #include "tsl/platform/coding.h" #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #if !defined(IS_SLIM_BUILD) -#include "tsl/lib/io/snappy/snappy_compression_options.h" -#include "tsl/lib/io/snappy/snappy_outputbuffer.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_compression_options.h" +#include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #endif // IS_SLIM_BUILD #include "tsl/platform/cord.h" #include "tsl/platform/macros.h" @@ -153,4 +153,4 @@ void RecordWriter::PopulateFooter(char* footer, const absl::Cord& data) { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_RECORD_WRITER_H_ +#endif // XLA_TSL_LIB_IO_RECORD_WRITER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc b/third_party/xla/xla/tsl/lib/io/recordio_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc rename to third_party/xla/xla/tsl/lib/io/recordio_test.cc index c07d26a37e698f..02d22ec4931218 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc +++ b/third_party/xla/xla/tsl/lib/io/recordio_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/record_reader.h" -#include "tsl/lib/io/record_writer.h" -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/record_reader.h" +#include "xla/tsl/lib/io/record_writer.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD b/third_party/xla/xla/tsl/lib/io/snappy/BUILD similarity index 56% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD rename to third_party/xla/xla/tsl/lib/io/snappy/BUILD index 0adc5e2fa467aa..6246244fa740e7 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD +++ b/third_party/xla/xla/tsl/lib/io/snappy/BUILD @@ -1,8 +1,8 @@ -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") load( - "//tsl/platform:build_config.bzl", + "@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test", ) +load("//xla/tsl:tsl.bzl", "internal_visibility") # Snappy targets. @@ -15,7 +15,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = internal_visibility([ "//tensorflow/core/lib/io:__pkg__", - "//tsl/lib/io:__pkg__", + "//xla/tsl/lib/io:__pkg__", ]), licenses = ["notice"], ) @@ -34,13 +34,13 @@ cc_library( srcs = ["snappy_inputbuffer.cc"], hdrs = ["snappy_inputbuffer.h"], deps = [ - "//tsl/lib/io:inputstream_interface", - "//tsl/platform:env", - "//tsl/platform:macros", - "//tsl/platform:platform_port", - "//tsl/platform:status", - "//tsl/platform:types", + "//xla/tsl/lib/io:inputstream_interface", "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -50,12 +50,12 @@ cc_library( srcs = ["snappy_outputbuffer.cc"], hdrs = ["snappy_outputbuffer.h"], deps = [ - "//tsl/platform", - "//tsl/platform:env", - "//tsl/platform:macros", - "//tsl/platform:platform_port", - "//tsl/platform:status", - "//tsl/platform:types", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -65,10 +65,10 @@ cc_library( srcs = ["snappy_inputstream.cc"], hdrs = ["snappy_inputstream.h"], deps = [ - "//tsl/lib/io:inputstream_interface", - "//tsl/platform:errors", - "//tsl/platform:platform_port", + "//xla/tsl/lib/io:inputstream_interface", "@com_google_absl//absl/memory", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:platform_port", ], alwayslink = True, ) @@ -77,7 +77,7 @@ cc_library( name = "snappy_compression_options", hdrs = ["snappy_compression_options.h"], deps = [ - "//tsl/platform:types", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -90,12 +90,12 @@ tsl_cc_test( ":snappy_inputbuffer", ":snappy_inputstream", ":snappy_outputbuffer", - "//tsl/lib/io:inputbuffer", - "//tsl/lib/io:random_inputstream", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/io:inputbuffer", + "//xla/tsl/lib/io:random_inputstream", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_compression_options.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h similarity index 84% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_compression_options.h rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h index 4d1ba01e3c15d7..3772a415056cf9 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_compression_options.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ -#define TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ +#ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ +#define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ #include "tsl/platform/types.h" @@ -33,4 +33,4 @@ struct SnappyCompressionOptions { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ +#endif // XLA_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc index 7844b8993fd98d..09c5e482ef51fa 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/snappy/snappy_inputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_inputbuffer.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.h rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h index 4d7fd3fe2e010d..969c1e00c2bfe3 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ -#define TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ +#ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ +#define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ #include #include -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/env.h" #include "tsl/platform/macros.h" #include "tsl/platform/snappy.h" @@ -131,4 +131,4 @@ class SnappyInputBuffer : public InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ +#endif // XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.cc rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc index 264f524fcef48f..bcbe96e21139e7 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/snappy/snappy_inputstream.h" +#include "xla/tsl/lib/io/snappy/snappy_inputstream.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.h rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.h index 6240aa53feb7fa..44535fe65d8763 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ -#define TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ +#ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ +#define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ #include -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" namespace tsl { namespace io { @@ -89,4 +89,4 @@ class SnappyInputStream : public InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ +#endif // XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc index e851f58f1b9cda..7241d24c46b155 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/snappy/snappy_outputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h index a3bd44748c152f..631014c3b6e189 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ -#define TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ +#ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ +#define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ #include #include @@ -155,4 +155,4 @@ class SnappyOutputBuffer : public WritableFile { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ +#endif // XLA_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc index 78eecf360d9489..f3504e9268a76e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/lib/io/inputbuffer.h" -#include "tsl/lib/io/random_inputstream.h" -#include "tsl/lib/io/snappy/snappy_inputbuffer.h" -#include "tsl/lib/io/snappy/snappy_inputstream.h" -#include "tsl/lib/io/snappy/snappy_outputbuffer.h" +#include "xla/tsl/lib/io/inputbuffer.h" +#include "xla/tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/snappy/snappy_inputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_inputstream.h" +#include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table.cc b/third_party/xla/xla/tsl/lib/io/table.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/table.cc rename to third_party/xla/xla/tsl/lib/io/table.cc index 05f8cd1d71e1c5..5c36b4649859b8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table.cc +++ b/third_party/xla/xla/tsl/lib/io/table.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/table.h" +#include "xla/tsl/lib/io/table.h" -#include "tsl/lib/io/block.h" -#include "tsl/lib/io/cache.h" -#include "tsl/lib/io/format.h" -#include "tsl/lib/io/table_options.h" -#include "tsl/lib/io/two_level_iterator.h" +#include "xla/tsl/lib/io/block.h" +#include "xla/tsl/lib/io/cache.h" +#include "xla/tsl/lib/io/format.h" +#include "xla/tsl/lib/io/table_options.h" +#include "xla/tsl/lib/io/two_level_iterator.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table.h b/third_party/xla/xla/tsl/lib/io/table.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/table.h rename to third_party/xla/xla/tsl/lib/io/table.h index 4a6855c661f6b8..3afdb0c461ea10 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table.h +++ b/third_party/xla/xla/tsl/lib/io/table.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_TABLE_H_ -#define TENSORFLOW_TSL_LIB_IO_TABLE_H_ +#ifndef XLA_TSL_LIB_IO_TABLE_H_ +#define XLA_TSL_LIB_IO_TABLE_H_ #include -#include "tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/iterator.h" namespace tsl { @@ -84,4 +84,4 @@ class Table { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_TABLE_H_ +#endif // XLA_TSL_LIB_IO_TABLE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_builder.cc b/third_party/xla/xla/tsl/lib/io/table_builder.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_builder.cc rename to third_party/xla/xla/tsl/lib/io/table_builder.cc index c07227b934a2b7..b5fcb0c9ed47dc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table_builder.cc +++ b/third_party/xla/xla/tsl/lib/io/table_builder.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/io/table_builder.h" #include -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/block_builder.h" -#include "tsl/lib/io/format.h" -#include "tsl/lib/io/table_options.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/block_builder.h" +#include "xla/tsl/lib/io/format.h" +#include "xla/tsl/lib/io/table_options.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_builder.h b/third_party/xla/xla/tsl/lib/io/table_builder.h similarity index 94% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_builder.h rename to third_party/xla/xla/tsl/lib/io/table_builder.h index d4e88e989d47bf..059f9ab60546c1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table_builder.h +++ b/third_party/xla/xla/tsl/lib/io/table_builder.h @@ -21,12 +21,12 @@ limitations under the License. // non-const method, all threads accessing the same TableBuilder must use // external synchronization. -#ifndef TENSORFLOW_TSL_LIB_IO_TABLE_BUILDER_H_ -#define TENSORFLOW_TSL_LIB_IO_TABLE_BUILDER_H_ +#ifndef XLA_TSL_LIB_IO_TABLE_BUILDER_H_ +#define XLA_TSL_LIB_IO_TABLE_BUILDER_H_ #include -#include "tsl/lib/io/table_options.h" +#include "xla/tsl/lib/io/table_options.h" #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" @@ -98,4 +98,4 @@ class TableBuilder { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_TABLE_BUILDER_H_ +#endif // XLA_TSL_LIB_IO_TABLE_BUILDER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_format.txt b/third_party/xla/xla/tsl/lib/io/table_format.txt similarity index 100% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_format.txt rename to third_party/xla/xla/tsl/lib/io/table_format.txt diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_options.h b/third_party/xla/xla/tsl/lib/io/table_options.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_options.h rename to third_party/xla/xla/tsl/lib/io/table_options.h index c3ca3e1b2fe96d..7784d225c371fe 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table_options.h +++ b/third_party/xla/xla/tsl/lib/io/table_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_TABLE_OPTIONS_H_ -#define TENSORFLOW_TSL_LIB_IO_TABLE_OPTIONS_H_ +#ifndef XLA_TSL_LIB_IO_TABLE_OPTIONS_H_ +#define XLA_TSL_LIB_IO_TABLE_OPTIONS_H_ #include @@ -73,4 +73,4 @@ struct Options { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_TABLE_OPTIONS_H_ +#endif // XLA_TSL_LIB_IO_TABLE_OPTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_test.cc b/third_party/xla/xla/tsl/lib/io/table_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_test.cc rename to third_party/xla/xla/tsl/lib/io/table_test.cc index 567881d1dce92d..6671bc816abc17 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table_test.cc +++ b/third_party/xla/xla/tsl/lib/io/table_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/table.h" +#include "xla/tsl/lib/io/table.h" #include #include @@ -21,12 +21,12 @@ limitations under the License. #include #include "absl/strings/escaping.h" -#include "tsl/lib/io/block.h" -#include "tsl/lib/io/block_builder.h" -#include "tsl/lib/io/format.h" -#include "tsl/lib/io/iterator.h" -#include "tsl/lib/io/table_builder.h" -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/io/block.h" +#include "xla/tsl/lib/io/block_builder.h" +#include "xla/tsl/lib/io/format.h" +#include "xla/tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/snappy.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.cc b/third_party/xla/xla/tsl/lib/io/two_level_iterator.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.cc rename to third_party/xla/xla/tsl/lib/io/two_level_iterator.cc index f3ee26b3cc71c8..853ea9c037cf49 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.cc +++ b/third_party/xla/xla/tsl/lib/io/two_level_iterator.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/two_level_iterator.h" +#include "xla/tsl/lib/io/two_level_iterator.h" -#include "tsl/lib/io/block.h" -#include "tsl/lib/io/format.h" -#include "tsl/lib/io/iterator.h" -#include "tsl/lib/io/table.h" +#include "xla/tsl/lib/io/block.h" +#include "xla/tsl/lib/io/format.h" +#include "xla/tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/table.h" namespace tsl { namespace table { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.h b/third_party/xla/xla/tsl/lib/io/two_level_iterator.h similarity index 88% rename from third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.h rename to third_party/xla/xla/tsl/lib/io/two_level_iterator.h index 1ae98da5af9695..87f2e1545aa344 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.h +++ b/third_party/xla/xla/tsl/lib/io/two_level_iterator.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ -#define TENSORFLOW_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#ifndef XLA_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#define XLA_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ -#include "tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/iterator.h" namespace tsl { namespace table { @@ -39,4 +39,4 @@ extern Iterator* NewTwoLevelIterator( } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#endif // XLA_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc b/third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc rename to third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc index 75554fa9bca17e..c66d9229e480c9 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/lib/io/random_inputstream.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_inputstream.h" -#include "tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_inputstream.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/strcat.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.cc b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.cc similarity index 94% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.cc rename to third_party/xla/xla/tsl/lib/io/zlib_compression_options.cc index 4f30c5252c9d9a..724eec1478ccbd 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.h b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.h rename to third_party/xla/xla/tsl/lib/io/zlib_compression_options.h index 612f32c507148d..0cae3a2ef54128 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ -#define TENSORFLOW_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#ifndef XLA_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#define XLA_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ #include "tsl/platform/types.h" @@ -138,4 +138,4 @@ inline ZlibCompressionOptions ZlibCompressionOptions::GZIP() { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#endif // XLA_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.cc b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.cc rename to third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc index 3407805e62ddff..fda83637279579 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/zlib_inputstream.h" +#include "xla/tsl/lib/io/zlib_inputstream.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.h b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.h rename to third_party/xla/xla/tsl/lib/io/zlib_inputstream.h index 7a61fda7c9de71..16df9508636019 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ -#define TENSORFLOW_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ +#ifndef XLA_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ +#define XLA_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ #include -#include "tsl/lib/io/inputstream_interface.h" -#include "tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" #include "tsl/platform/env.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" @@ -139,4 +139,4 @@ class ZlibInputStream : public InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ +#endif // XLA_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.cc b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.cc rename to third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc index afcf5a46752074..646e4397898841 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h rename to third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h index a255524ff78c04..96b1d1bb9704da 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ -#define TENSORFLOW_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#ifndef XLA_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#define XLA_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ #include #include -#include "tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" #include "tsl/platform/macros.h" @@ -156,4 +156,4 @@ class ZlibOutputBuffer : public WritableFile { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#endif // XLA_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/math/BUILD b/third_party/xla/xla/tsl/lib/math/BUILD similarity index 74% rename from third_party/xla/third_party/tsl/tsl/lib/math/BUILD rename to third_party/xla/xla/tsl/lib/math/BUILD index a78947f3c38ffa..8bb5fd993079fb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/math/BUILD +++ b/third_party/xla/xla/tsl/lib/math/BUILD @@ -1,9 +1,9 @@ -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") load( - "//tsl/platform:build_config.bzl", + "@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test", ) +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -33,11 +33,11 @@ tsl_cc_test( ], deps = [ ":math_util", - "//tsl/platform:logging", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/math/math_util.h b/third_party/xla/xla/tsl/lib/math/math_util.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/math/math_util.h rename to third_party/xla/xla/tsl/lib/math/math_util.h index 26dc0093982740..a2622d48976726 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/math/math_util.h +++ b/third_party/xla/xla/tsl/lib/math/math_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_MATH_MATH_UTIL_H_ -#define TENSORFLOW_TSL_LIB_MATH_MATH_UTIL_H_ +#ifndef XLA_TSL_LIB_MATH_MATH_UTIL_H_ +#define XLA_TSL_LIB_MATH_MATH_UTIL_H_ #include @@ -158,4 +158,4 @@ constexpr T MathUtil::IPow(T base, int exp) { } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_MATH_MATH_UTIL_H_ +#endif // XLA_TSL_LIB_MATH_MATH_UTIL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/math/math_util_test.cc b/third_party/xla/xla/tsl/lib/math/math_util_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/math/math_util_test.cc rename to third_party/xla/xla/tsl/lib/math/math_util_test.cc index ccceabf5cc6da7..c60f9796695ceb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/math/math_util_test.cc +++ b/third_party/xla/xla/tsl/lib/math/math_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/math/math_util.h" +#include "xla/tsl/lib/math/math_util.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/BUILD b/third_party/xla/xla/tsl/lib/random/BUILD similarity index 73% rename from third_party/xla/third_party/tsl/tsl/lib/random/BUILD rename to third_party/xla/xla/tsl/lib/random/BUILD index c64a1332e76ff8..4fa352c07886eb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/BUILD +++ b/third_party/xla/xla/tsl/lib/random/BUILD @@ -1,13 +1,13 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( - "//tsl/platform:build_config.bzl", + "@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test", ) +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") default_visibility = [ - "//tsl/lib/io:__pkg__", + "//xla/tsl/lib/io:__pkg__", # tensorflow/core/platform/random aliases this package "//tensorflow/core/lib/random:__pkg__", ] @@ -40,11 +40,11 @@ cc_library( ":exact_uniform_int", ":philox_random", ":random_distributions_utils", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:types", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", ], alwayslink = 1, ) @@ -73,8 +73,8 @@ cc_library( hdrs = ["philox_random_test_utils.h"], deps = [ ":philox_random", - "//tsl/platform:logging", - "//tsl/platform:random", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:random", ], ) @@ -84,9 +84,9 @@ cc_library( hdrs = ["weighted_picker.h"], deps = [ ":philox", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:types", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", ], alwayslink = 1, ) @@ -159,11 +159,11 @@ tsl_cc_test( srcs = ["distribution_sampler_test.cc"], deps = [ ":philox", - "//tsl/platform:macros", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) @@ -175,10 +175,10 @@ tsl_cc_test( ":philox", ":philox_random", ":philox_random_test_utils", - "//tsl/platform:logging", - "//tsl/platform:random", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:random", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -190,11 +190,11 @@ tsl_cc_test( ":philox", ":philox_random", ":philox_random_test_utils", - "//tsl/lib/math:math_util", - "//tsl/platform:logging", - "//tsl/platform:random", - "//tsl/platform:test", - "//tsl/platform:test_main", + "//xla/tsl/lib/math:math_util", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:random", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -204,10 +204,10 @@ tsl_cc_test( srcs = ["simple_philox_test.cc"], deps = [ ":philox", - "//tsl/platform:logging", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) @@ -218,11 +218,11 @@ tsl_cc_test( deps = [ ":philox", ":weighted_picker", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.cc b/third_party/xla/xla/tsl/lib/random/distribution_sampler.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.cc rename to third_party/xla/xla/tsl/lib/random/distribution_sampler.cc index 9e597ffea0a390..384dd50fc34e74 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.cc +++ b/third_party/xla/xla/tsl/lib/random/distribution_sampler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/distribution_sampler.h" +#include "xla/tsl/lib/random/distribution_sampler.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.h b/third_party/xla/xla/tsl/lib/random/distribution_sampler.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.h rename to third_party/xla/xla/tsl/lib/random/distribution_sampler.h index 877660c8532baa..ababcc6bf23a31 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.h +++ b/third_party/xla/xla/tsl/lib/random/distribution_sampler.h @@ -28,14 +28,14 @@ limitations under the License. // // The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. -#ifndef TENSORFLOW_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#ifndef XLA_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#define XLA_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ #include #include #include "absl/types/span.h" -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/types.h" @@ -92,4 +92,4 @@ class DistributionSampler { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#endif // XLA_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler_test.cc b/third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler_test.cc rename to third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc index 142c01a77023eb..16107ec61c26c0 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler_test.cc +++ b/third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/distribution_sampler.h" +#include "xla/tsl/lib/random/distribution_sampler.h" #include #include #include -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/macros.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/exact_uniform_int.h b/third_party/xla/xla/tsl/lib/random/exact_uniform_int.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/lib/random/exact_uniform_int.h rename to third_party/xla/xla/tsl/lib/random/exact_uniform_int.h index 392d1aa2835110..25d05cb69eefcc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/exact_uniform_int.h +++ b/third_party/xla/xla/tsl/lib/random/exact_uniform_int.h @@ -15,8 +15,8 @@ limitations under the License. // Exact uniform integers using rejection sampling -#ifndef TENSORFLOW_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ +#ifndef XLA_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ +#define XLA_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ #include @@ -31,7 +31,7 @@ UintType ExactUniformInt(const UintType n, const RandomBits& random) { "random() should return UintType"); if (n == 0) { // Consume a value anyway - // TODO(irving): Assert n != 0, since this case makes no sense. + // TODO(geoffreyi): Assert n != 0, since this case makes no sense. return random() * n; } else if (0 == (n & (n - 1))) { // N is a power of two, so just mask off the lower bits. @@ -80,4 +80,4 @@ UintType ExactUniformInt(const UintType n, const RandomBits& random) { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ +#endif // XLA_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random.h b/third_party/xla/xla/tsl/lib/random/philox_random.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/philox_random.h rename to third_party/xla/xla/tsl/lib/random/philox_random.h index 03b54aae3ec48e..f3b57794f737bc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random.h +++ b/third_party/xla/xla/tsl/lib/random/philox_random.h @@ -17,8 +17,8 @@ limitations under the License. // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -#ifndef TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ +#ifndef XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ +#define XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ #include @@ -255,4 +255,4 @@ class PhiloxRandom { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ +#endif // XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test.cc b/third_party/xla/xla/tsl/lib/random/philox_random_test.cc similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test.cc rename to third_party/xla/xla/tsl/lib/random/philox_random_test.cc index 714c510aa4ddd5..7af1f9485754fd 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test.cc +++ b/third_party/xla/xla/tsl/lib/random/philox_random_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random.h" #include @@ -22,8 +22,8 @@ limitations under the License. #include #include -#include "tsl/lib/random/philox_random_test_utils.h" -#include "tsl/lib/random/random_distributions.h" +#include "xla/tsl/lib/random/philox_random_test_utils.h" +#include "xla/tsl/lib/random/random_distributions.h" #include "tsl/platform/logging.h" #include "tsl/platform/random.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test_utils.h b/third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h similarity index 86% rename from third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test_utils.h rename to third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h index 4e217d6362cce1..6bbb1c89596b80 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test_utils.h +++ b/third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ +#ifndef XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ +#define XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ #include -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random.h" #include "tsl/platform/logging.h" #include "tsl/platform/random.h" @@ -48,4 +48,4 @@ void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p, } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ +#endif // XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.cc b/third_party/xla/xla/tsl/lib/random/random_distributions.cc similarity index 90% rename from third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.cc rename to third_party/xla/xla/tsl/lib/random/random_distributions.cc index 12a806f80acbe5..ab8930008f8c8b 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.cc +++ b/third_party/xla/xla/tsl/lib/random/random_distributions.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/distribution_sampler.h" -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/distribution_sampler.h" +#include "xla/tsl/lib/random/philox_random.h" namespace tsl { namespace random { diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.h b/third_party/xla/xla/tsl/lib/random/random_distributions.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.h rename to third_party/xla/xla/tsl/lib/random/random_distributions.h index 70a78cf86082ab..ce231f9f652c27 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.h +++ b/third_party/xla/xla/tsl/lib/random/random_distributions.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#ifndef XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#define XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #include #include #include -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "tsl/lib/random/philox_random.h" -#include "tsl/lib/random/random_distributions_utils.h" +#include "unsupported/Eigen/CXX11/Tensor" +#include "xla/tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/random_distributions_utils.h" #include "tsl/platform/types.h" namespace tsl { @@ -753,4 +753,4 @@ PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x) { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#endif // XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_test.cc b/third_party/xla/xla/tsl/lib/random/random_distributions_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_test.cc rename to third_party/xla/xla/tsl/lib/random/random_distributions_test.cc index ccb595aa0dae4f..b1dab4cd81d6d8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_test.cc +++ b/third_party/xla/xla/tsl/lib/random/random_distributions_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/random_distributions.h" +#include "xla/tsl/lib/random/random_distributions.h" #include #include @@ -22,9 +22,9 @@ limitations under the License. #include #include -#include "tsl/lib/math/math_util.h" -#include "tsl/lib/random/philox_random.h" -#include "tsl/lib/random/philox_random_test_utils.h" +#include "xla/tsl/lib/math/math_util.h" +#include "xla/tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random_test_utils.h" #include "tsl/platform/logging.h" #include "tsl/platform/random.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_utils.h b/third_party/xla/xla/tsl/lib/random/random_distributions_utils.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_utils.h rename to third_party/xla/xla/tsl/lib/random/random_distributions_utils.h index 38f2d792f58a2d..8da345b83e5c97 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_utils.h +++ b/third_party/xla/xla/tsl/lib/random/random_distributions_utils.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ +#ifndef XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ +#define XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ #include #include -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random.h" #ifndef M_PI #define M_PI (3.14159265358979323846) @@ -94,4 +94,4 @@ void BoxMullerFloat(uint32_t x0, uint32_t x1, float* f0, float* f1) { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ +#endif // XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.cc b/third_party/xla/xla/tsl/lib/random/simple_philox.cc similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.cc rename to third_party/xla/xla/tsl/lib/random/simple_philox.cc index 1ae957ed0a9c29..f2c2bbe5820863 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.cc +++ b/third_party/xla/xla/tsl/lib/random/simple_philox.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" -#include "tsl/lib/random/exact_uniform_int.h" +#include "xla/tsl/lib/random/exact_uniform_int.h" #include "tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.h b/third_party/xla/xla/tsl/lib/random/simple_philox.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.h rename to third_party/xla/xla/tsl/lib/random/simple_philox.h index 631656519478d3..736bec4d84d238 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.h +++ b/third_party/xla/xla/tsl/lib/random/simple_philox.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ +#ifndef XLA_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ +#define XLA_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ #include #include #include -#include "tsl/lib/random/philox_random.h" -#include "tsl/lib/random/random_distributions.h" +#include "xla/tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/random_distributions.h" namespace tsl { namespace random { @@ -74,4 +74,4 @@ class SimplePhilox { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ +#endif // XLA_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox_test.cc b/third_party/xla/xla/tsl/lib/random/simple_philox_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/simple_philox_test.cc rename to third_party/xla/xla/tsl/lib/random/simple_philox_test.cc index 657d4cf64758ca..3eded84eb0ee33 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox_test.cc +++ b/third_party/xla/xla/tsl/lib/random/simple_philox_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.cc b/third_party/xla/xla/tsl/lib/random/weighted_picker.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.cc rename to third_party/xla/xla/tsl/lib/random/weighted_picker.cc index 06e0df581fbd46..911f0f4d300616 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.cc +++ b/third_party/xla/xla/tsl/lib/random/weighted_picker.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/weighted_picker.h" +#include "xla/tsl/lib/random/weighted_picker.h" #include #include -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" namespace tsl { namespace random { diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.h b/third_party/xla/xla/tsl/lib/random/weighted_picker.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.h rename to third_party/xla/xla/tsl/lib/random/weighted_picker.h index 05fabea852b5f7..27903077df2a73 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.h +++ b/third_party/xla/xla/tsl/lib/random/weighted_picker.h @@ -24,8 +24,8 @@ limitations under the License. // Alternative: distribution-sampler.h allows O(1) time picking, but no weight // adjustment after construction. -#ifndef TENSORFLOW_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ +#ifndef XLA_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ +#define XLA_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ #include @@ -131,4 +131,4 @@ inline int WeightedPicker::num_elements() const { return N_; } } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ +#endif // XLA_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker_test.cc b/third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker_test.cc rename to third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc index a81b2ad99d5e6b..64e40c05c432a8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker_test.cc +++ b/third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/weighted_picker.h" +#include "xla/tsl/lib/random/weighted_picker.h" #include #include -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tsl/lib/strings/BUILD b/third_party/xla/xla/tsl/lib/strings/BUILD index 03f82a366f78c6..1e36a8d78f0b5e 100644 --- a/third_party/xla/xla/tsl/lib/strings/BUILD +++ b/third_party/xla/xla/tsl/lib/strings/BUILD @@ -13,9 +13,9 @@ cc_library( hdrs = ["proto_serialization.h"], visibility = ["//visibility:public"], deps = [ + "//xla/tsl/lib/gtl:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", diff --git a/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc index 06ef0747ee553d..fef78bd1835a00 100644 --- a/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc +++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" -#include "tsl/lib/gtl/inlined_vector.h" +#include "xla/tsl/lib/gtl/inlined_vector.h" #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD index 8e02d961b59f47..9a891967712b0e 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD @@ -64,10 +64,10 @@ cc_library( "//learning/pathways/data_parallel:__pkg__", ]), deps = [ + "//xla/tsl/lib/io:zlib_compression_options", + "//xla/tsl/lib/io:zlib_outputbuffer", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/lib/io:zlib_compression_options", - "@local_tsl//tsl/lib/io:zlib_outputbuffer", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc index cf98c56d1944e4..acdceeaa6e0c10 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc @@ -28,8 +28,8 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/time/clock.h" #include "absl/time/time.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" diff --git a/third_party/xla/xla/util.h b/third_party/xla/xla/util.h index 325080f0f08201..a6e74601a809f7 100644 --- a/third_party/xla/xla/util.h +++ b/third_party/xla/xla/util.h @@ -47,9 +47,9 @@ limitations under the License. #include "absl/types/span.h" #include "Eigen/Core" #include "xla/status_macros.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/bfloat16.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" // IWYU pragma: keep From c1ee8548f5c25fa50768e3e18eccba763394595b Mon Sep 17 00:00:00 2001 From: pizzud Date: Fri, 20 Sep 2024 12:27:20 -0700 Subject: [PATCH 076/483] [NFC] Shard some large tests so they pass in Bazel's fastbuild mode. Runtime is improved 8-40x on internal test infrastructure. PiperOrigin-RevId: 676932226 --- third_party/xla/xla/service/gpu/transforms/BUILD | 1 + third_party/xla/xla/service/spmd/BUILD | 1 + 2 files changed, 2 insertions(+) diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 6ee8cf82fda2e6..d59bd94e726f52 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -1707,6 +1707,7 @@ xla_test( name = "gemm_rewriter_test", srcs = ["gemm_rewriter_test.cc"], backends = ["gpu"], + shard_count = 5, deps = [ ":gemm_rewriter", "//xla:error_spec", diff --git a/third_party/xla/xla/service/spmd/BUILD b/third_party/xla/xla/service/spmd/BUILD index 2d3a5d993d0c80..b444e594a89f6a 100644 --- a/third_party/xla/xla/service/spmd/BUILD +++ b/third_party/xla/xla/service/spmd/BUILD @@ -93,6 +93,7 @@ cc_library( xla_cc_test( name = "spmd_partitioner_test", srcs = ["spmd_partitioner_test.cc"], + shard_count = 10, deps = [ ":spmd_partitioner", ":spmd_prepare", From 0546d214f928bd6986f6c8d01b966691d1bf500a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 12:50:03 -0700 Subject: [PATCH 077/483] Fix wheel_test configurations. PiperOrigin-RevId: 676939344 --- .bazelrc | 10 +++++----- third_party/xla/.bazelrc | 10 +++++----- third_party/xla/third_party/tsl/.bazelrc | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.bazelrc b/.bazelrc index f5b06fd21d937e..ec9e9cc552831b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index f5b06fd21d937e..ec9e9cc552831b 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index f5b06fd21d937e..ec9e9cc552831b 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. From 168aa9f750ba8a9b296aa3fc60f543a16a14fb9f Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Fri, 20 Sep 2024 13:46:27 -0700 Subject: [PATCH 078/483] Fix test failures with NumPy 2.1 upgrade Create NumPy 1.x and 2.x compatible versions of np.where and np.reshape np.where: When only condition is provided, np.where(condition) is a shorthand for np.asarray(condition).nonzero(). NumPy 2.1.0rc0 disallows 0D input arrays in nonzero, so np.atleast_1d is used here to remain compatible with NumPy 1.x. See https://github.com/numpy/numpy/pull/26268. np.reshape: NumPy 2.1.0rc1 added shape and copy arguments to numpy.reshape. Both newshape and shape keywords are supported (use shape as newshape will be deprecated). Besides, shape cannot be None now. To remain behavior with NumPy 1.x, we now use asarray to create an ndarray. See https://github.com/numpy/numpy/pull/26292. PiperOrigin-RevId: 676957264 --- tensorflow/python/framework/BUILD | 1 + tensorflow/python/framework/test_util.py | 23 +++---- .../python/kernel_tests/array_ops/BUILD | 1 + .../kernel_tests/array_ops/init_ops_test.py | 3 +- tensorflow/python/ops/numpy_ops/tests/BUILD | 1 + .../python/ops/numpy_ops/tests/np_test.py | 5 +- tensorflow/python/ops/ragged/BUILD | 2 + .../python/ops/ragged/ragged_factory_ops.py | 5 +- tensorflow/python/util/numpy_compat.py | 61 +++++++++++++++++++ 9 files changed, 86 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 0caa7f4788e572..f08a3e1347ec99 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -2153,6 +2153,7 @@ pytype_strict_library( "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_decorator_py", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_inspect", diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index c2b769a99d6f8e..46f981df64b6c6 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -100,6 +100,7 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util import traceback_utils from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.numpy_compat import np_where from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export @@ -3248,11 +3249,11 @@ def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): np.abs(a - b) > atol + rtol * np.abs(b), np.isnan(a) != np.isnan(b)) if a.ndim: - x = a[np.where(cond)] - y = b[np.where(cond)] - msgs.append("not close where = {}".format(np.where(cond))) + x = a[np_where(cond)] + y = b[np_where(cond)] + msgs.append("not close where = {}".format(np_where(cond))) else: - # np.where is broken for scalars + # np_where is broken for scalars x, y = a, b msgs.append("not close lhs = {}".format(x)) msgs.append("not close rhs = {}".format(y)) @@ -3479,11 +3480,11 @@ def assertAllEqual(self, a, b, msg=None): # Adds more details to np.testing.assert_array_equal. diff = np.logical_not(same) if a.ndim: - x = a[np.where(diff)] - y = b[np.where(diff)] - msgs.append("not equal where = {}".format(np.where(diff))) + x = a[np_where(diff)] + y = b[np_where(diff)] + msgs.append("not equal where = {}".format(np_where(diff))) else: - # np.where is broken for scalars + # np_where is broken for scalars x, y = a, b msgs.append("not equal lhs = %r" % x) msgs.append("not equal rhs = %r" % y) @@ -3583,7 +3584,7 @@ def _format_subscripts(self, subscripts, value, limit=10, indent=2): Args: subscripts: The tensor (np.ndarray) subscripts, of the same format as - np.where()'s return value, i.e., a tuple of arrays with each array + np_where()'s return value, i.e., a tuple of arrays with each array corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])). value: (np.ndarray) value of the tensor. limit: (int) The maximum number of indices to print. @@ -3639,7 +3640,7 @@ def assertAllInRange(self, "The value of %s does not have an ordered numeric type, instead it " "has type: %s" % (target, target.dtype)) - nan_subscripts = np.where(np.isnan(target)) + nan_subscripts = np_where(np.isnan(target)) if np.size(nan_subscripts): raise AssertionError( "%d of the %d element(s) are NaN. " @@ -3657,7 +3658,7 @@ def assertAllInRange(self, violations, np.greater_equal(target, upper_bound) if open_upper_bound else np.greater(target, upper_bound)) - violation_subscripts = np.where(violations) + violation_subscripts = np_where(violations) if np.size(violation_subscripts): raise AssertionError( "%d of the %d element(s) are outside the range %s. " % diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index 1a8d26ef226f7c..e89826ef5db4a3 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -430,6 +430,7 @@ cuda_py_strict_test( "//tensorflow/python/ops:variable_scope", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/util:numpy_compat", "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/kernel_tests/array_ops/init_ops_test.py b/tensorflow/python/kernel_tests/array_ops/init_ops_test.py index 460b8f8e064e2c..9c35c61a605cc0 100644 --- a/tensorflow/python/kernel_tests/array_ops/init_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/init_ops_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.util.numpy_compat import np_where # Returns true iff the two initializers produce the same tensor to @@ -714,7 +715,7 @@ def _baseNDArrayCompareToNumpy(self, axis): self.assert_close(actual, expected) def assert_close(self, actual, expected): - wrong_indices = np.where(~np.allclose(actual, expected)) + wrong_indices = np_where(~np.allclose(actual, expected)) mess = "Wrong float answer. Wrong indices: {}".format(wrong_indices) self.assertTrue(np.allclose(actual, expected), mess) diff --git a/tensorflow/python/ops/numpy_ops/tests/BUILD b/tensorflow/python/ops/numpy_ops/tests/BUILD index b21e1bc1a20582..b1f6ceabe56dc6 100644 --- a/tensorflow/python/ops/numpy_ops/tests/BUILD +++ b/tensorflow/python/ops/numpy_ops/tests/BUILD @@ -221,6 +221,7 @@ py_strict_test( "//tensorflow/python/ops/numpy_ops:np_config", "//tensorflow/python/ops/numpy_ops:numpy", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/ops/numpy_ops/tests/np_test.py b/tensorflow/python/ops/numpy_ops/tests/np_test.py index 6e2499960f4849..37c869db58fc48 100644 --- a/tensorflow/python/ops/numpy_ops/tests/np_test.py +++ b/tensorflow/python/ops/numpy_ops/tests/np_test.py @@ -35,6 +35,7 @@ import tensorflow.python.ops.numpy_ops.tests.np_wrapper as tnp import tensorflow.python.ops.numpy_ops.tests.test_util as jtu from tensorflow.python.util import nest +from tensorflow.python.util.numpy_compat import np_where config.parse_flags_with_absl() @@ -683,7 +684,7 @@ def testCountNonzero(self, shape, dtype, axis): for shape in all_shapes for dtype in all_dtypes)) def testNonzero(self, shape, dtype): rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.nonzero(x) # pylint: disable=unnecessary-lambda + onp_fun = lambda x: onp.nonzero(onp.atleast_1d(x)) # pylint: disable=unnecessary-lambda lnp_fun = lambda x: tnp.nonzero(x) # pylint: disable=unnecessary-lambda args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) @@ -2338,7 +2339,7 @@ def onp_fun(*args): for shape in all_shapes for dtype in all_dtypes)) def testWhereOneArgument(self, shape, dtype): rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.where(x) + onp_fun = lambda x: np_where(x) lnp_fun = lambda x: tnp.where(x) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index 5b66904cc4b6cd..3e47c991e0247b 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -237,6 +237,7 @@ py_strict_library( "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/util:dispatch", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", ], @@ -547,6 +548,7 @@ py_strict_library( "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_decorator_py", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_inspect", diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py index 215304c867507c..55505df533d447 100644 --- a/tensorflow/python/ops/ragged/ragged_factory_ops.py +++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py @@ -26,6 +26,7 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.util import dispatch +from tensorflow.python.util.numpy_compat import np_reshape from tensorflow.python.util.tf_export import tf_export @@ -151,9 +152,9 @@ def _ragged_factory(values, row_splits): def _inner_factory(pylist, dtype, shape, name=None): # pylint: disable=unused-argument if dtype is object or dtype is None: - return np.reshape(np.array(pylist, dtype=dtype), shape) + return np_reshape(np.array(pylist, dtype=dtype), shape) else: - return np.reshape(np.array(pylist).astype(dtype), shape) + return np_reshape(np.array(pylist).astype(dtype), shape) return _constant_value( _ragged_factory, _inner_factory, pylist, dtype, ragged_rank, inner_shape diff --git a/tensorflow/python/util/numpy_compat.py b/tensorflow/python/util/numpy_compat.py index e9fa23ae637e09..ce2cfe8220e582 100644 --- a/tensorflow/python/util/numpy_compat.py +++ b/tensorflow/python/util/numpy_compat.py @@ -80,3 +80,64 @@ def np_asarray(values, dtype=None, order=None, copy=None): return np.asarray(values, dtype=dtype, order=order, copy=copy) else: return np.asarray(values, dtype=dtype, order=order) + + +def np_where(condition, x=None, y=None): + """Return elements chosen from x or y depending on condition. + + When only condition is provided, np.where(condition) is a shorthand for + np.asarray(condition).nonzero(). See + https://numpy.org/doc/stable/reference/generated/numpy.where.html. NumPy + 2.1.0rc0 disallows 0D input arrays in nonzero, so np.atleast_1d is used here + to remain compatible with NumPy 1.x. See + https://github.com/numpy/numpy/pull/26268. + + Args: + condition: Array_like, bool. Where True, yield x, otherwise yield y. + x: Array_like. Values from which to choose. x, y and condition need to be + broadcastable to some shape. + y: Array_like. Values from which to choose. x, y and condition need to be + broadcastable to some shape. + + Returns: + An array with elements from x where condition is True, and elements from y + elsewhere. Or the indices of the elements that are non-zero. + """ + if x is None and y is None: + if np.lib.NumpyVersion(np.__version__) >= '2.1.0.rc0': + return np.atleast_1d(np.asarray(condition)).nonzero() + return np.where(condition) + return np.where(condition, x, y) + + +def np_reshape(a, /, shape=None, *, newshape=None, order='C', copy=None): + """Reshapes an array without changing its data. + + NumPy 2.1.0rc1 added shape and copy arguments to numpy.reshape. See + https://github.com/numpy/numpy/pull/26292. Both newshape and shape keywords + are supported, but newshape is going to be deprecated. Use `shape` instead. + + Besides, shape cannot be None now. See + https://github.com/numpy/numpy/blob/v2.1.0rc1/numpy/_core/fromnumeric.py#L309. + Previously, np.reshape with newshape=None returned a copy. To maintain this + behavior, we now use asarray to create an ndarray. + + Args: + a: Array_like. Array to be reshaped. + shape: The new shape of the array. + newshape: The new shape of the array (deprecated). + order: {‘C’, ‘F’, ‘K’}. + copy: bool. If True, then the array data is copied. If None, a copy will + only be made if it’s required by order. For False it raises a ValueError if + a copy cannot be avoided. + + Returns: + This will be a new view object if possible; otherwise, it will be a copy. + """ + if shape is None: + shape = newshape + if np.lib.NumpyVersion(np.__version__) >= '2.1.0.rc0': + if shape is None and newshape is None: + return np.asarray(a, order=order, copy=copy) + return np.reshape(a, shape, order=order, copy=copy) + return np.reshape(a, shape, order=order) From cd8973732f665fec4101e2b0992181418a0daf4e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 14:12:00 -0700 Subject: [PATCH 079/483] Add "tf_nightly*" to local wheels inclusion list. PiperOrigin-RevId: 676966274 --- WORKSPACE | 1 + 1 file changed, 1 insertion(+) diff --git a/WORKSPACE b/WORKSPACE index 269256eadd32c1..73cb4ee89c93fc 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -35,6 +35,7 @@ python_init_repositories( local_wheel_dist_folder = "dist", local_wheel_inclusion_list = [ "tensorflow*", + "tf_nightly*", ], local_wheel_workspaces = ["//:WORKSPACE"], requirements = { From 01012ffef235efb2f941f68424fd502443bed79e Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Fri, 20 Sep 2024 14:32:27 -0700 Subject: [PATCH 080/483] Clarified CanUseSubBufferForImage2d for PowerVR. PiperOrigin-RevId: 676973249 --- tensorflow/lite/delegates/gpu/cl/BUILD | 2 -- tensorflow/lite/delegates/gpu/cl/environment.cc | 17 ++++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index b84cb9a71a46f0..ecc22a1fef7cc4 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -351,11 +351,9 @@ cc_library( ":cl_context", ":cl_device", ":program_cache", - ":util", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:precision", - "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc index ed5b895e4a8164..2ec9c243027896 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.cc +++ b/tensorflow/lite/delegates/gpu/cl/environment.cc @@ -15,12 +15,16 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/environment.h" -#include #include #include -#include "tensorflow/lite/delegates/gpu/cl/util.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { namespace gpu { @@ -252,8 +256,11 @@ bool CanUseSubBufferForImage2d(const GpuInfo& gpu_info) { if (!gpu_info.IsCL11OrHigher()) { return false; } - if (gpu_info.IsPowerVR()) { - // driver issue + if (gpu_info.IsPowerVR() && + gpu_info.powervr_info.driver_version.branch_main <= 23) { + // 24.2@6603887 - works. + // 1.15@6133110 - doesn't work. + // Segfaults, wrong results at model level. return false; } if (gpu_info.IsNvidia()) { From 8bab57b320f8b395c82151c20f90d06b831b5280 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 15:51:53 -0700 Subject: [PATCH 081/483] Add support for float8_e4m3fn and float8_e5m2 matmuls in HLO evalulator and XLA CPU PiperOrigin-RevId: 677000460 --- .../xla/xla/hlo/evaluator/hlo_evaluator.cc | 78 +++++++++++++++++++ .../xla/xla/hlo/evaluator/hlo_evaluator.h | 7 ++ third_party/xla/xla/service/cpu/BUILD | 4 + .../xla/xla/service/cpu/cpu_runtime.cc | 4 + third_party/xla/xla/service/cpu/cpu_runtime.h | 2 + .../cpu/runtime_single_threaded_matmul.h | 13 ++++ .../cpu/runtime_single_threaded_matmul_f8.cc | 39 ++++++++++ .../xla/xla/service/cpu/simple_orc_jit.cc | 2 + third_party/xla/xla/tests/BUILD | 8 ++ .../xla/xla/tests/dot_operation_test.cc | 25 ++++++ 10 files changed, 182 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/runtime_single_threaded_matmul_f8.cc diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index b75903563ea57c..13fe2f9f4d290e 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -4769,4 +4769,82 @@ std::unique_ptr> HloEvaluator::MatmulArray2D( lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulU8); } +/* static */ std::unique_ptr> Array2DF8E5M2ToF32( + const Array2D& input) { + auto result = std::make_unique>(input.height(), input.width()); + for (int64_t rowno = 0; rowno < input.height(); ++rowno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { + (*result)(rowno, colno) = static_cast(input(rowno, colno)); + } + } + return result; +} + +/* static */ std::unique_ptr> Array2DF8E4M3FNToF32( + const Array2D& input) { + auto result = std::make_unique>(input.height(), input.width()); + for (int64_t rowno = 0; rowno < input.height(); ++rowno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { + (*result)(rowno, colno) = static_cast(input(rowno, colno)); + } + } + return result; +} + +/* static */ std::unique_ptr> Array2DF32ToF8E5M2( + const Array2D& input) { + auto result = std::make_unique>(input.height(), + input.width()); + for (int64_t rowno = 0; rowno < input.height(); ++rowno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { + (*result)(rowno, colno) = + static_cast(input(rowno, colno)); + } + } + return result; +} + +/* static */ std::unique_ptr> Array2DF32ToF8E4M3FN( + const Array2D& input) { + auto result = std::make_unique>(input.height(), + input.width()); + for (int64_t rowno = 0; rowno < input.height(); ++rowno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { + (*result)(rowno, colno) = + static_cast(input(rowno, colno)); + } + } + return result; +} + +static bool promote_f8_to_f32 = true; + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, + const Array2D& rhs) { + if (promote_f8_to_f32) { + auto lhs_float = Array2DF8E5M2ToF32(lhs); + auto rhs_float = Array2DF8E5M2ToF32(rhs); + auto result = MatmulArray2D(*lhs_float, *rhs_float); + return Array2DF32ToF8E5M2(*result); + } else { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF8E5M2); + } +} + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, + const Array2D& rhs) { + if (promote_f8_to_f32) { + auto lhs_float = Array2DF8E4M3FNToF32(lhs); + auto rhs_float = Array2DF8E4M3FNToF32(rhs); + auto result = MatmulArray2D(*lhs_float, *rhs_float); + return Array2DF32ToF8E4M3FN(*result); + } else { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF8E4M3FN); + } +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index f26b633b20a3ae..113b4c05f17706 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -52,6 +52,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { @@ -238,6 +239,12 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { const Array2D>& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, + const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, + const Array2D& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index ee615cd9ef75d2..08d8685feb374e 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -81,6 +81,7 @@ filegroup( "runtime_single_threaded_matmul_c128.cc", "runtime_single_threaded_matmul_c64.cc", "runtime_single_threaded_matmul_common.h", + "runtime_single_threaded_matmul_f8.cc", "runtime_single_threaded_matmul_f16.cc", "runtime_single_threaded_matmul_f32.cc", "runtime_single_threaded_matmul_f64.cc", @@ -1281,6 +1282,7 @@ cc_library( "runtime_single_threaded_matmul_f16.cc", "runtime_single_threaded_matmul_f32.cc", "runtime_single_threaded_matmul_f64.cc", + "runtime_single_threaded_matmul_f8.cc", "runtime_single_threaded_matmul_s32.cc", "runtime_single_threaded_matmul_u8.cc", ], @@ -1293,6 +1295,7 @@ cc_library( "//xla/tsl/framework/contraction:eigen_contraction_kernel_no_mkl", "@com_google_absl//absl/base:core_headers", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:ml_dtypes", ], ) @@ -1305,6 +1308,7 @@ cc_library( deps = [ ":runtime_single_threaded_matmul_impl", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:ml_dtypes", ], ) diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 4e209e61f283c6..d4f9119043a610 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -117,6 +117,10 @@ extern const char* const kEigenConv3DF32SymbolName = extern const char* const kDuccFftSymbolName = "__xla_cpu_runtime_DuccFft"; extern const char* const kDuccSingleThreadedFftSymbolName = "__xla_cpu_runtime_DuccSingleThreadedFft"; +extern const char* const kEigenSingleThreadedMatMulF8E4M3FNSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF8E4M3FN"; +extern const char* const kEigenSingleThreadedMatMulF8E5M2SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF8E5M2"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 92beff43a3c0ea..5ac8e39101c844 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -62,6 +62,8 @@ extern const char* const kDuccSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; +extern const char* const kEigenSingleThreadedMatMulF8E4M3FNSymbolName; +extern const char* const kEigenSingleThreadedMatMulF8E5M2SymbolName; extern const char* const kEigenSingleThreadedMatMulC64SymbolName; extern const char* const kEigenSingleThreadedMatMulC128SymbolName; extern const char* const kEigenSingleThreadedMatMulS32SymbolName; diff --git a/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul.h b/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul.h index 7f99e89ed54523..f23291b5510671 100644 --- a/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "Eigen/Core" +#include "tsl/platform/ml_dtypes.h" extern "C" { @@ -65,6 +66,18 @@ extern void __xla_cpu_runtime_EigenSingleThreadedMatMulU8( uint8_t* lhs, uint8_t* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs); +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF8E5M2( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + tsl::float8_e5m2* out, tsl::float8_e5m2* lhs, tsl::float8_e5m2* rhs, + int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs); + +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF8E4M3FN( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + tsl::float8_e4m3fn* out, tsl::float8_e4m3fn* lhs, tsl::float8_e4m3fn* rhs, + int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs); + } // extern "C" #endif // XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ diff --git a/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul_f8.cc b/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul_f8.cc new file mode 100644 index 00000000000000..d29015456a5f3e --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul_f8.cc @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/base/attributes.h" +#include "xla/service/cpu/runtime_single_threaded_matmul.h" +#include "xla/service/cpu/runtime_single_threaded_matmul_common.h" +#include "tsl/platform/ml_dtypes.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF8E5M2( + const void* run_options_ptr, tsl::float8_e5m2* out, tsl::float8_e5m2* lhs, + tsl::float8_e5m2* rhs, int64_t m, int64_t n, int64_t k, + int32_t transpose_lhs, int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF8E4M3FN( + const void* run_options_ptr, tsl::float8_e4m3fn* out, + tsl::float8_e4m3fn* lhs, tsl::float8_e4m3fn* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 44b9f109bf0f6d..0e9a2feb186092 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -561,6 +561,8 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF32); REGISTER_CPU_RUNTIME_SYMBOL(DuccSingleThreadedFft); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF8E4M3FN); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF8E5M2); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 4c61d86e168f46..18d23e96d1cc29 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -927,6 +927,7 @@ xla_test( ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:error_spec", "//xla:reference_util", "//xla:shape_util", "//xla/client:local_client", @@ -936,6 +937,7 @@ xla_test( "//xla/service:hlo_parser", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ @@ -966,6 +968,7 @@ xla_test( ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:error_spec", "//xla:reference_util", "//xla:shape_util", "//xla/client:local_client", @@ -975,6 +978,7 @@ xla_test( "//xla/service:hlo_parser", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ @@ -1010,6 +1014,7 @@ xla_test( ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:error_spec", "//xla:reference_util", "//xla:shape_util", "//xla/client:local_client", @@ -1019,6 +1024,7 @@ xla_test( "//xla/service:hlo_parser", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ @@ -1093,6 +1099,7 @@ xla_test( ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:error_spec", "//xla:reference_util", "//xla:shape_util", "//xla/client:local_client", @@ -1102,6 +1109,7 @@ xla_test( "//xla/service:hlo_parser", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ diff --git a/third_party/xla/xla/tests/dot_operation_test.cc b/third_party/xla/xla/tests/dot_operation_test.cc index 4c3c728f6e1fda..526a13b62d5db3 100644 --- a/third_party/xla/xla/tests/dot_operation_test.cc +++ b/third_party/xla/xla/tests/dot_operation_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/client/lib/matrix.h" #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" +#include "xla/error_spec.h" #include "xla/primitive_util.h" #include "xla/reference_util.h" #include "xla/service/hlo_parser.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -364,6 +366,27 @@ void ParametricDotTest::ComputeAndCompareR2WithError( ComputeAndCompareR2(builder, expected, arguments); } +template <> +void ParametricDotTest::ComputeAndCompareR2WithError( + XlaBuilder* builder, const Array2D& expected, + absl::Span arguments) { + ErrorSpec error_spec(0.3, 3e-3); + error_spec.low_precision_fp_error_spec.type = + primitive_util::NativeToPrimitiveType(); + error_spec.low_precision_fp_error_spec.within_n_values = 1; + ComputeAndCompareR2(builder, expected, arguments, error_spec); +} + +template <> +void ParametricDotTest::ComputeAndCompareR2WithError( + XlaBuilder* builder, const Array2D& expected, + absl::Span arguments) { + ErrorSpec error_spec(0.3, 3e-3); + error_spec.low_precision_fp_error_spec.type = + primitive_util::NativeToPrimitiveType(); + error_spec.low_precision_fp_error_spec.within_n_values = 1; + ComputeAndCompareR2(builder, expected, arguments, error_spec); +} template void ParametricDotTest::TestImpl() { DotTestParam param = GetParam(); @@ -486,6 +509,8 @@ XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl>(); } XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl>(); } #endif XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl(); } +XLA_TEST_P(ParametricDotTest, TestF8E5M2) { TestImpl(); } +XLA_TEST_P(ParametricDotTest, TestF8E4M3FN) { TestImpl(); } XLA_TEST_P(ParametricDotTest, TestU8) { TestImpl(); } From 343058c782ebe85b30c2b46e70d471febcbca8cf Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Fri, 20 Sep 2024 16:46:48 -0700 Subject: [PATCH 082/483] [XLA:TPU] Remove the use of LoopOptimizerBestFitHeap in MemoryBoundLoopOptimizer. PiperOrigin-RevId: 677015326 --- .../xla/service/memory_space_assignment/BUILD | 2 - .../memory_space_assignment/algorithm.cc | 29 +- .../memory_space_assignment/algorithm.h | 3 + .../memory_bound_loop_optimizer.cc | 446 ++++++------------ .../memory_bound_loop_optimizer.h | 84 ++-- .../memory_bound_loop_optimizer_test.cc | 104 +--- 6 files changed, 221 insertions(+), 447 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index 9a16ff01017cf6..f3f989d083f8ea 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -437,7 +437,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", ], @@ -519,7 +518,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 7a5ff4073692c5..0be461746b71d1 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -39,7 +39,6 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -954,12 +953,23 @@ absl::Status MsaAlgorithm::OptimizeMemoryBoundLoop(int loop_start_idx, const int iteration_start_idx = loop_start_idx + loop_size; const int iteration_end_idx = iteration_start_idx + loop_size; - TF_ASSIGN_OR_RETURN(std::unique_ptr optimizer, - MemoryBoundLoopOptimizer::Create( - iteration_start_idx, iteration_end_idx, - hlo_live_range_, alias_analysis_, options_)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr optimizer, + MemoryBoundLoopOptimizer::Create( + iteration_start_idx, iteration_end_idx, options_.max_size_in_bytes, + options_.memory_bound_loop_optimizer_options, hlo_live_range_, + alias_analysis_, *options_.cost_analysis, options_.size_fn, + options_.reserved_scoped_memory_fn)); optimizer->Optimize(); + const int loop_optimized_allocations_original_size = + loop_optimized_allocations_.size(); + for (MemoryBoundLoopOptimizer::LoopValue& value : optimizer->loop_values()) { + if (!value.allocations.empty() && value.IsAllocationTypeSupported()) { + loop_optimized_allocations_.push_back(std::move(value.allocations)); + } + } + // Check if this unrolled loop is in a while loop. const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); @@ -970,12 +980,9 @@ absl::Status MsaAlgorithm::OptimizeMemoryBoundLoop(int loop_start_idx, // Update the loop_optimized_allocations_map_ with the output of the // optimizer. - for (MemoryBoundLoopOptimizer::LoopValue& value : optimizer->loop_values()) { - if (value.allocations.empty() || !value.IsAllocationTypeSupported()) { - continue; - } - loop_optimized_allocations_.push_back(std::move(value.allocations)); - const AllocationSequence& sequence = loop_optimized_allocations_.back(); + for (int i = loop_optimized_allocations_original_size; + i < loop_optimized_allocations_.size(); ++i) { + const AllocationSequence& sequence = loop_optimized_allocations_.at(i); CHECK(!sequence.empty()); VLOG(3) << " alloc: " << sequence.back()->ToString(); for (const auto& allocation : sequence) { diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index 1cfcf1f6094938..52d0f0ee563747 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -35,11 +36,13 @@ limitations under the License. #endif #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc index f5b6c810cb6597..9c4b1a2e8bd39b 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc @@ -36,7 +36,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -63,21 +62,6 @@ namespace xla { namespace memory_space_assignment { namespace { -struct LoopOptimizerChunkInterval { - int64_t begin_idx_in_loop; - int64_t end_idx_in_loop; - EvenOddChunkPair chunks; - - std::string ToString() const { - CHECK(chunks.HasValues()); - return absl::StrFormat( - "begin_idx_in_loop: %d, end_idx_in_loop: %d, even chunk: %s, odd " - "chunk: %s", - begin_idx_in_loop, end_idx_in_loop, chunks.even_chunk->ToString(), - chunks.odd_chunk->ToString()); - } -}; - std::optional GetInstructionIndex( const HloInstruction* instruction, const absl::flat_hash_map& @@ -153,7 +137,7 @@ void LoopOptimizerBestFitHeap::RemoveEvenOddChunkPair( EvenOddChunkPair& chunks) { CheckAllocationIntervalValid(begin_idx_in_loop, end_idx_in_loop); ShiftAllocationIntervalIfRequired(begin_idx_in_loop, end_idx_in_loop); - auto& [even_chunk, odd_chunk] = chunks; + auto [even_chunk, odd_chunk] = chunks; RemoveEvenChunks(begin_idx_in_loop, end_idx_in_loop, even_chunk); RemoveOddChunks(begin_idx_in_loop, end_idx_in_loop, odd_chunk); } @@ -341,17 +325,18 @@ int64_t LoopOptimizerBestFitHeap::LastMemoryOffsetOccupied() const { } /*static*/ absl::StatusOr> -MemoryBoundLoopOptimizer::Create(int loop_start, int loop_end, - const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis, - const Options& options) { - CHECK(options.cost_analysis != nullptr); +MemoryBoundLoopOptimizer::Create( + int loop_start, int loop_end, uint64_t alternate_memory_size, + const MemoryBoundLoopOptimizerOptions& options, + const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, + const CostAnalysis& cost_analysis, + const BufferValue::SizeFunction& size_function, + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn) { std::unique_ptr optimizer = absl::WrapUnique(new MemoryBoundLoopOptimizer( - loop_start, loop_end, options.max_size_in_bytes, - options.memory_bound_loop_optimizer_options, hlo_live_range, - alias_analysis, *options.cost_analysis, options.size_fn, - options.reserved_scoped_memory_fn, options.alignment_in_bytes)); + loop_start, loop_end, alternate_memory_size, options, hlo_live_range, + alias_analysis, cost_analysis, size_function, + reserved_scoped_memory_fn)); TF_RETURN_IF_ERROR(optimizer->Initialize()); return std::move(optimizer); } @@ -362,8 +347,7 @@ MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn, - int64_t alignment_in_bytes) + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn) : loop_start_(loop_start), loop_end_(loop_end), loop_size_(loop_end - loop_start), @@ -373,17 +357,13 @@ MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( alias_analysis_(alias_analysis), cost_analysis_(cost_analysis), size_function_(size_function), - reserved_scoped_memory_fn_(reserved_scoped_memory_fn), - heap_(LoopOptimizerBestFitHeap(alternate_memory_size, - /*loop_size=*/loop_end - loop_start, - alignment_in_bytes)) {} + reserved_scoped_memory_fn_(reserved_scoped_memory_fn) {} absl::Status MemoryBoundLoopOptimizer::Initialize() { const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); VLOG(3) << "MemoryBoundLoopOptimizer::Initialize, loop start: " << loop_start_ - << ", loop end: " << loop_end_ << ", loop size: " << loop_size_ - << ", alternate memory size: " << alternate_memory_size_; + << ", loop end: " << loop_end_ << ", loop size: " << loop_size_; const HloComputation* loop_computation = nullptr; // Initialize the remaining memory array with the size of the alternate // memory. Also populate instructions_in_loop_ and @@ -407,20 +387,11 @@ absl::Status MemoryBoundLoopOptimizer::Initialize() { } else { TF_RET_CHECK(loop_computation == loop_inst->parent()); } - int64_t reserved_memory = + remaining_memory_.push_back( + alternate_memory_size_ - reserved_scoped_memory_fn_(loop_inst, /*operands_in_alternate_memory=*/{}, - /*outputs_in_alternate_memory=*/{}); - if (reserved_memory == 0) { - continue; - } - // Chunks for reserved scoped memory should always be found at offset 0. - EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( - i, i, reserved_memory, /*preferred_offsets=*/{0, 0}); - CHECK(chunks.HasValues()); - CHECK(chunks.even_chunk->size == reserved_memory); - VLOG(3) << "Reserved chunk: " << chunks.even_chunk->ToString() - << " loop index: " << i; + /*outputs_in_alternate_memory=*/{})); } // Create a tree set to keep track of all the values that the loop @@ -838,22 +809,11 @@ std::string MemoryBoundLoopOptimizer::LoopValue::ToString() const { for (const auto& allocation : allocations) { absl::StrAppend(&allocations_str, "\n - ", allocation->ToString()); } - std::string chunk_str; - if (chunks.HasValues()) { - absl::StrAppend(&chunk_str, "\n", - "even chunk: ", chunks.even_chunk->ToString()); - absl::StrAppend(&chunk_str, "\n", - "odd chunk: ", chunks.odd_chunk->ToString()); - absl::StrAppend(&chunk_str, "\n", "alternate memory begin idx in loop: ", - alternate_memory_begin_idx_in_loop.value()); - absl::StrAppend(&chunk_str, "\n", "alternate memory end idx in loop: ", - alternate_memory_end_idx_in_loop.value()); - } return absl::StrCat( "Size: ", size, " savings: ", savings, " savings per byte: ", savings_per_byte, - " allocation type: ", AllocationTypeToString(allocation_type), chunk_str, - "\n", values_str, "\n", allocations_str); + " allocation type: ", AllocationTypeToString(allocation_type), "\n", + values_str, "\n", allocations_str); } bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { @@ -862,14 +822,6 @@ bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { allocation_type == AllocationType::kPrefetch; } -void MemoryBoundLoopOptimizer::LoopValue::SetChunkPairAndInterval( - EvenOddChunkPair chunk_pair, int64_t begin_idx_in_loop, - int64_t end_idx_in_loop) { - chunks = chunk_pair; - alternate_memory_begin_idx_in_loop = begin_idx_in_loop; - alternate_memory_end_idx_in_loop = end_idx_in_loop; -} - void MemoryBoundLoopOptimizer::SortLoopValues() { absl::c_stable_sort(loop_values_, [](const LoopValue& a, const LoopValue& b) { return a.savings_per_byte > b.savings_per_byte; @@ -898,13 +850,9 @@ void MemoryBoundLoopOptimizer::AllocateLoopValues() { VLOG(1) << "Unsupported allocation: " << value.ToString(); } } - VLOG(3) << "Heap after allocating temporaries:\n" - << heap_.MemoryUsageToAsciiArt(); VLOG(3) << "Execution time after allocating temporaries: " << CalculateExecutionTime(); AllocatePrefetches(absl::MakeSpan(prefetch_values)); - VLOG(3) << "Heap after allocating prefetches:\n" - << heap_.MemoryUsageToAsciiArt(); VLOG(3) << "Execution time after allocating prefetches: " << CalculateExecutionTime(); } @@ -949,10 +897,26 @@ void MemoryBoundLoopOptimizer::PostProcess() { value.allocations.back()->AddUse(use); } } - VLOG(3) << "LoopValue: " << value.ToString(); } } +bool MemoryBoundLoopOptimizer::AllocateBetween(int64_t begin_idx, + int64_t end_idx, int64_t size) { + int64_t end_idx_sentinel = end_idx; + if (end_idx < begin_idx) { + end_idx_sentinel += loop_size_; + } + for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { + if (remaining_memory_[i % loop_size_] < size) { + return false; + } + } + for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { + remaining_memory_[i % loop_size_] -= size; + } + return true; +} + bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { VLOG(3) << "AllocateTemporary: " << value.ToString(); if (value.hlo_values.size() > 1) { @@ -961,59 +925,37 @@ bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { } int64_t definition_idx = value.loop_positions.front().first; int64_t max_use_idx; - int64_t begin_idx_in_loop = definition_idx; - int64_t end_idx_in_loop; if (!value.next_iteration_uses.empty()) { max_use_idx = value.next_iteration_uses.back().first; // If max_use_idx >= definition_idx, then this is a loop carried dependence // and we should not have called this function. CHECK_LT(max_use_idx, definition_idx); - end_idx_in_loop = max_use_idx + loop_size_; } else { max_use_idx = value.loop_uses.back().first; - end_idx_in_loop = max_use_idx; } - EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( - begin_idx_in_loop, end_idx_in_loop, value.size); - if (!chunks.HasValues()) { - VLOG(3) << "Could not find Allocation for temporary value: " - << value.ToString(); - return false; + bool success = AllocateBetween(definition_idx, max_use_idx, value.size); + if (success) { + VLOG(3) << "Pos: " << value.loop_positions[0].second; + value.allocations.push_back(std::make_unique( + value.loop_positions[0].second, MemorySpace::kAlternate, std::nullopt, + definition_idx, max_use_idx, + /*is_scoped_allocation=*/false)); + AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); } - value.SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); - VLOG(3) << "Pos: " << value.loop_positions[0].second; - VLOG(3) << "Allocation found for temporary value: " << value.ToString(); - VLOG(3) << "Heap after allocating temporary value: " - << heap_.MemoryUsageToAsciiArt(); - value.allocations.push_back(std::make_unique( - value.loop_positions[0].second, MemorySpace::kAlternate, std::nullopt, - definition_idx, max_use_idx, - /*is_scoped_allocation=*/false)); - AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); - return true; + return success; } bool MemoryBoundLoopOptimizer::AllocatePinned(LoopValue& value) { - int64_t begin_idx_in_loop = 0; - int64_t end_idx_in_loop = loop_size_ - 1; - EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( - begin_idx_in_loop, end_idx_in_loop, value.size); - if (!chunks.HasValues()) { - VLOG(3) << "Could not find Allocation for pinned value: " - << value.ToString(); - return false; + bool success = AllocateBetween(0, loop_size_ - 1, value.size); + if (success) { + CHECK(value.header_position); + value.allocations.push_back(std::make_unique( + *value.header_position, MemorySpace::kAlternate, std::nullopt, 0, + loop_size_, + /*is_scoped_allocation=*/false)); + AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); } - value.SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); - CHECK(value.header_position); - VLOG(3) << "Allocation found for pinned value: " << value.ToString(); - VLOG(3) << "Heap after allocating pinned value: " - << heap_.MemoryUsageToAsciiArt(); - value.allocations.push_back(std::make_unique( - *value.header_position, MemorySpace::kAlternate, std::nullopt, 0, - loop_size_, - /*is_scoped_allocation=*/false)); - AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); - return true; + return success; } bool MemoryBoundLoopOptimizer::AllocatePrefetches( @@ -1063,6 +1005,8 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( << *context.bandwidth_idle_times.rbegin(); } + context.additional_memory_used.resize(loop_size_, 0); + // Allocate prefetches by traversing the loop values in reverse order of // the first uses. for (int value_index : context.value_indices) { @@ -1070,6 +1014,10 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( } for (int i = 0; i < loop_size_; ++i) { + remaining_memory_[i] -= context.additional_memory_used[i]; + VLOG(3) << "Additional memory [" << i + << "]: " << context.additional_memory_used[i]; + VLOG(3) << "Remaining memory [" << i << "]: " << remaining_memory_[i]; VLOG(3) << "Remaining bandwidth [" << i << "] : " << context.bandwidth_idle_times[i]; } @@ -1078,7 +1026,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( bool MemoryBoundLoopOptimizer::AllocatePrefetch( int value_index, AllocatePrefetchesContext& context) { - LoopValue* value = context.values[value_index]; + LoopValue* value = context.values.at(value_index); VLOG(3) << "Allocating value: " << value->ToString(); int first_use_idx = value->loop_uses.front().first; int last_use_idx = value->loop_uses.back().first; @@ -1088,22 +1036,24 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( last_use_idx_sentinel = last_use_idx + loop_size_; CHECK_LT(last_use_idx, first_use_idx); } + bool out_of_memory = false; + for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { + int loop_idx = i % loop_size_; + if (context.additional_memory_used[loop_idx] + value->size > + remaining_memory_[loop_idx]) { + VLOG(3) << "Ran out of memory allocating for uses."; + out_of_memory = true; + } + } + if (out_of_memory) { + return false; + } float copy_resource = cost_analysis_.GetAsyncCopyElapsed(value->hlo_values.front()->shape()); VLOG(3) << "First use: " << value->loop_uses.begin()->second << " use idx: " << first_use_idx << " copy resource: " << copy_resource; - const auto& [even_chunk, odd_chunk] = heap_.FindEvenAndOddAllocationBetween( - first_use_idx, last_use_idx_sentinel, value->size); - if (!even_chunk.has_value() || !odd_chunk.has_value()) { - // Not enough memory to even fit the value in the alternate memory for the - // duration of its live range. - VLOG(3) << "Could not find Allocation for prefetch value: " - << value->ToString(); - return false; - } - - std::optional copy_start_loop_idx; + std::optional copy_start_time; // The general allocation algorithm for prefetches is to first calculate the // default-memory bandwidth idle times at each point (assuming all prefetches // succeeded). We show this pictorially below. We also show the previous @@ -1210,31 +1160,23 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( float accumulated_copy_resource = 0; std::vector early_forced_prefetch_value_indices; int early_forced_prefetch_value_search_index = 0; - VLOG(3) << "Memory usage before allocating prefetch value: " - << value->ToString() << "\n" - << heap_.MemoryUsageToAsciiArt(); - // NOTE: We can, in practice, run the following loop for loop_size - // iterations(one full loop), till first_use_idx - loop_size, as opposed to - // limiting it till last_use_idx_sentinel - loop_size. This will allow a - // prefetch to use all the idle bandwidth available during one full loop - // iteration. - for (int current_idx = first_use_idx - 1; - current_idx >= last_use_idx_sentinel - loop_size_; --current_idx) { - int loop_idx = (current_idx + loop_size_) % loop_size_; + float early_forced_prefetch_additional_memory = 0; + for (int i = first_use_idx - 1; i >= last_use_idx_sentinel - loop_size_; + --i) { + int loop_idx = (i + loop_size_) % loop_size_; // Check if this prefetch rolls over to the previous iteration, check if any // already-scheduled prefetches would violate the FIFO order, and if so, // "early-force" them to be co-scheduled with this prefetch to maintain the // FIFO order. This of course increases the required memory, so also keep // track of additional memory that would be consumed. - if (current_idx < 0) { + if (i < 0) { for (; context.value_indices[early_forced_prefetch_value_search_index] != value_index; ++early_forced_prefetch_value_search_index) { VLOG(3) << "Searching for early forced: " << early_forced_prefetch_value_search_index; - LoopValue* early_forced_value = - context.values[context.value_indices - [early_forced_prefetch_value_search_index]]; + LoopValue* early_forced_value = context.values.at( + context.value_indices[early_forced_prefetch_value_search_index]); if (early_forced_value->allocations.empty()) { continue; } @@ -1257,85 +1199,31 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( } early_forced_prefetch_value_indices.push_back( early_forced_prefetch_value_search_index); - VLOG(3) - << "Memory usage before removing prefetch value for early force: " - << early_forced_value->ToString() << "\n" - << heap_.MemoryUsageToAsciiArt(); - // Remove the original chunk from the heap. - heap_.RemoveEvenOddChunkPair( - early_forced_value->alternate_memory_begin_idx_in_loop.value(), - early_forced_value->alternate_memory_end_idx_in_loop.value(), - early_forced_value->chunks); - } - } - - VLOG(3) << "Loop idx:" << loop_idx << " Early force prefetch values: " - << early_forced_prefetch_value_indices.size(); - VLOG(3) << "Memory usage before adding pending chunks: \n" - << heap_.MemoryUsageToAsciiArt(); - std::vector pending_chunk_intervals; - for (int early_forced_prefetch_value_index : - early_forced_prefetch_value_indices) { - LoopValue* early_forced_value = - context - .values[context.value_indices[early_forced_prefetch_value_index]]; - int64_t begin_idx_in_loop = loop_idx; - int64_t end_idx_in_loop = - early_forced_value->alternate_memory_end_idx_in_loop.value(); - EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( - begin_idx_in_loop, end_idx_in_loop, early_forced_value->size); - if (!chunks.HasValues()) { - VLOG(3) << "Could not allocate between " << begin_idx_in_loop << " and " - << end_idx_in_loop << " for early forced value: " + early_forced_prefetch_additional_memory += early_forced_value->size; + VLOG(3) << "Found early-forced prefetch value: " << early_forced_value->ToString(); - VLOG(3) << "Memory usage after failed allocation: \n" - << heap_.MemoryUsageToAsciiArt(); - break; - } - pending_chunk_intervals.push_back( - {begin_idx_in_loop, end_idx_in_loop, chunks}); - VLOG(3) << "Added pending chunk: " - << pending_chunk_intervals.back().ToString() - << " for value: " << early_forced_value->ToString(); - } - - if (pending_chunk_intervals.size() == - early_forced_prefetch_value_indices.size()) { - int64_t begin_idx_in_loop = current_idx; - int64_t end_idx_in_loop = last_use_idx_sentinel; - EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( - begin_idx_in_loop, end_idx_in_loop, value->size); - if (chunks.HasValues()) { - pending_chunk_intervals.push_back( - {begin_idx_in_loop, end_idx_in_loop, chunks}); - VLOG(3) << "Added pending chunk: " - << pending_chunk_intervals.back().ToString() - << " for current value: " << value->ToString(); - } else { - VLOG(3) << "Could not allocate between " << begin_idx_in_loop << " and " - << end_idx_in_loop << " for value: " << value->ToString(); - VLOG(3) << "Memory usage after failed allocation: \n" - << heap_.MemoryUsageToAsciiArt(); + VLOG(3) << "Early forced prefetch additional memory: " + << early_forced_prefetch_additional_memory; } } - bool out_of_memory = pending_chunk_intervals.size() < - early_forced_prefetch_value_indices.size() + 1; - - // Remove the pending chunks from the heap. - for (auto& pending_chunk_interval : pending_chunk_intervals) { - VLOG(3) << "Removing pending chunk: " - << pending_chunk_interval.ToString(); - heap_.RemoveEvenOddChunkPair(pending_chunk_interval.begin_idx_in_loop, - pending_chunk_interval.end_idx_in_loop, - pending_chunk_interval.chunks); + // Overlap memory overhead only happens if the copy start overlaps with the + // first use (i.e. fully pipelined), so we'd need to account for 2X the + // buffer at this time. + int64_t overlap_memory_overhead = 0; + if (loop_idx == last_use_idx) { + overlap_memory_overhead = value->size; + VLOG(3) << "Loop idx == last use idx (" << loop_idx + << "), overlap memory overhead = " << overlap_memory_overhead; } - VLOG(3) << "Memory usage after removing pending chunks: " - << heap_.MemoryUsageToAsciiArt(); - - if (out_of_memory) { - VLOG(3) << "Ran out of memory for value: " << value->ToString(); + // OOM; give up prefetch. + if (context.additional_memory_used[loop_idx] + value->size + + overlap_memory_overhead + early_forced_prefetch_additional_memory > + remaining_memory_[loop_idx]) { + VLOG(3) << "Ran out of memory. Accumulated copy resource " + << accumulated_copy_resource << " out of " << copy_resource + << " at " << loop_idx; break; } @@ -1355,16 +1243,16 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( (copy_resource - accumulated_copy_resource)); if (bandwidth_idle_time >= copy_resource - accumulated_copy_resource) { accumulated_copy_resource = copy_resource; - copy_start_loop_idx = current_idx; + copy_start_time = loop_idx; VLOG(3) << "Found the complete copy ratio and updated accumulated copy " "resource: " << accumulated_copy_resource; break; - } else if (!copy_start_loop_idx.has_value() && + } else if (!copy_start_time && accumulated_copy_resource + bandwidth_idle_time >= copy_resource * options_.desired_copy_ratio()) { accumulated_copy_resource += bandwidth_idle_time; - copy_start_loop_idx = current_idx; + copy_start_time = loop_idx; VLOG(3) << "Found the desired copy ratio and updated accumulated copy " "resource: " << accumulated_copy_resource; @@ -1373,7 +1261,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( // Even if desired resource isn't reached, and if the options allow it, // allow a fully pipelined prefetch. accumulated_copy_resource += bandwidth_idle_time; - copy_start_loop_idx = current_idx; + copy_start_time = loop_idx; VLOG(3) << "Could not reach the desired copy ratio but scheduling " "fully pipelined prefetch anyway: " << accumulated_copy_resource; @@ -1386,44 +1274,26 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( } // Could not find a suitable copy start time. - if (!copy_start_loop_idx.has_value()) { - // Restore original heap state as is. - VLOG(3) << "Could not find a suitable copy start time for value: " - << value->ToString(); - VLOG(3) << "Memory usage before restoring original state: " - << heap_.MemoryUsageToAsciiArt(); - for (int early_forced_prefetch_value_index : - early_forced_prefetch_value_indices) { - LoopValue* early_forced_value = - context - .values[context.value_indices[early_forced_prefetch_value_index]]; - // Allocate a chunk in at the same offset as the original prefetch. - EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( - early_forced_value->alternate_memory_begin_idx_in_loop.value(), - early_forced_value->alternate_memory_end_idx_in_loop.value(), - early_forced_value->size, - {early_forced_value->chunks.even_chunk->offset, - early_forced_value->chunks.odd_chunk->offset}); - // The chunk should always be present as we are allocating at the same - // offset. - CHECK(chunks.HasValues()); - CHECK_EQ(chunks.even_chunk->offset, - early_forced_value->chunks.even_chunk->offset); - CHECK_EQ(chunks.odd_chunk->offset, - early_forced_value->chunks.odd_chunk->offset); - } - VLOG(3) << "Memory usage after restoring original state: " - << heap_.MemoryUsageToAsciiArt(); + if (!copy_start_time) { return false; } - VLOG(3) << "Success: copy_start_loop_idx: " << copy_start_loop_idx.value() + VLOG(3) << "Success: copy_start_time: " << *copy_start_time << " leftover copy resource: " << (copy_resource - accumulated_copy_resource); - // We are early forcing the prefetches of the previous iteration. This is the - // corresponding copy start index in the previous iteration. - int early_prefetch_copy_start_loop_idx = - (copy_start_loop_idx.value() + loop_size_) % loop_size_; + auto update_additional_memory_used = [&](int loop_idx, int64_t addition) { + VLOG(4) << "Updating additional memory used at " << loop_idx << ". " + << context.additional_memory_used[loop_idx] << " + " << addition + << " => " << (context.additional_memory_used[loop_idx] + addition) + << " (remaining: " << remaining_memory_[loop_idx] << ")"; + context.additional_memory_used[loop_idx] += addition; + CHECK_LE(context.additional_memory_used[loop_idx], + remaining_memory_[loop_idx]); + }; + for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { + int loop_idx = i % loop_size_; + update_additional_memory_used(loop_idx, value->size); + } // We reset accumulated copy resource and then reuse it to accumulate copy // resource time in order to replay the previous for loop. It is important // that we use the same arithmetic operations (as opposed to subtracting from @@ -1433,78 +1303,58 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( --i) { int loop_idx = (i + loop_size_) % loop_size_; float& bandwidth_idle_time = context.bandwidth_idle_times[loop_idx]; + // Overlap memory overhead only happens if the copy start overlaps with the + // first use (i.e. fully pipelined), so we'd need to account for 2X the + // buffer at this time. + int64_t overlap_memory_overhead = 0; + update_additional_memory_used(loop_idx, + value->size + overlap_memory_overhead); if (bandwidth_idle_time < copy_resource - accumulated_copy_resource) { accumulated_copy_resource += bandwidth_idle_time; bandwidth_idle_time = 0; - if (loop_idx == early_prefetch_copy_start_loop_idx) { + if (loop_idx == *copy_start_time) { VLOG(3) << "Remaining copy resource: " << (copy_resource - accumulated_copy_resource); break; } } else { bandwidth_idle_time -= copy_resource - accumulated_copy_resource; - CHECK_EQ(loop_idx, early_prefetch_copy_start_loop_idx); + CHECK_EQ(loop_idx, *copy_start_time); break; } } + // Create the Allocation objects that correspond to the scheduled prefetch. + CHECK(value->header_position); + value->allocations.push_back(std::make_unique( + *value->header_position, MemorySpace::kDefault, std::nullopt, 0, + loop_size_, /*is_scoped_allocation=*/false)); + value->allocations.push_back(std::make_unique( + *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, + ((*copy_start_time - 1) + loop_size_) % loop_size_, first_use_idx, + last_use_idx_sentinel)); + AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); + // Account for the additional memory used by early forcing the already // scheduled prefetches. Also modify the start times of these to this // prefetch's copy start time. - // Allocate the force-early prefetches first, and allocate them in the same - // order as we did to check for out-of-memory, so we can reproduce the same - // allocation pattern. - // TODO(subhankarshah): Instead of depending on the order of allocation, store - // the offsets of the early forced prefetches and use that to allocate them. for (int early_forced_prefetch_value_index : early_forced_prefetch_value_indices) { - LoopValue* early_forced_value = - context - .values[context.value_indices[early_forced_prefetch_value_index]]; + LoopValue* early_forced_value = context.values.at( + context.value_indices[early_forced_prefetch_value_index]); CHECK(!early_forced_value->allocations.empty()); CopyAllocation* early_forced_prefetch = static_cast( early_forced_value->allocations.back().get()); - int64_t begin_idx_in_loop = early_prefetch_copy_start_loop_idx; - int64_t end_idx_in_loop = - early_forced_value->alternate_memory_end_idx_in_loop.value(); - EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( - begin_idx_in_loop, end_idx_in_loop, early_forced_value->size); - // The chunk should always be present as we reproducing the same allocation - // pattern as the out-of-memory check. - CHECK(chunks.HasValues()); - CHECK_LT(begin_idx_in_loop, - early_forced_value->alternate_memory_begin_idx_in_loop.value()); - early_forced_value->SetChunkPairAndInterval(chunks, begin_idx_in_loop, - end_idx_in_loop); + for (int index = early_forced_prefetch->copy_start_schedule_after(); + index >= *copy_start_time; --index) { + update_additional_memory_used(index, early_forced_value->size); + VLOG(3) << "Additional memory used: " << index << " " + << context.additional_memory_used[index]; + } early_forced_prefetch->set_copy_start_schedule_after( - ((early_prefetch_copy_start_loop_idx - 1) + loop_size_) % loop_size_); - VLOG(3) << "Early forced prefetch: " << early_forced_value->ToString(); - VLOG(3) << "Memory usage after allocating early forced prefetch: " - << heap_.MemoryUsageToAsciiArt(); + ((*copy_start_time - 1) + loop_size_) % loop_size_); + VLOG(3) << "Updated prefetch: " << early_forced_prefetch->ToString(); } - - // Create the Allocation objects that correspond to the scheduled prefetch. - CHECK(value->header_position); - value->allocations.push_back(std::make_unique( - *value->header_position, MemorySpace::kDefault, std::nullopt, 0, - loop_size_, /*is_scoped_allocation=*/false)); - int64_t begin_idx_in_loop = copy_start_loop_idx.value(); - int64_t end_idx_in_loop = last_use_idx_sentinel; - // The chunk should always be present as we reproducing the same allocation - // pattern as the out-of-memory check. - EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( - begin_idx_in_loop, end_idx_in_loop, value->size); - CHECK(chunks.HasValues()); - value->SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); - value->allocations.push_back(std::make_unique( - *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, - ((early_prefetch_copy_start_loop_idx - 1) + loop_size_) % loop_size_, - first_use_idx, last_use_idx_sentinel)); - VLOG(3) << "Allocation found for prefetch: " << value->ToString(); - VLOG(3) << "Memory usage after allocating prefetch: " << value->ToString() - << "\n" - << heap_.MemoryUsageToAsciiArt(); - AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); return true; } diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h index d87975c6384b65..5af196b4323af7 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ +#include #include #include #include @@ -25,7 +26,6 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -49,14 +49,8 @@ namespace xla { namespace memory_space_assignment { // Pair of chunks for even and odd loop iterations. -struct EvenOddChunkPair { - std::optional even_chunk; - std::optional odd_chunk; - - bool HasValues() const { - return even_chunk.has_value() && odd_chunk.has_value(); - } -}; +using EvenOddChunkPair = std::pair, + std::optional>; // LoopOptimizerBestFitHeap extends GlobalDecreasingSizeBestFitHeap to track // allocated buffers and their live intervals for the MemoryBoundLoopOptimizer. @@ -140,10 +134,10 @@ class LoopOptimizerBestFitHeap private: // REQUIRES: - // * begin_idx_in_loop <= end_idx_in_loop - // * begin_idx_in_loop is within [-loop_size loop_size) - // * end_idx_in_loop is within [0, 2 * loop_size) - // * end_idx_in_loop - begin_idx_in_loop + 1 <= 2 * loop_size (allocation + // - begin_idx_in_loop <= end_idx_in_loop + // - begin_idx_in_loop is within [-loop_size loop_size) + // - end_idx_in_loop is within [0, 2 * loop_size) + // - end_idx_in_loop - begin_idx_in_loop + 1 <= 2 * loop_size (allocation // colocated in even (or odd) iterations cannot span more than 2 loop // iterations) void CheckAllocationIntervalValid(int64_t begin_idx_in_loop, @@ -260,7 +254,6 @@ class MemoryBoundLoopOptimizer { // We represent each tensor used in the current iteration as a LoopValue, // wrapping the relevant information such as its HLO value, indices and // pointers to its use and position sites in different iterations. - // TODO(b/364621066): Make LoopValue a class. struct LoopValue { // An enum that encodes the allocation type that is suitable for this // LoopValue. See the comment above on what each of these mean. @@ -280,12 +273,6 @@ class MemoryBoundLoopOptimizer { // of a loop value. bool IsAllocationTypeSupported() const; - // Sets the data members `chunks`, `alternate_memory_begin_idx_in_loop`, and - // `alternate_memory_end_idx_in_loop`. - void SetChunkPairAndInterval(EvenOddChunkPair chunk_pair, - int64_t begin_idx_in_loop, - int64_t end_idx_in_loop); - // The HloValues that correspond to this LoopValue. std::vector hlo_values; // The position in the header, if any. @@ -312,25 +299,17 @@ class MemoryBoundLoopOptimizer { float savings_per_byte; // The optimized AllocationSequence. AllocationSequence allocations; - // Chunks for even and odd iterations. If a loop value is double buffered - // then it must have different chunks for even and odd iterations. - EvenOddChunkPair chunks; - // Begin index of loop value in alternate memory. - // REQUIRES: - // * (-loop_size) <= alternate_memory_begin_idx_in_loop - // * alternate_memory_begin_idx_in_loop < loop_size - std::optional alternate_memory_begin_idx_in_loop = std::nullopt; - // End index of loop value in alternate memory. - // REQUIRES: - // * 0 <= alternate_memory_end_idx_in_loop - // * alternate_memory_end_idx_in_loop < 2*loop_size - std::optional alternate_memory_end_idx_in_loop = std::nullopt; }; // Factory method to create and initialize a MemoryBoundLoopOptimizer. static absl::StatusOr> Create( - int loop_start, int loop_end, const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis, const Options& options); + int loop_start, int loop_end, uint64_t alternate_memory_size, + const MemoryBoundLoopOptimizerOptions& options, + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis_, + const CostAnalysis& cost_analysis, + const BufferValue::SizeFunction& size_function, + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn); // Optimize the loop. Initialize must be called first. void Optimize(); @@ -345,16 +324,13 @@ class MemoryBoundLoopOptimizer { // Return the remaining memory vector for each point in time in the loop using // the allocation decisions so far. - std::vector RemainingMemory() const { - return heap_.RemainingMemoryByTime(); + const std::vector& remaining_memory() const { + return remaining_memory_; } int64_t MaxAlternateMemoryUsed() const { - return heap_.LastMemoryOffsetOccupied(); - } - - std::string MemoryUsageToAsciiArt() const { - return heap_.MemoryUsageToAsciiArt(); + return alternate_memory_size_ - *std::min_element(remaining_memory_.begin(), + remaining_memory_.end()); } // The loop start, end, and size accessors. @@ -368,12 +344,15 @@ class MemoryBoundLoopOptimizer { // The values that are requested to be prefetched. absl::Span values; - // A list of indices into values array, sorted by the (descending) start - // time of the first use. + // A list of indices into values array, sorted by the start time of the + // first use. std::vector value_indices; // Default memory remaining bandwidths assuming all prefetches succeeded. std::vector bandwidth_idle_times; + + // Additional memory used while performing prefetching. + std::vector additional_memory_used; }; MemoryBoundLoopOptimizer( @@ -383,8 +362,7 @@ class MemoryBoundLoopOptimizer { const HloAliasAnalysis& alias_analysis_, const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn, - int64_t alignment_in_bytes); + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn); // Initializes the data structures used by the optimizer. absl::Status Initialize(); @@ -406,6 +384,9 @@ class MemoryBoundLoopOptimizer { // Allocate LoopValues by dispatching to the correct Allocate method. void AllocateLoopValues(); + // Allocate and reserve memory between the given indices. + bool AllocateBetween(int64_t begin_idx, int64_t end_idx, int64_t size); + // Perform allocation type kTemporary. Return true if successful. bool AllocateTemporary(LoopValue& value); @@ -459,22 +440,13 @@ class MemoryBoundLoopOptimizer { absl::flat_hash_map instructions_in_next_iteration_; std::vector loop_values_; + std::vector remaining_memory_; absl::flat_hash_map>> uses_in_alternate_mem_; absl::flat_hash_map> positions_in_alternate_mem_; const ReservedScopedMemoryFunction& reserved_scoped_memory_fn_; - - // The heap used to allocate loop values. Since some loop values can be double - // buffered, between successive iterations, they must have different chunks - // for even and odd iterations. We model 4 iterations of the loop to allocate - // the loop values to alternate memory so we can model the buffers that cross - // one or two loop boundaries. The allocations in the 2nd and 3rd iterations - // represent the actual memory view. The 0th and 1st iteration serve to - // account for allocations, whose buffers cross one or two loop boundaries, - // into the 2nd and 3rd iterations. - LoopOptimizerBestFitHeap heap_; }; } // namespace memory_space_assignment diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc index b01874003fb22a..f241269bb6fa77 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -97,7 +97,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.HasValues(); + return chunks.first.has_value() && chunks.second.has_value(); } bool CanFindSameEvenAndOddAllocationBetween(int64_t begin_idx_in_loop, @@ -105,7 +105,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.FindSameEvenAndOddAllocationBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.HasValues(); + return chunks.first.has_value() && chunks.second.has_value(); } bool IsAllocateEvenAndOddBetweenSuccessful(int64_t begin_idx_in_loop, @@ -113,7 +113,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.HasValues(); + return chunks.first.has_value() && chunks.second.has_value(); } bool CanFindEvenAndOddAllocationBetween(int64_t begin_idx_in_loop, @@ -121,7 +121,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.FindEvenAndOddAllocationBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.HasValues(); + return chunks.first.has_value() && chunks.second.has_value(); } std::string GetMemoryUsageAsciiArt() { return heap_.MemoryUsageToAsciiArt(); } @@ -193,9 +193,10 @@ TEST_F(LoopOptimizerBestFitHeapTest, TestAllocateEvenAndOddBetween) { TEST_F(LoopOptimizerBestFitHeapTest, TestRemoveChunk) { EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween(3, 11, 16); - EXPECT_TRUE(chunks.HasValues()); + EXPECT_TRUE(chunks.first.has_value() && chunks.second.has_value()); EvenOddChunkPair second_chunks = heap_.AllocateEvenAndOddBetween(-3, 8, 16); - EXPECT_TRUE(second_chunks.HasValues()); + EXPECT_TRUE(second_chunks.first.has_value() && + second_chunks.second.has_value()); EXPECT_THAT(heap_.RemainingMemoryByTime(), ContainerEq(std::vector{16, 16, 16, 0, 0, 0})); EXPECT_EQ(heap_.LastMemoryOffsetOccupied(), 64); @@ -313,17 +314,12 @@ class MemoryBoundLoopOptimizerTest : public HloTestBase { optimizer_options.set_enabled(true); optimizer_options.set_desired_copy_ratio(0.7); optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); - Options options; - options.max_size_in_bytes = alternate_memory_size; - options.alignment_in_bytes = 8; - options.alternate_memory_space = kAlternateMemorySpace; - options.cost_analysis = cost_analysis_.get(); - options.size_fn = SizeFunction; - options.reserved_scoped_memory_fn = reserved_scoped_memory_fn; - options.memory_bound_loop_optimizer_options = optimizer_options; - TF_ASSIGN_OR_RETURN(optimizer_, MemoryBoundLoopOptimizer::Create( - loop_start, loop_end, *live_range_, - *alias_analysis_, options)); + TF_ASSIGN_OR_RETURN( + optimizer_, + MemoryBoundLoopOptimizer::Create( + loop_start, loop_end, alternate_memory_size, optimizer_options, + *live_range_, *alias_analysis_, *cost_analysis_, SizeFunction, + reserved_scoped_memory_fn)); return optimizer_.get(); } @@ -706,56 +702,6 @@ TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch) { )"; int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - // Although alternate_memory_size=64 is minimum memory needed to fit the copy - // of param0 with desired copy ratio. alternate_memory_size=80 memory will - // ensure complete copy of param0 to alternate memory. - int64_t alternate_memory_size = 80; - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, - loop_start_idx, &optimizer)); - - optimizer->Optimize(); - absl::flat_hash_set seen_uses; - for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : - optimizer->loop_values()) { - LOG(INFO) << loop_value.ToString(); - if (loop_value.hlo_values.front() - ->defining_position() - .instruction->name() == "param0") { - EXPECT_TRUE(loop_value.allocations.back()->is_copy_allocation()); - } - for (const auto& allocation : loop_value.allocations) { - for (const HloUse& use : allocation->uses()) { - EXPECT_FALSE(seen_uses.contains(use)) << use.ToString(); - seen_uses.insert(use); - } - } - } - - // Ensure all of the uses in the loop have an associated use. - for (absl::string_view inst_name : {"op0", "op1", "op2", "op3", "op4"}) { - HloInstruction* inst = - module->entry_computation()->GetInstructionWithName(inst_name); - EXPECT_TRUE(seen_uses.contains(HloUse{inst, 0})) << inst_name; - EXPECT_TRUE(seen_uses.contains(HloUse{inst, 1})) << inst_name; - } - EXPECT_EQ(optimizer->CalculateExecutionTime(), 1.875); - EXPECT_EQ(optimizer->MaxAlternateMemoryUsed(), alternate_memory_size); -} - -TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch2) { - absl::string_view hlo_loop_str = R"( - $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) - $op1 = f32[1,4] add(f32[1,4] $prev_op4, f32[1,4] $op0) - $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op1) - $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) - $op4 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $op3) - ROOT $root = tuple($op4, $param0) - )"; - int loop_start_idx; - MemoryBoundLoopOptimizer* optimizer; - // alternate_memory_size=64 is minimum memory needed to fit the copy of param0 - // with desired copy ratio. int64_t alternate_memory_size = 64; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, @@ -786,9 +732,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch2) { EXPECT_TRUE(seen_uses.contains(HloUse{inst, 0})) << inst_name; EXPECT_TRUE(seen_uses.contains(HloUse{inst, 1})) << inst_name; } - // Check that execution time has increased to 2 since we will wait on copy - // done for param0. - EXPECT_EQ(optimizer->CalculateExecutionTime(), 2); + EXPECT_EQ(optimizer->CalculateExecutionTime(), 1.875); EXPECT_EQ(optimizer->MaxAlternateMemoryUsed(), alternate_memory_size); } @@ -829,10 +773,10 @@ TEST_F(MemoryBoundLoopOptimizerTest, ReservedScopedMemory) { // Check that a spurious GetTupleElement instruction in a later iteration of a // loop does not cause MSA to CHECK fail, when identifying loops. Prior to the -// change introduced with this test, IdentifyAndOptimizeMemoryBoundLoops() +// change instroduced with this test, IdentifyAndOptimizeMemoryBoundLoops() // would recognize 4 iterations to the loop thinking that gte is a repeat of // op2. Doing so triggers the CHECKs introduced by the change that added this -// test to fail. So, the point of this test is to verify that we do not check +// test to fail. So, the point of this test is to verfiy that we do not check // fail. TEST_F(MemoryBoundLoopOptimizerTest, GetTupleElement) { absl::string_view hlo_string = R"( @@ -965,7 +909,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 464; + int64_t alternate_memory_size = 432; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -1041,7 +985,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { EXPECT_EQ(optimizer->CalculateExecutionTime(), 12.5); // Check the memory used at each point of the loop. - std::vector remaining_memory = optimizer->RemainingMemory(); + const std::vector& remaining_memory = optimizer->remaining_memory(); // Time 0: 3 temporaries (16 B) + param0 (128 B) + param1 (128 B) EXPECT_EQ(remaining_memory.at(0), alternate_memory_size - (3 * 16 + 128 + 128)); @@ -1105,7 +1049,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithoutOverlap) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 208; + int64_t alternate_memory_size = 192; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -1189,7 +1133,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap2) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 464; + int64_t alternate_memory_size = 432; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -1358,13 +1302,13 @@ TEST_F(MemoryBoundLoopOptimizerTest, TempAndPinnedAllocations) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); - int64_t alternate_memory_size = 80; + int64_t alternate_memory_size = 64; TF_ASSERT_OK_AND_ASSIGN( auto optimizer, CreateOptimizer(19, 24, module.get(), alternate_memory_size)); optimizer->Optimize(); - std::vector remaining_memory = optimizer->RemainingMemory(); + const std::vector& remaining_memory = optimizer->remaining_memory(); // Time 0: 3 temporaries (16 B) + 1 pinned (16 B) EXPECT_EQ(remaining_memory.at(0), alternate_memory_size - (3 * 16 + 16)); // Time 1: 3 temporaries (16 B) + 1 pinned (16 B) @@ -1429,12 +1373,12 @@ TEST_F(MemoryBoundLoopOptimizerTest, NegativeSavingNotPinned) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); - int64_t alternate_memory_size = 72; + int64_t alternate_memory_size = 52; TF_ASSERT_OK_AND_ASSIGN( auto optimizer, CreateOptimizer(21, 27, module.get(), alternate_memory_size)); optimizer->Optimize(); - std::vector remaining_memory = optimizer->RemainingMemory(); + const std::vector& remaining_memory = optimizer->remaining_memory(); // We expect that pinned_prev_param0 would not get pinned due to negative // savings: 32(uses) - 28 * 16(size) = -416 Time 0: 3 temporaries (16 B) + 1 // pinned (4 B) From 55b572cbdb12dd609db06f7665cbb7f085c1f11b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 16:55:19 -0700 Subject: [PATCH 083/483] Remove `--@xla//xla/tsl:wheel_dependency=true` to fix nightlies CI. PiperOrigin-RevId: 677017563 --- .bazelrc | 10 +++++----- third_party/xla/.bazelrc | 10 +++++----- third_party/xla/third_party/tsl/.bazelrc | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.bazelrc b/.bazelrc index ec9e9cc552831b..8201ce4582a00f 100644 --- a/.bazelrc +++ b/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index ec9e9cc552831b..8201ce4582a00f 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index ec9e9cc552831b..8201ce4582a00f 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -740,27 +740,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_xla//xla/tsl:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. From e5d72c3732735d70e42ff075b0f0c486f1eafb68 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 20 Sep 2024 16:58:43 -0700 Subject: [PATCH 084/483] Integrate StableHLO at openxla/stablehlo@9bb28f84 PiperOrigin-RevId: 677018398 --- third_party/stablehlo/temporary.patch | 352 ------------------ third_party/stablehlo/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 352 ------------------ .../xla/third_party/stablehlo/workspace.bzl | 4 +- 4 files changed, 4 insertions(+), 708 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 3e0b0e66bc8a4f..8b137891791fe9 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,353 +1 @@ -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -@@ -41,3 +41,170 @@ - %3 = stablehlo.imag %1 : (tensor<4xcomplex>) -> tensor<4xf64> - func.return %2, %3 : tensor<4xf64>, tensor<4xf64> - } -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_dims -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> -+func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> -+ func.return %0 : tensor<4x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], -+// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>) -> tensor<4x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> -+func.func @gather_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) <{ -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> -+ func.return %0 : tensor<4x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_dim_size_zero -+// CHECK-NEXT: %[[iota:.*]] = stablehlo.iota dim = 0 : tensor<0x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota]], %arg1, dim = 3 : (tensor<0x3x5x1xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x2xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1], -+// CHECK-SAME: start_index_map = [0, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<0x2x9xi32>, tensor<0x3x5x2xi32>) -> tensor<0x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<0x3x5x8xi32> -+func.func @gather_with_batching_dim_size_zero(%arg0: tensor<0x2x9xi32>, %arg1: tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) <{ -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0], -+ start_indices_batching_dims = [0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> -+ func.return %0 : tensor<0x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_with_batching_dims -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = false}> -+// CHECK: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x7x9xi32> -+func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { -+ // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1, 3], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1, 3], -+ index_vector_dim = 3 -+ >, -+ unique_indices = false -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> -+ func.return %0 : tensor<3x2x4x7x9xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = true}> -+// CHECK: (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x9xi32> -+func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> { -+ // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1], -+ index_vector_dim = 3 -+ >, -+ unique_indices = true -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> -+ func.return %0 : tensor<3x2x4x9xi32> -+} -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -@@ -12,14 +12,22 @@ - - #include - -+#include - #include -+#include -+#include -+#include - - #include "llvm/ADT/APFloat.h" -+#include "llvm/ADT/STLExtras.h" -+#include "llvm/ADT/SmallVector.h" - #include "llvm/Support/ErrorHandling.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Diagnostics.h" - #include "mlir/IR/PatternMatch.h" -+#include "mlir/Rewrite/FrozenRewritePatternSet.h" - #include "mlir/Support/LLVM.h" - #include "mlir/Transforms/DialectConversion.h" - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -@@ -58,6 +66,132 @@ - return targetVersion; - } - -+SmallVector mergeSortedDims(ArrayRef dims1, -+ ArrayRef dims2) { -+ SmallVector result; -+ result.reserve(dims1.size() + dims2.size()); -+ std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(), -+ std::back_inserter(result)); -+ return result; -+} -+ -+// Returns an updated indices tensor such that an `IotaOp` is prepended for each -+// dim in `indicesBatchingDims` with a `ConcatenateOp`. -+// -+// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have -+// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s. -+Value createConcatIndices(Value indices, int64_t indexVectorDim, -+ ArrayRef indicesBatchingDims, -+ PatternRewriter &rewriter) { -+ Location loc = indices.getLoc(); -+ auto indicesType = cast(indices.getType()); -+ bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); -+ -+ SmallVector iotaShape(indicesType.getShape()); -+ if (indexVectorDimOnLastDim) { -+ iotaShape.push_back(1); -+ } else { -+ iotaShape[indexVectorDim] = 1; -+ } -+ auto iotaType = -+ RankedTensorType::get(iotaShape, indicesType.getElementType()); -+ -+ SmallVector indicesToConcat; -+ indicesToConcat.reserve(indicesBatchingDims.size() + 1); -+ for (int64_t batchingDim : indicesBatchingDims) { -+ indicesToConcat.push_back( -+ rewriter.create(loc, iotaType, batchingDim)); -+ } -+ if (indexVectorDimOnLastDim) { -+ indicesToConcat.push_back( -+ rewriter.create(loc, iotaType, indices)); -+ } else { -+ indicesToConcat.push_back(indices); -+ } -+ return rewriter.create(loc, indicesToConcat, indexVectorDim); -+} -+ -+//===----------------------------------------------------------------------===// -+// Patterns (non DRR) -+//===----------------------------------------------------------------------===// -+ -+// Converts a `GatherOp` with batching dims to a `GatherOp` without batching -+// dims, such that each batching dim becomes a collapsed slice dim with a -+// corresponding `IotaOp` concatenated to the start indices. -+class GatherWithBatchingDimsExpander : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(GatherOp op, -+ PatternRewriter &rewriter) const override { -+ GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); -+ ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); -+ if (operandBatchingDims.empty()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "gather op has no batching dims"; -+ }); -+ } -+ -+ SmallVector newCollapsedSliceDims = mergeSortedDims( -+ operandBatchingDims, dimNumbers.getCollapsedSliceDims()); -+ SmallVector newStartIndexMap = -+ llvm::to_vector(llvm::concat( -+ operandBatchingDims, dimNumbers.getStartIndexMap())); -+ Value newIndices = createConcatIndices( -+ op.getStartIndices(), dimNumbers.getIndexVectorDim(), -+ dimNumbers.getStartIndicesBatchingDims(), rewriter); -+ rewriter.replaceOpWithNewOp( -+ op, op.getOperand(), newIndices, -+ GatherDimensionNumbersAttr::get( -+ op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, -+ /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, -+ newStartIndexMap, dimNumbers.getIndexVectorDim()), -+ op.getSliceSizes(), /*indicesAreSorted=*/false); -+ -+ return success(); -+ } -+}; -+ -+// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching -+// dims, such that each batching dim becomes an inserted window dim with a -+// corresponding `IotaOp` concatenated to the scatter indices. -+class ScatterWithBatchingDimsExpander : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(ScatterOp op, -+ PatternRewriter &rewriter) const override { -+ ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); -+ ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); -+ if (inputBatchingDims.empty()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "scatter op has no batching dims"; -+ }); -+ } -+ -+ SmallVector newInsertedWindowDims = -+ mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims()); -+ SmallVector newScatterDimsToOperandDims = -+ llvm::to_vector(llvm::concat( -+ inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); -+ Value newIndices = createConcatIndices( -+ op.getScatterIndices(), dimNumbers.getIndexVectorDim(), -+ dimNumbers.getScatterIndicesBatchingDims(), rewriter); -+ auto newScatterOp = rewriter.create( -+ op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, -+ op.getUpdates(), -+ ScatterDimensionNumbersAttr::get( -+ op.getContext(), dimNumbers.getUpdateWindowDims(), -+ newInsertedWindowDims, -+ /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, -+ newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), -+ /*indicesAreSorted=*/false, op.getUniqueIndices()); -+ -+ newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); -+ rewriter.replaceOp(op, newScatterOp.getResults()); -+ -+ return success(); -+ } -+}; -+ - //===----------------------------------------------------------------------===// - // Pass - //===----------------------------------------------------------------------===// -@@ -107,10 +241,16 @@ - void populateStablehloCreateCompatibilityExpanderPatterns( - RewritePatternSet *patterns, MLIRContext *context, - vhlo::Version targetVersion) { -+ // StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0. -+ if (targetVersion < vhlo::Version(1, 1, 0)) { -+ patterns -+ ->add( -+ context); -+ } - // StableHLO TanOp is introduced in v1.4.0. - if (targetVersion < vhlo::Version(1, 4, 0)) { -- patterns->add(context); -- patterns->add(context); -+ patterns->add(context); - } - } - diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 97fd0b990fc1c7..61608344c772f9 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "78c753ad13ad8205cacc5fcc12418c1ac97276c7" - STABLEHLO_SHA256 = "b7fef892020eb465a6d1ed921160f5229398ba10acff36b6345171b9867ccc7c" + STABLEHLO_COMMIT = "9bb28f84c281795783639364b727e4398dcec570" + STABLEHLO_SHA256 = "44f90a3b6e8c7fba454644a7457b71327f643196ee1b1f69a3b210604ea8b5a9" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 3e0b0e66bc8a4f..8b137891791fe9 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,353 +1 @@ -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -@@ -41,3 +41,170 @@ - %3 = stablehlo.imag %1 : (tensor<4xcomplex>) -> tensor<4xf64> - func.return %2, %3 : tensor<4xf64>, tensor<4xf64> - } -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_dims -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> -+func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> -+ func.return %0 : tensor<4x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], -+// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>) -> tensor<4x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> -+func.func @gather_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) <{ -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> -+ func.return %0 : tensor<4x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_dim_size_zero -+// CHECK-NEXT: %[[iota:.*]] = stablehlo.iota dim = 0 : tensor<0x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota]], %arg1, dim = 3 : (tensor<0x3x5x1xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x2xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1], -+// CHECK-SAME: start_index_map = [0, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<0x2x9xi32>, tensor<0x3x5x2xi32>) -> tensor<0x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<0x3x5x8xi32> -+func.func @gather_with_batching_dim_size_zero(%arg0: tensor<0x2x9xi32>, %arg1: tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) <{ -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0], -+ start_indices_batching_dims = [0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> -+ func.return %0 : tensor<0x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_with_batching_dims -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = false}> -+// CHECK: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x7x9xi32> -+func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { -+ // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1, 3], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1, 3], -+ index_vector_dim = 3 -+ >, -+ unique_indices = false -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> -+ func.return %0 : tensor<3x2x4x7x9xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = true}> -+// CHECK: (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x9xi32> -+func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> { -+ // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1], -+ index_vector_dim = 3 -+ >, -+ unique_indices = true -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> -+ func.return %0 : tensor<3x2x4x9xi32> -+} -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -@@ -12,14 +12,22 @@ - - #include - -+#include - #include -+#include -+#include -+#include - - #include "llvm/ADT/APFloat.h" -+#include "llvm/ADT/STLExtras.h" -+#include "llvm/ADT/SmallVector.h" - #include "llvm/Support/ErrorHandling.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Diagnostics.h" - #include "mlir/IR/PatternMatch.h" -+#include "mlir/Rewrite/FrozenRewritePatternSet.h" - #include "mlir/Support/LLVM.h" - #include "mlir/Transforms/DialectConversion.h" - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -@@ -58,6 +66,132 @@ - return targetVersion; - } - -+SmallVector mergeSortedDims(ArrayRef dims1, -+ ArrayRef dims2) { -+ SmallVector result; -+ result.reserve(dims1.size() + dims2.size()); -+ std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(), -+ std::back_inserter(result)); -+ return result; -+} -+ -+// Returns an updated indices tensor such that an `IotaOp` is prepended for each -+// dim in `indicesBatchingDims` with a `ConcatenateOp`. -+// -+// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have -+// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s. -+Value createConcatIndices(Value indices, int64_t indexVectorDim, -+ ArrayRef indicesBatchingDims, -+ PatternRewriter &rewriter) { -+ Location loc = indices.getLoc(); -+ auto indicesType = cast(indices.getType()); -+ bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); -+ -+ SmallVector iotaShape(indicesType.getShape()); -+ if (indexVectorDimOnLastDim) { -+ iotaShape.push_back(1); -+ } else { -+ iotaShape[indexVectorDim] = 1; -+ } -+ auto iotaType = -+ RankedTensorType::get(iotaShape, indicesType.getElementType()); -+ -+ SmallVector indicesToConcat; -+ indicesToConcat.reserve(indicesBatchingDims.size() + 1); -+ for (int64_t batchingDim : indicesBatchingDims) { -+ indicesToConcat.push_back( -+ rewriter.create(loc, iotaType, batchingDim)); -+ } -+ if (indexVectorDimOnLastDim) { -+ indicesToConcat.push_back( -+ rewriter.create(loc, iotaType, indices)); -+ } else { -+ indicesToConcat.push_back(indices); -+ } -+ return rewriter.create(loc, indicesToConcat, indexVectorDim); -+} -+ -+//===----------------------------------------------------------------------===// -+// Patterns (non DRR) -+//===----------------------------------------------------------------------===// -+ -+// Converts a `GatherOp` with batching dims to a `GatherOp` without batching -+// dims, such that each batching dim becomes a collapsed slice dim with a -+// corresponding `IotaOp` concatenated to the start indices. -+class GatherWithBatchingDimsExpander : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(GatherOp op, -+ PatternRewriter &rewriter) const override { -+ GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); -+ ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); -+ if (operandBatchingDims.empty()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "gather op has no batching dims"; -+ }); -+ } -+ -+ SmallVector newCollapsedSliceDims = mergeSortedDims( -+ operandBatchingDims, dimNumbers.getCollapsedSliceDims()); -+ SmallVector newStartIndexMap = -+ llvm::to_vector(llvm::concat( -+ operandBatchingDims, dimNumbers.getStartIndexMap())); -+ Value newIndices = createConcatIndices( -+ op.getStartIndices(), dimNumbers.getIndexVectorDim(), -+ dimNumbers.getStartIndicesBatchingDims(), rewriter); -+ rewriter.replaceOpWithNewOp( -+ op, op.getOperand(), newIndices, -+ GatherDimensionNumbersAttr::get( -+ op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, -+ /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, -+ newStartIndexMap, dimNumbers.getIndexVectorDim()), -+ op.getSliceSizes(), /*indicesAreSorted=*/false); -+ -+ return success(); -+ } -+}; -+ -+// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching -+// dims, such that each batching dim becomes an inserted window dim with a -+// corresponding `IotaOp` concatenated to the scatter indices. -+class ScatterWithBatchingDimsExpander : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(ScatterOp op, -+ PatternRewriter &rewriter) const override { -+ ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); -+ ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); -+ if (inputBatchingDims.empty()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "scatter op has no batching dims"; -+ }); -+ } -+ -+ SmallVector newInsertedWindowDims = -+ mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims()); -+ SmallVector newScatterDimsToOperandDims = -+ llvm::to_vector(llvm::concat( -+ inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); -+ Value newIndices = createConcatIndices( -+ op.getScatterIndices(), dimNumbers.getIndexVectorDim(), -+ dimNumbers.getScatterIndicesBatchingDims(), rewriter); -+ auto newScatterOp = rewriter.create( -+ op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, -+ op.getUpdates(), -+ ScatterDimensionNumbersAttr::get( -+ op.getContext(), dimNumbers.getUpdateWindowDims(), -+ newInsertedWindowDims, -+ /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, -+ newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), -+ /*indicesAreSorted=*/false, op.getUniqueIndices()); -+ -+ newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); -+ rewriter.replaceOp(op, newScatterOp.getResults()); -+ -+ return success(); -+ } -+}; -+ - //===----------------------------------------------------------------------===// - // Pass - //===----------------------------------------------------------------------===// -@@ -107,10 +241,16 @@ - void populateStablehloCreateCompatibilityExpanderPatterns( - RewritePatternSet *patterns, MLIRContext *context, - vhlo::Version targetVersion) { -+ // StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0. -+ if (targetVersion < vhlo::Version(1, 1, 0)) { -+ patterns -+ ->add( -+ context); -+ } - // StableHLO TanOp is introduced in v1.4.0. - if (targetVersion < vhlo::Version(1, 4, 0)) { -- patterns->add(context); -- patterns->add(context); -+ patterns->add(context); - } - } - diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 97fd0b990fc1c7..61608344c772f9 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "78c753ad13ad8205cacc5fcc12418c1ac97276c7" - STABLEHLO_SHA256 = "b7fef892020eb465a6d1ed921160f5229398ba10acff36b6345171b9867ccc7c" + STABLEHLO_COMMIT = "9bb28f84c281795783639364b727e4398dcec570" + STABLEHLO_SHA256 = "44f90a3b6e8c7fba454644a7457b71327f643196ee1b1f69a3b210604ea8b5a9" # LINT.ThenChange(Google-internal path) tf_http_archive( From 2150a5924b91415e7f7d5729e18051f62801573a Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 20 Sep 2024 17:24:23 -0700 Subject: [PATCH 085/483] [HLO Componentization] Remove spurious dependencies from hlo parser. PiperOrigin-RevId: 677024981 --- third_party/xla/xla/service/BUILD | 3 ++- third_party/xla/xla/service/hlo_lexer.cc | 7 ++++++- third_party/xla/xla/service/hlo_lexer.h | 3 +-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index c307f128ca1342..8216d818fb8085 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -7051,10 +7051,11 @@ cc_library( ], deps = [ "//xla:shape_util", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/service/hlo_lexer.cc b/third_party/xla/xla/service/hlo_lexer.cc index 18ca3fcb775277..546e4989a380fe 100644 --- a/third_party/xla/xla/service/hlo_lexer.cc +++ b/third_party/xla/xla/service/hlo_lexer.cc @@ -15,17 +15,22 @@ limitations under the License. #include "xla/service/hlo_lexer.h" +#include #include -#include #include #include #include #include "absl/base/casts.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/escaping.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "xla/primitive_util.h" #include "xla/util.h" #include "tsl/platform/numbers.h" diff --git a/third_party/xla/xla/service/hlo_lexer.h b/third_party/xla/xla/service/hlo_lexer.h index 8a7547ff679834..7f6346da558046 100644 --- a/third_party/xla/xla/service/hlo_lexer.h +++ b/third_party/xla/xla/service/hlo_lexer.h @@ -16,13 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_LEXER_H_ #define XLA_SERVICE_HLO_LEXER_H_ +#include #include #include #include #include "absl/strings/string_view.h" -#include "xla/shape.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/regexp.h" From 33054e3206755c54dbe1aacb89d2fb39b29389c4 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 20 Sep 2024 18:39:50 -0700 Subject: [PATCH 086/483] [GPU][NFC] Add CUDA 12.6.1 redistribution. This closes https://github.com/openxla/xla/pull/17003 PiperOrigin-RevId: 677042671 --- third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl | 5 +++++ .../third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index 6934b75b47852d..417a237053c8e7 100644 --- a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -58,6 +58,10 @@ CUDA_REDIST_JSON_DICT = { "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json", "87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab", ], + "12.6.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.1.json", + "22ddfeb81a6f9cee4a708a2e3b4db1c36c7db0a1daa1f33f9c7f2f12a1e790de", + ], } CUDNN_REDIST_JSON_DICT = { @@ -134,6 +138,7 @@ CUDA_NCCL_WHEELS = { "12.5.0": CUDA_12_NCCL_WHEEL_DICT, "12.5.1": CUDA_12_NCCL_WHEEL_DICT, "12.6.0": CUDA_12_NCCL_WHEEL_DICT, + "12.6.1": CUDA_12_NCCL_WHEEL_DICT, } REDIST_VERSIONS_TO_BUILD_TEMPLATES = { diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index 6934b75b47852d..417a237053c8e7 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -58,6 +58,10 @@ CUDA_REDIST_JSON_DICT = { "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json", "87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab", ], + "12.6.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.1.json", + "22ddfeb81a6f9cee4a708a2e3b4db1c36c7db0a1daa1f33f9c7f2f12a1e790de", + ], } CUDNN_REDIST_JSON_DICT = { @@ -134,6 +138,7 @@ CUDA_NCCL_WHEELS = { "12.5.0": CUDA_12_NCCL_WHEEL_DICT, "12.5.1": CUDA_12_NCCL_WHEEL_DICT, "12.6.0": CUDA_12_NCCL_WHEEL_DICT, + "12.6.1": CUDA_12_NCCL_WHEEL_DICT, } REDIST_VERSIONS_TO_BUILD_TEMPLATES = { From d039edee6513d7520f3ce8de59776c3da31e71c4 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 20 Sep 2024 19:12:08 -0700 Subject: [PATCH 087/483] [xla:SpmdPartitioner] Support partitioning along the explicit batch dimensions in gather instructions. Explicit batch dimensions are added for gather instructions in https://github.com/openxla/stablehlo/pull/2084. This cl allows us partitioning along these explicit batch dimensions. Before this cl, we already have `PartitionGatherIndexParallelDimensions`, where the index parallel dimensions are implicit batch dimensions (the indices are usually concatenation of iota). We reuse most of the code in this function and implement `PartitionGatherExplicitBatchDimensions`. PiperOrigin-RevId: 677049477 --- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 155 ++++----- .../service/spmd/gather_scatter_handler.cc | 317 +++++++++++------- .../xla/xla/service/spmd/spmd_partitioner.h | 1 + .../xla/service/spmd/spmd_partitioner_test.cc | 67 +++- 4 files changed, 317 insertions(+), 223 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 881744646f0baa..aa7a485cbc685d 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -1369,22 +1369,24 @@ namespace { absl::InlinedVector GetGatherScatterOperandPassthroughOperandDims( const Shape& operand_shape, absl::Span collapsed_or_inserted_dims, + absl::Span operand_batching_dims, absl::Span index_map, absl::Span offset_or_window_dims, absl::Span slice_size) { absl::InlinedVector passthrough_dims; - int64_t collapsed = 0; - for (int64_t i = 0; i != operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { - collapsed++; + int64_t collapsed_or_batching = 0; + for (int64_t i = 0; i < operand_shape.rank(); ++i) { + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(operand_batching_dims, i)) { + collapsed_or_batching++; continue; } if (slice_size[i] != operand_shape.dimensions(i)) { continue; } - int64_t offset_dim = offset_or_window_dims[i - collapsed]; - if (i - collapsed > 0 && - offset_dim < offset_or_window_dims[i - collapsed - 1]) { + if (i - collapsed_or_batching > 0 && + offset_or_window_dims[i - collapsed_or_batching] < + offset_or_window_dims[i - collapsed_or_batching - 1]) { // Output offsets are transposed, we do not support this case. continue; } @@ -1397,22 +1399,25 @@ absl::InlinedVector GetGatherScatterOperandPassthroughOutputOrUpdateDims( const int64_t output_or_update_rank, const Shape& operand_shape, absl::Span collapsed_or_inserted_dims, + absl::Span operand_batching_dims, absl::Span index_map, absl::Span offset_or_window_dims, absl::Span slice_size) { auto operand_passthrough_dims = GetGatherScatterOperandPassthroughOperandDims( - operand_shape, collapsed_or_inserted_dims, index_map, - offset_or_window_dims, slice_size); + operand_shape, collapsed_or_inserted_dims, operand_batching_dims, + index_map, offset_or_window_dims, slice_size); absl::InlinedVector passthrough_dims; - int64_t collapsed = 0; + int64_t collapsed_or_batching = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { - collapsed++; + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(operand_batching_dims, i)) { + collapsed_or_batching++; + continue; } if (!absl::c_linear_search(operand_passthrough_dims, i)) { continue; } - int64_t offset_dim = offset_or_window_dims[i - collapsed]; + int64_t offset_dim = offset_or_window_dims[i - collapsed_or_batching]; passthrough_dims.push_back(offset_dim); } return passthrough_dims; @@ -1426,6 +1431,7 @@ std::optional PassthroughOperandToGatherOutputOrScatterUpdate( const Shape& operand_shape, const HloSharding& operand_sharding, const int64_t output_or_update_rank, absl::Span collapsed_or_inserted_dims, + absl::Span operand_batching_dims, absl::Span index_map, absl::Span offset_or_window_dims, absl::Span slice_size, const int64_t index_vector_dim) { @@ -1433,18 +1439,20 @@ std::optional PassthroughOperandToGatherOutputOrScatterUpdate( return std::nullopt; } auto operand_passthrough_dims = GetGatherScatterOperandPassthroughOperandDims( - operand_shape, collapsed_or_inserted_dims, index_map, - offset_or_window_dims, slice_size); + operand_shape, collapsed_or_inserted_dims, operand_batching_dims, + index_map, offset_or_window_dims, slice_size); DimensionVector passthrough_tile(output_or_update_rank, 1); - int64_t collapsed = 0; + int64_t collapsed_or_batching = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { - collapsed++; + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(operand_batching_dims, i)) { + collapsed_or_batching++; + continue; } if (!absl::c_linear_search(operand_passthrough_dims, i)) { continue; } - int64_t offset_dim = offset_or_window_dims[i - collapsed]; + int64_t offset_dim = offset_or_window_dims[i - collapsed_or_batching]; passthrough_tile[offset_dim] = operand_sharding.tile_assignment().dim(i); } HloSharding replicate_non_passthrough_dims = @@ -1475,6 +1483,7 @@ std::optional PassthroughOperandToGatherOutputOrScatterUpdate( std::optional PassthroughGatherOutputOrScatterUpdateToOperand( const Shape& operand_shape, const HloSharding& output_or_update_sharding, absl::Span collapsed_or_inserted_dims, + absl::Span operand_batching_dims, absl::Span index_map, absl::Span offset_or_window_dims, absl::Span slice_size) { @@ -1483,20 +1492,22 @@ std::optional PassthroughGatherOutputOrScatterUpdateToOperand( return output_or_update_sharding; } auto operand_passthrough_dims = GetGatherScatterOperandPassthroughOperandDims( - operand_shape, collapsed_or_inserted_dims, index_map, - offset_or_window_dims, slice_size); + operand_shape, collapsed_or_inserted_dims, operand_batching_dims, + index_map, offset_or_window_dims, slice_size); DimensionVector passthrough_tile(operand_shape.rank(), 1); - int64_t collapsed = 0; + int64_t collapsed_or_batching = 0; // Relevant dims have shardings passed to the operand. DimensionVector relevant_output_or_update_dims; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { - collapsed++; + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(operand_batching_dims, i)) { + collapsed_or_batching++; + continue; } if (!absl::c_linear_search(operand_passthrough_dims, i)) { continue; } - int64_t offset_dim = offset_or_window_dims[i - collapsed]; + int64_t offset_dim = offset_or_window_dims[i - collapsed_or_batching]; passthrough_tile[i] = output_or_update_sharding.tile_assignment().dim(offset_dim); relevant_output_or_update_dims.push_back(offset_dim); @@ -1609,27 +1620,17 @@ GatherOutputShardingFromOperandOperandPassthroughDimensions( const Shape& operand_shape, const HloSharding& operand_sharding, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.gather_dimension_numbers(); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); return PassthroughOperandToGatherOutputOrScatterUpdate( - operand_shape, operand_sharding, hlo.shape().rank(), collapsed_slice_dims, - start_index_map, offset_dims, slice_sizes, dnums.index_vector_dim()); + operand_shape, operand_sharding, hlo.shape().rank(), + dnums.collapsed_slice_dims(), dnums.operand_batching_dims(), + dnums.start_index_map(), dnums.offset_dims(), slice_sizes, + dnums.index_vector_dim()); } std::optional GatherOperandShardingFromOutput( const HloSharding& output_sharding, const HloInstruction& hlo, const CallGraph& call_graph) { const auto& dnums = hlo.gather_dimension_numbers(); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); // Prioritize parallel sharding first as this is how it is in // spmd_partitioner. std::optional parallel_sharding = @@ -1637,8 +1638,10 @@ std::optional GatherOperandShardingFromOutput( call_graph); std::optional passthrough_sharding = PassthroughGatherOutputOrScatterUpdateToOperand( - hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims, - start_index_map, offset_dims, hlo.gather_slice_sizes()); + hlo.operand(0)->shape(), output_sharding, + dnums.collapsed_slice_dims(), dnums.operand_batching_dims(), + dnums.start_index_map(), dnums.offset_dims(), + hlo.gather_slice_sizes()); // Try to merge the two shardings or return the one that is present if only // one of the two is. if (!passthrough_sharding) { @@ -1677,19 +1680,13 @@ std::vector GetScatterSliceSize(const Shape& operand_shape, std::optional ScatterOutputShardingFromUpdate( const HloSharding& update_sharding, const HloScatterInstruction& scatter) { const auto& dnums = scatter.scatter_dimension_numbers(); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); std::vector slice_size = GetScatterSliceSize(scatter.scatter_operands()[0]->shape(), scatter.scatter_updates()[0]->shape(), dnums); return PassthroughGatherOutputOrScatterUpdateToOperand( scatter.scatter_operands()[0]->shape(), update_sharding, - inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims, + dnums.inserted_window_dims(), dnums.input_batching_dims(), + dnums.scatter_dims_to_operand_dims(), dnums.update_window_dims(), slice_size); } @@ -1744,18 +1741,12 @@ ScatterUpdateShardingFromOutputOperandPassthroughDimensions( const HloScatterInstruction* scatter = DynCast(&hlo); CHECK(scatter); const auto& dnums = scatter->scatter_dimension_numbers(); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); return PassthroughOperandToGatherOutputOrScatterUpdate( output_shape, output_sharding, - scatter->scatter_updates()[0]->shape().rank(), inserted_window_dims, - scatter_dims_to_operand_dims, update_window_dims, slice_sizes, - dnums.index_vector_dim()); + scatter->scatter_updates()[0]->shape().rank(), + dnums.inserted_window_dims(), dnums.input_batching_dims(), + dnums.scatter_dims_to_operand_dims(), dnums.update_window_dims(), + slice_sizes, dnums.index_vector_dim()); } std::optional ScatterUpdateShardingFromOutputParallelDimensions( @@ -2404,46 +2395,30 @@ absl::InlinedVector GetGatherOperandPassthroughOperandDims( const Shape& operand_shape, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.gather_dimension_numbers(); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); return GetGatherScatterOperandPassthroughOperandDims( - operand_shape, collapsed_slice_dims, start_index_map, offset_dims, - slice_sizes); + operand_shape, dnums.collapsed_slice_dims(), + dnums.operand_batching_dims(), dnums.start_index_map(), + dnums.offset_dims(), slice_sizes); } absl::InlinedVector GetScatterOperandPassthroughOperandDims( const Shape& operand_shape, const HloSharding& operand_sharding, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.scatter_dimension_numbers(); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); return GetGatherScatterOperandPassthroughOperandDims( - operand_shape, inserted_window_dims, scatter_dims_to_operand_dims, - update_window_dims, slice_sizes); + operand_shape, dnums.inserted_window_dims(), dnums.input_batching_dims(), + dnums.scatter_dims_to_operand_dims(), dnums.update_window_dims(), + slice_sizes); } absl::InlinedVector GetGatherOperandPassthroughOutputDims( const Shape& output_shape, const Shape& operand_shape, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.gather_dimension_numbers(); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); return GetGatherScatterOperandPassthroughOutputOrUpdateDims( - output_shape.rank(), operand_shape, collapsed_slice_dims, start_index_map, - offset_dims, slice_sizes); + output_shape.rank(), operand_shape, dnums.collapsed_slice_dims(), + dnums.operand_batching_dims(), dnums.start_index_map(), + dnums.offset_dims(), slice_sizes); } absl::InlinedVector GetScatterOperandPassthroughUpdateDims( @@ -2451,16 +2426,10 @@ absl::InlinedVector GetScatterOperandPassthroughUpdateDims( const HloSharding& operand_sharding, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.scatter_dimension_numbers(); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); return GetGatherScatterOperandPassthroughOutputOrUpdateDims( - update_shape.rank(), operand_shape, inserted_window_dims, - scatter_dims_to_operand_dims, update_window_dims, slice_sizes); + update_shape.rank(), operand_shape, dnums.inserted_window_dims(), + dnums.input_batching_dims(), dnums.scatter_dims_to_operand_dims(), + dnums.update_window_dims(), slice_sizes); } absl::InlinedVector GetGatherScatterIndexPassthroughIndexDims( diff --git a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc index fd45e649f18cd2..7bac1768aebf30 100644 --- a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc +++ b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include +#include +#include #include #include @@ -266,6 +267,10 @@ absl::StatusOr PartitionGatherIndexPassthroughDimensions( const HloSharding& output_sharding, absl::Span batch_dims, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, bool allow_recursive) { + if (indices.sharding().IsTileMaximal()) { + return nullptr; + } + // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -274,7 +279,7 @@ absl::StatusOr PartitionGatherIndexPassthroughDimensions( } }; - GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); SpmdBuilder* b = visitor->builder(); absl::InlinedVector index_group_dims = hlo_sharding_util::GetGatherScatterIndexPassthroughIndexDims( @@ -300,10 +305,6 @@ absl::StatusOr PartitionGatherIndexPassthroughDimensions( AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(indices.sharding(), index_group_dims), output_grouped); - - if (indices.sharding().IsTileMaximal()) { - return nullptr; - } // See if we can group partially replicated dimensions from the operand // otherwise replicate it. const GroupedSharding operand_grouped = AlignGroupsWith( @@ -430,7 +431,7 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( }; SpmdBuilder* b = visitor->builder(); - GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); std::vector start_index_map(dnums.start_index_map().begin(), dnums.start_index_map().end()); if (std::optional> trivial_slice_dims = @@ -579,16 +580,14 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( return nullptr; } -// Partition a gather over a indices dimensions that are cosidered parallel -// (which means that the indices access the operand in a monotonically -// increasing way across the respective operand dimension referenced by the -// index). -absl::StatusOr PartitionGatherIndexParallelDimensions( +absl::StatusOr PartitionGatherBatchDimensions( const HloGatherInstruction* gather, PartitionedHlo operand, PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, - bool allow_recursive) { + bool allow_recursive, + const hlo_sharding_util::GatherScatterParallelDims& parallel_dims, + bool need_offset) { // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -597,126 +596,184 @@ absl::StatusOr PartitionGatherIndexParallelDimensions( } }; + auto gather_sharding = GatherScatterOperandsShardedAcrossParallelDims( + *operand.hlo(), *indices.hlo(), parallel_dims); + if (!gather_sharding.has_value()) { + return nullptr; + } + SpmdBuilder* b = visitor->builder(); - GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); const int64_t index_dim = dnums.index_vector_dim(); - // Handle the case where operand is tile maximal. In this case we check if - // the index is not TileMaximal and in this case we use the index sharding - // to drive the output sharding. - if (std::optional - parallel_dims = hlo_sharding_util::GetGatherParallelBatchDims( - *gather, visitor->call_graph())) { - if (auto gather_sharding = GatherScatterOperandsShardedAcrossParallelDims( - *operand.hlo(), *indices.hlo(), *parallel_dims)) { - const auto indices_parallel_dims = parallel_dims->indices_parallel_dims; - const auto operand_parallel_dims = parallel_dims->operand_parallel_dims; - const auto output_parallel_dims = - hlo_sharding_util::GetGatherParallelOutputDims(*gather, - *parallel_dims); - operand = operand.Reshard(gather_sharding->operand_sharding); - indices = indices.Reshard(gather_sharding->indices_sharding); - HloSharding gather_output_sharding = hlo_sharding_util:: - GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( - indices.sharding(), output_shape.rank(), indices_parallel_dims, - output_parallel_dims); - // Refine output sharding from the operand. it should be inferred from - // operand sharding, so that the partitioned gather can be either 1) - // directly created on the partitioned operand, or 2) recursively created - // without aligning the groups. - if (auto maybe_passthrough = hlo_sharding_util:: - GatherOutputShardingFromOperandOperandPassthroughDimensions( - operand.base_shape(), - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - operand.sharding(), operand_parallel_dims), - *gather, slice_sizes)) { - hlo_sharding_util::MergeShardingIfCompatible(*maybe_passthrough, - &gather_output_sharding); - } - // Construct the offsets for the operand sharding to be used to adjust - // the indices. Because we know the only dimensions partitioned are the - // parallel ones and because the partitioning is the same across indices - // and operands we can apply the offsets on the operands on the indices. - std::vector operand_offsets = MakePartitionOffsets( - operand.base_shape(), operand.sharding(), - operand.state().partition_id, b, operand_parallel_dims); - absl::InlinedVector index_offsets; - for (int start_idx = 0; start_idx < dnums.start_index_map_size(); - ++start_idx) { - HloInstruction* index_offset = - indices.rank() > index_dim - ? b->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(S32, {1}), - operand_offsets[dnums.start_index_map(start_idx)])) - : operand_offsets[dnums.start_index_map(start_idx)]; - index_offsets.push_back(index_offset); - } - HloInstruction* adjusted_indices = nullptr; - if (indices.rank() > index_dim) { - // Concatenate the offsets for the parallel dimensions to subtract. - adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(S32, - {indices.base_shape().dimensions(index_dim)}), - index_offsets, 0)); - } else { - CHECK_EQ(index_offsets.size(), 1); - adjusted_indices = index_offsets[0]; - } - if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { - adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(adjusted_indices->shape(), - indices.hlo()->shape().element_type()), - adjusted_indices)); - } - if (adjusted_indices->shape().rank() == 0) { - adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {})); - } else { - adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {index_dim})); - } - // Adjust indices by subtracting the offsets based on the partition id. - adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( - indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + const auto& indices_parallel_dims = parallel_dims.indices_parallel_dims; + const auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; + const auto output_parallel_dims = + hlo_sharding_util::GetGatherParallelOutputDims(*gather, parallel_dims); + operand = operand.Reshard(gather_sharding->operand_sharding); + indices = indices.Reshard(gather_sharding->indices_sharding); + HloSharding gather_output_sharding = hlo_sharding_util:: + GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( + indices.sharding(), output_shape.rank(), indices_parallel_dims, + output_parallel_dims); + if (!need_offset) { + hlo_sharding_util::MergeShardingIfCompatible( + hlo_sharding_util:: + GatherOutputShardingFromIndexIndexPassthroughDimensions( + indices.sharding(), gather), + &gather_output_sharding); + } + + // Refine output sharding from the operand. it should be inferred from + // operand sharding, so that the partitioned gather can be either 1) + // directly created on the partitioned operand, or 2) recursively created + // without aligning the groups. + if (auto maybe_passthrough = hlo_sharding_util:: + GatherOutputShardingFromOperandOperandPassthroughDimensions( + operand.base_shape(), + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + operand.sharding(), operand_parallel_dims), + *gather, slice_sizes)) { + hlo_sharding_util::MergeShardingIfCompatible(*maybe_passthrough, + &gather_output_sharding); + } + + // Construct the offsets for the operand sharding to be used to adjust + // the indices. Because we know the only dimensions partitioned are the + // parallel ones and because the partitioning is the same across indices + // and operands we can apply the offsets on the operands on the indices. + PartitionedHlo new_indices = indices; + if (need_offset) { + std::vector operand_offsets = MakePartitionOffsets( + operand.base_shape(), operand.sharding(), operand.state().partition_id, + b, operand_parallel_dims); + absl::InlinedVector index_offsets; + for (int start_idx = 0; start_idx < dnums.start_index_map_size(); + ++start_idx) { + HloInstruction* index_offset = + indices.rank() > index_dim + ? b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1}), + operand_offsets[dnums.start_index_map(start_idx)])) + : operand_offsets[dnums.start_index_map(start_idx)]; + index_offsets.push_back(index_offset); + } + HloInstruction* adjusted_indices = nullptr; + if (indices.rank() > index_dim) { + // Concatenate the offsets for the parallel dimensions to subtract. + adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, + {indices.base_shape().dimensions(index_dim)}), + index_offsets, 0)); + } else { + CHECK_EQ(index_offsets.size(), 1); + adjusted_indices = index_offsets[0]; + } + if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(adjusted_indices->shape(), + indices.hlo()->shape().element_type()), adjusted_indices)); - PartitionedHlo new_indices = indices.CloneWithNewHlo(adjusted_indices); - const GroupedSharding new_indices_grouped = - hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), - indices_parallel_dims); - const GroupedSharding operand_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - operand.sharding(), operand_parallel_dims), - new_indices_grouped); - const GroupedSharding output_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - gather_output_sharding, output_parallel_dims), - new_indices_grouped); - PartitionedHlo per_group_operand = - PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups); - PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo( - new_indices, new_indices_grouped, b, clean_ups); - const Shape pshape = GetPerGroupBaseShape(output_grouped, output_shape); - TF_ASSIGN_OR_RETURN( - HloInstruction * pgather, - PartitionGather(gather, per_group_operand, per_group_new_indices, - pshape, output_grouped.sharding, batch_dims, - slice_sizes, visitor, allow_recursive)); - if (allow_recursive) { - VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim"; - } - pgather->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped)); - return PartitionedHlo(pgather, output_shape, operand.state()) - .Reshard(output_sharding) - .hlo(); } + if (adjusted_indices->shape().rank() == 0) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {})); + } else { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {index_dim})); + } + // Adjust indices by subtracting the offsets based on the partition id. + adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + adjusted_indices)); + new_indices = indices.CloneWithNewHlo(adjusted_indices); } - return nullptr; + + const GroupedSharding new_indices_grouped = + hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), + indices_parallel_dims); + const GroupedSharding operand_grouped = + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + operand.sharding(), operand_parallel_dims), + new_indices_grouped); + const GroupedSharding output_grouped = + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + gather_output_sharding, output_parallel_dims), + new_indices_grouped); + PartitionedHlo per_group_operand = + PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups); + PartitionedHlo per_group_new_indices = + PerGroupPartitionedHlo(new_indices, new_indices_grouped, b, clean_ups); + const Shape pshape = GetPerGroupBaseShape(output_grouped, output_shape); + TF_ASSIGN_OR_RETURN( + HloInstruction * pgather, + PartitionGather(gather, per_group_operand, per_group_new_indices, pshape, + output_grouped.sharding, batch_dims, slice_sizes, visitor, + allow_recursive)); + if (allow_recursive) { + VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim"; + } + pgather->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped)); + return PartitionedHlo(pgather, output_shape, operand.state()) + .Reshard(output_sharding) + .hlo(); +} + +// Partition a gather over indices dimensions that are considered parallel +// (which means that the indices access the operand in a monotonically +// increasing way across the respective operand dimension referenced by the +// index). +absl::StatusOr PartitionGatherIndexParallelDimensions( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, + const HloSharding& output_sharding, absl::Span batch_dims, + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + std::optional parallel_dims = + hlo_sharding_util::GetGatherParallelBatchDims(*gather, + visitor->call_graph()); + if (!parallel_dims.has_value()) { + return nullptr; + } + return PartitionGatherBatchDimensions(gather, operand, indices, output_shape, + output_sharding, batch_dims, + slice_sizes, visitor, allow_recursive, + *parallel_dims, /*need_offset=*/true); +} + +// Partition a gather over explicit batch dimensions defined in +// operand_batching_dims and start_indices_batching_dims. +absl::StatusOr PartitionGatherExplicitBatchDimensions( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, + const HloSharding& output_sharding, absl::Span batch_dims, + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); + if (dnums.operand_batching_dims().empty()) { + return nullptr; + } + + hlo_sharding_util::GatherScatterParallelDims parallel_dims; + parallel_dims.operand_parallel_dims.assign( + dnums.operand_batching_dims().begin(), + dnums.operand_batching_dims().end()); + parallel_dims.indices_parallel_dims.assign( + dnums.start_indices_batching_dims().begin(), + dnums.start_indices_batching_dims().end()); + + return PartitionGatherBatchDimensions(gather, operand, indices, output_shape, + output_sharding, batch_dims, + slice_sizes, visitor, allow_recursive, + parallel_dims, /*need_offset=*/false); } // Returns a full list of partitioning methods used for gather. std::vector> GatherPartitionMethods() { - return {{PartitionGatherIndexParallelDimensions, + return {{PartitionGatherExplicitBatchDimensions, + "PartitionGatherExplicitBatchDimensions"}, + {PartitionGatherIndexParallelDimensions, "PartitionGatherIndexParallelDimensions"}, {PartitionGatherOperandPassthroughDimensions, "PartitionGatherOperandPassthroughDimensions"}, @@ -729,6 +786,8 @@ GatherPartitionMethods() { // Helper function to get the gather partitioning method. decltype(PartitionGather)* GetGatherPartitionMethod(PartitioningMethod method) { switch (method) { + case PartitioningMethod::kExplicitBatch: + return PartitionGatherExplicitBatchDimensions; case PartitioningMethod::kIndexParallel: return PartitionGatherIndexParallelDimensions; case PartitioningMethod::kOperandPassthrough: @@ -738,7 +797,7 @@ decltype(PartitionGather)* GetGatherPartitionMethod(PartitioningMethod method) { case PartitioningMethod::kIndexPassthrough: return PartitionGatherIndexPassthroughDimensions; default: - return PartitionGatherIndexParallelDimensions; + return PartitionGatherExplicitBatchDimensions; } } @@ -1038,7 +1097,7 @@ absl::StatusOr PartitionScatterIndexParallelDimensions( }; SpmdBuilder* b = visitor->builder(); - auto dnums = scatter->scatter_dimension_numbers(); + const auto& dnums = scatter->scatter_dimension_numbers(); const int64_t index_dim = dnums.index_vector_dim(); // Handle the case where operand is tile maximal. In this case we check if // the index is not TileMaximal and in this case we use the index sharding @@ -1276,7 +1335,7 @@ absl::StatusOr PartitionScatterIndexPassthroughDimensions( }; SpmdBuilder* b = visitor->builder(); - auto dnums = scatter->scatter_dimension_numbers(); + const auto& dnums = scatter->scatter_dimension_numbers(); // Parse non-variadic computation only. Vardiadic case will be replicated. const HloSharding original_indices_sharding = indices.sharding(); absl::InlinedVector index_group_dims = @@ -1410,7 +1469,7 @@ absl::StatusOr PartitionScatterTrivialSlicedOperandDimensions( }; SpmdBuilder* b = visitor->builder(); - auto dnums = scatter->scatter_dimension_numbers(); + const auto& dnums = scatter->scatter_dimension_numbers(); if (std::optional> trivial_slice_dims = GatherScatterOperandPartitionedOnTrivialSliceDims( operands[0], dnums.scatter_dims_to_operand_dims(), slice_sizes)) { @@ -1657,8 +1716,7 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { if (hlo->sharding().HasUniqueDevice()) { return DefaultAction(hlo); } - auto scatter = Cast(hlo); - auto dnums = scatter->scatter_dimension_numbers(); + const auto scatter = Cast(hlo); // Check all operands have the same shapes and shardings, and all updates have // the same shapes and shardings, and live with this assumption during scatter // partitioning. @@ -1724,7 +1782,8 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { } scatter_partition_method = options().scatter_partition_method; std::vector slice_sizes = hlo_sharding_util::GetScatterSliceSize( - operands[0].base_shape(), updates[0].base_shape(), dnums); + operands[0].base_shape(), updates[0].base_shape(), + scatter->scatter_dimension_numbers()); TF_ASSIGN_OR_RETURN( HloInstruction * pscatter, diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.h b/third_party/xla/xla/service/spmd/spmd_partitioner.h index 26ae71f44d21f4..fc19332c7ce3dc 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.h +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.h @@ -54,6 +54,7 @@ namespace spmd { // Enum representing the partitioning methods for gather and scatter. enum class PartitioningMethod { + kExplicitBatch, kIndexParallel, kOperandPassthrough, kTrivialSlicedOperand, diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index 3ed5b7bb5f0116..29adc287913b38 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -7730,6 +7729,72 @@ ENTRY entry { EXPECT_THAT(root, op::CollectivePermute(gather)); } +TEST_P(SpmdPartitioningTest, GatherExplicitBatchDims) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0), sharding={devices=[2,1,2,1]<=[2,2]T(1,0)} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,2,1,1]<=[4]} + ROOT %gather = f32[14,10,6,2] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,2}, sharding={devices=[2,2,1,1]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[5,3,7,4]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[7,5,6,2]"), op::Parameter(1)); + auto gather = AllOf(op::Shape("f32[7,5,6,2]"), op::Gather(input, indices)); + EXPECT_THAT(module->entry_computation()->root_instruction(), gather); +} + +TEST_P(SpmdPartitioningTest, GatherExplicitBatchAndOperandPassthroughDims) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0), sharding={devices=[2,1,1,2]<=[4]} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[1,2,1,1,2]<=[4] last_tile_dim_replicate} + ROOT %gather = f32[14,10,6,4] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,4}, sharding={devices=[1,2,1,2]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[5,3,14,2]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[14,5,6,2]"), op::Parameter(1)); + auto gather = AllOf(op::Shape("f32[14,5,6,2]"), op::Gather(input, indices)); + EXPECT_THAT(module->entry_computation()->root_instruction(), gather); +} + +TEST_P(SpmdPartitioningTest, GatherExplicitBatchAndIndexPassthroughDims) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0), sharding={devices=[1,1,2,1,2]<=[4] last_tile_dim_replicate} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,1,2,1]<=[4]} + ROOT %gather = f32[14,10,6,2] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,2}, sharding={devices=[2,1,2,1]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[10,3,7,4]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[7,10,3,2]"), op::Parameter(1)); + auto gather = AllOf(op::Shape("f32[7,10,3,2]"), op::Gather(input, indices)); + EXPECT_THAT(module->entry_computation()->root_instruction(), gather); +} + TEST_P(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { absl::string_view hlo_string = R"( HloModule module From 3984cffa3e8660dfeba232d5f3c2d84556790b05 Mon Sep 17 00:00:00 2001 From: Isha Arkatkar Date: Fri, 20 Sep 2024 19:26:35 -0700 Subject: [PATCH 088/483] Introduce Pad Op before tf.Split and Slice op after tf.concat ops to handle not-divisible sharding for SPMD. This is an alternate approach to using XLA ND Split/Concat ops. tf.Split and tf.Concat ops operate on a single dimension at a time. So the padding and slice ops are introduced accordingly. PiperOrigin-RevId: 677052836 --- .../mlir/tensorflow/tests/tpu_rewrite.mlir | 178 +++++++++++++++- .../tensorflow/utils/xla_sharding_util.cc | 199 +++++++++++++----- .../utils/xla_sharding_util_test.cc | 25 ++- 3 files changed, 344 insertions(+), 58 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 07e3eea2d4ea69..e7bd2191b344f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1891,19 +1891,19 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // Tests tile sharding of inputs with number of splits that does not evenly divide -// the input results in an error. +// the input results in an error, when shapes are not fully known. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { - func.func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { - %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + func.func @uneven_input_sharding_disallowed(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> } - func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { - %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + func.func @tpu0_func(%arg0: tensor, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor) -> (tensor<*xi32>, tensor<*xi1>) %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1> func.return %4, %3 : tensor<*xi32>, tensor<*xi1> @@ -2839,3 +2839,169 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func.return %4, %3 : tensor<*xi32>, tensor<*xi1> } } + +// ----- + +// Tests that outputs are correctly merged and fed from TPU computation for +// tiled output sharding with padding for concat ops. + +// The following OpSharding is used for TPU computation outputs in below test: +// Proto debug string: +// output 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// output 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_output + func.func @parallel_execute_with_tiled_output(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<128x10xi32>, %arg3: tensor<128x10xi32>) -> (tensor<128x5xi32>, tensor<10x5xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<128x10xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<128x10xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}> + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}> + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + // + // CHECK: %[[CONST_CONCAT3_DIM:.*]] = "tf.Const"() + // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2) + // CHECK: %[[CONST_SLICE_BEGIN:.*]] = "tf.Const"() + // dense<0> + // tensor<2xi64>}> : () -> tensor<2xi64> + // CHECK: %[[CONST_SLICE_SIZE:.*]] = "tf.Const"() + // dense<[128, 5]> : tensor<2xi64>}> : () -> tensor<2xi64> + // CHECK: "tf.Slice"(%[[CONCAT3_OUTPUT]], %[[CONST_SLICE_BEGIN]], %[[CONST_SLICE_SIZE]]) + // : (tensor<128x6xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<128x5xi32> + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<128x10xi32>) -> (tensor<128x5xi32>, tensor<10x5xi1>) + tf_device.return %1, %2 : tensor<128x5xi32>, tensor<10x5xi1> + } + func.return %0#0, %1#0 : tensor<128x5xi32>, tensor<10x5xi1> + } + func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<128x10xi32>, tensor<128x10xi32>) -> (tensor<128x10xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<10x5xi1>) -> tensor<10x5xi1> + func.return %4, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } +} + +// ----- + +// Tests inputs are correctly split and fed into TPU computation for tiled input +// sharding with padding. + +// The following OpSharding is used for TPU computation inputs in the below +// test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_input + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x9xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x9xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<128x10xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<128x10xi32>) + func.func @parallel_execute_with_tiled_input(%arg0: tensor<128x9xf32>, %arg1: tensor<128x9xf32>, %arg2: tensor<128x10xi32>, %arg3: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x9xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<128x10xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x9xf32>, [%arg2, %arg3] as %ri_2: tensor<128x10xi32>) {n = 2 : i32} { + // CHECK: %[[DEVICE_LAUNCH_OUT:[a-z0-9]+]] = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // + // CHECK: %[[PAD_SHAPE:[a-z0-9]+]] = "tf.Const"() + // CHECK: [0, 0], [0, 1] + // CHECK: : tensor<2x2xi64>}> : () -> tensor<2x2xi64> + // CHECK: %[[PAD_OUT:[a-z0-9]+]] = "tf.Pad"(%[[DEVICE_LAUNCH_OUT]], %[[PAD_SHAPE]]) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x9xf32>, tensor<2x2xi64>) -> tensor<128x10xf32> + // CHECK: %[[CONST_SPLIT_DIM:.*]] = "tf.Const"() <{value = dense<1> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor + // CHECK: %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[PAD_OUT]]) {ici_weight_distribution_mlir_bridge_marker = true, num_split = 2 : i32} : (tensor, tensor<128x10xf32>) -> (tensor<128x5xf32>, tensor<128x5xf32>) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}> + // + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1) + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}> + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + %1 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ + %identity = "tf.Identity"(%ri_1) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x9xf32>) -> tensor<128x9xf32> + tf_device.return %identity : tensor<128x9xf32> + }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x9xf32> + %2, %3 = "tf_device.cluster_func"(%1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x9xf32>, tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + tf_device.return %2, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } + func.return %0#0, %1#0 : tensor<128x10xi32>, tensor<10x5xi1> + } + func.func @tpu0_func(%arg0: tensor<128x9xf32>, %arg1: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x9xf32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<128x10xi32>, tensor<128x10xi32>) -> (tensor<128x10xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<10x5xi1>) -> tensor<10x5xi1> + func.return %4, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } +} + +// ----- + +// CHECK: "tf.Split" +// : (tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) +module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1857 : i32}} { + func.func @main(%arg0: tensor {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} { + %0 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster__train_helper", device = ""} : () -> tensor + %1 = "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> + %3:2 = tf_device.replicate {n = 2 : i32} { + %6 = "tf_device.cluster_func"(%1, %2) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], input_sharding_configuration = ["\08\03\1A\01\04\22\04\00\01\02\03", ""], num_cores_per_replica = 4 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> + tf_device.return %6 : tensor<*xf32> + } + %4 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor + } + func.func private @_func(%arg0: tensor<128x1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}, %arg1: tensor<1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = ""}) -> (tensor<*xf32> {mhlo.sharding = ""}) { + %0 = "tf.XlaSharding"(%arg0) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + %1 = "tf.MatMul"(%0, %arg1) : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> + return %1 : tensor<*xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index d10b908e02d3c3..8913a1812b9c99 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -73,15 +74,93 @@ int64_t GetPadding(const int split_dim, const int num_splits, return total_padding; } +mlir::TF::SliceOp CreateSliceOp(mlir::OpBuilder* builder, + const mlir::Location& location, + mlir::Value input, + const PartialTensorShape& shape) { + mlir::SmallVector slice_start_position; + for (int i = 0; i < shape.dims(); ++i) { + slice_start_position.push_back(0); + } + mlir::SmallVector slice_size; + for (int i = 0; i < shape.dims(); ++i) { + slice_size.push_back(shape.dim_size(i)); + } + + auto start_position_type = + mlir::RankedTensorType::get(shape.dims(), builder->getIntegerType(64)); + + auto start_position_op = builder->create( + input.getLoc(), mlir::DenseIntElementsAttr::get(start_position_type, + slice_start_position)); + + auto slice_size_op = builder->create( + input.getLoc(), mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get( + shape.dims(), builder->getIntegerType(64)), + slice_size)); + + auto slice_result_type = + mlir::RankedTensorType::get(slice_size, getElementTypeOrSelf(input)); + + return builder->create(input.getLoc(), slice_result_type, + input, start_position_op, + slice_size_op); +} + +mlir::TF::PadOp CreatePadOp(mlir::OpBuilder* builder, + const mlir::Location& location, int64_t num_dims, + int64_t split_dim, mlir::Value src_input, + int64_t padding) { + auto input_type = mlir::cast(src_input.getType()); + llvm::SmallVector padding_values; + std::vector padded_shape; + for (int i = 0; i < num_dims; ++i) { + // 0 padding in the beginning. + padding_values.push_back(0); + if (i == split_dim) { + // pad the split dimension to make the total size of the input equal to + // the total size of the split dimension. + padding_values.push_back(padding); + padded_shape.push_back(input_type.getShape()[i] + padding); + } else { + padding_values.push_back(0); + padded_shape.push_back(input_type.getShape()[i]); + } + } + auto padding_type = + mlir::RankedTensorType::get({num_dims, 2}, builder->getIntegerType(64)); + auto paddings = mlir::DenseIntElementsAttr::get(padding_type, padding_values); + auto paddings_value = builder->create(location, paddings); + mlir::SmallVector expand_shape(padded_shape.begin(), + padded_shape.end()); + + auto expand_result_type = + mlir::RankedTensorType::get(expand_shape, input_type.getElementType()); + + return builder->create(location, expand_result_type, + src_input, paddings_value); +} + // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways // in 'split_dimension' dimension and returns the split values. -mlir::LogicalResult CreateSplitOp(const int num_split, - const int split_dimension, - const mlir::Location& location, - mlir::Value src_input, - mlir::OpBuilder* builder, - mlir::TF::SplitOp* split_op, - bool is_ici_weight_dist_spmd) { +mlir::LogicalResult CreateSplitOp( + const int num_split, const int split_dimension, const int64_t padding, + const mlir::Location& location, mlir::Value src_input, + mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op, + bool is_ici_weight_dist_spmd) { + if (padding > 0) { + int64_t num_dims = + mlir::cast(src_input.getType()).getRank(); + auto pad_op = CreatePadOp(builder, location, num_dims, split_dimension, + src_input, padding); + if (is_ici_weight_dist_spmd) { + pad_op->setAttr(kICIWeightDistributionMlirBridgeMarker, + builder->getBoolAttr(true)); + } + src_input = pad_op.getResult(); + } + // Creates a const op to hold split dimension value. auto split_dim_type = mlir::RankedTensorType::get({}, builder->getIntegerType(32)); @@ -139,6 +218,7 @@ mlir::LogicalResult CreateSplitOp(const int num_split, // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`. mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension, const mlir::Location& location, + const int64_t padding, mlir::ArrayRef inputs, mlir::OpBuilder* builder) { // Creates a const op to hold concat dimension value. @@ -265,6 +345,22 @@ mlir::LogicalResult CreateXlaSplitNDOp(const mlir::Location& location, return mlir::success(); } +bool IsShapeKnown(mlir::TensorType type) { + if (!type.hasRank()) return false; + + bool shape_known = false; + for (int i = 0; i < type.getRank(); ++i) { + if (type.getShape()[i] == mlir::ShapedType::kDynamic) { + shape_known = false; + break; + } else { + shape_known = true; + } + } + + return shape_known; +} + mlir::LogicalResult HandleTileShardedInputsUsingXlaSplitOps( const mlir::Location& location, const xla::OpSharding& input_sharding, const mlir::Value& original_source, mlir::OpBuilder* builder, @@ -335,17 +431,27 @@ mlir::LogicalResult HandleTileShardedInputsUsingTfSplitOps( LOG(ERROR) << dimension_to_splits_map.status(); return mlir::failure(); } - + PartialTensorShape shape; + const auto input_type = + mlir::cast(original_source.getType()); + bool input_shape_known = IsShapeKnown(input_type); + if (input_shape_known) { + shape = PartialTensorShape(input_type.getShape()); + } for (const auto& dimension_and_num_splits : *dimension_to_splits_map) { const int dimension = dimension_and_num_splits.first; const int num_splits = dimension_and_num_splits.second; + int padding = input_shape_known + ? GetPadding(dimension, num_splits, + PartialTensorShape(input_type.getShape())) + : 0; // Creates root split op. if (split_ops_for_tiled_input.empty()) { mlir::TF::SplitOp root_split_op; - auto result = - CreateSplitOp(num_splits, dimension, location, original_source, - builder, &root_split_op, is_ici_weight_dist_spmd); + auto result = CreateSplitOp(num_splits, dimension, padding, location, + original_source, builder, &root_split_op, + is_ici_weight_dist_spmd); if (mlir::failed(result)) return mlir::failure(); split_ops_for_tiled_input.emplace_back(root_split_op); @@ -358,7 +464,7 @@ mlir::LogicalResult HandleTileShardedInputsUsingTfSplitOps( for (auto split_op : split_ops_for_tiled_input) { for (auto parent_split_output_value : split_op.getResults()) { mlir::TF::SplitOp child_split_op; - auto result = CreateSplitOp(num_splits, dimension, location, + auto result = CreateSplitOp(num_splits, dimension, padding, location, parent_split_output_value, builder, &child_split_op, is_ici_weight_dist_spmd); if (mlir::failed(result)) return mlir::failure(); @@ -827,7 +933,15 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( LOG(ERROR) << dimension_to_splits_map.status(); return mlir::failure(); } - + auto output_type = + mlir::cast(cluster_func_output.getType()); + PartialTensorShape shape; + bool output_shape_known = IsShapeKnown(output_type); + if (output_shape_known) { + shape = PartialTensorShape(output_type.getShape()); + } + bool has_paddings = false; + std::vector paddings; for (auto it = dimension_to_splits_map->rbegin(); it != dimension_to_splits_map->rend(); ++it) { int concat_dimension = it->first; @@ -837,12 +951,21 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( new_outputs.reserve(num_splits); for (int i = 0, end = outputs_to_merge.size(); i < end; i = i + num_splits) { + int64_t padding; + if (output_shape_known) { + padding = GetPadding(concat_dimension, num_splits, shape); + } else { + padding = 0; + } mlir::TF::ConcatOp concat_op = - CreateConcatOp(concat_dimension, location, + CreateConcatOp(concat_dimension, location, padding, llvm::ArrayRef{ outputs_to_merge.begin() + i, outputs_to_merge.begin() + i + num_splits}, builder); + + paddings.push_back(padding); + has_paddings |= padding > 0; new_outputs.emplace_back(concat_op.getResult()); } @@ -850,6 +973,12 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( } assert(outputs_to_merge.size() == 1); + if (has_paddings) { + // Add slice op to remove paddings. + mlir::TF::SliceOp slice_op = + CreateSliceOp(builder, location, outputs_to_merge[0], shape); + cluster_func_output.replaceAllUsesWith(slice_op.getResult()); + } cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]); return mlir::success(); } @@ -876,26 +1005,13 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( *tiled_logical_computation_type = cluster_func_output_type; break; } - if (use_xla_nd_ops) { - if (output_shape[dimension] % output_splits == 0) { - new_output_shape[dimension] = output_shape[dimension] / output_splits; - } else { - // Input will be padded to be divisible by output_splits, thus add 1 to - // the output shape. - new_output_shape[dimension] = - (output_shape[dimension] / output_splits) + 1; - } - } else { - if (output_shape[dimension] % output_splits != 0) { - mlir::emitError( - location, - llvm::formatv("incorrect output sharding received. " - "{0}-th dimension of the output must be " - "evenly divisible by {1}, got dimension " - "shape {2}", - dimension, output_splits, output_shape[dimension])); - } + if (output_shape[dimension] % output_splits == 0) { new_output_shape[dimension] = output_shape[dimension] / output_splits; + } else { + // Input will be padded to be divisible by output_splits, thus add 1 to + // the output shape. + new_output_shape[dimension] = + (output_shape[dimension] / output_splits) + 1; } } @@ -904,23 +1020,6 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( return mlir::success(); } - -bool IsShapeKnown(mlir::TensorType type) { - if (!type.hasRank()) return false; - - bool shape_known = false; - for (int i = 0; i < type.getRank(); ++i) { - if (type.getShape()[i] == mlir::ShapedType::kDynamic) { - shape_known = false; - break; - } else { - shape_known = true; - } - } - - return shape_known; -} - } // namespace bool AreInputOutputShapesStaticallyKnownForSplitSharding( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc index 84d5697c9a6c2b..a168ad9984041e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc @@ -139,7 +139,6 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { func.func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { - // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } @@ -165,6 +164,7 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { int num_cores_per_replica = 4; mlir::OpBuilder builder(&context); bool use_xla_nd_ops = true; + llvm::SmallVector, 4> input_list; auto result = tensorflow::ExtractInputsForLogicalDevices( num_cores_per_replica, cluster_func_op, &builder, use_xla_nd_ops, @@ -194,9 +194,30 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { // will appropriately add the values to the block. op->destroy(); + input_list.clear(); + // Expect error when use_xla_nd_ops is false. result = tensorflow::ExtractInputsForLogicalDevices( num_cores_per_replica, cluster_func_op, &builder, false, &input_list); - ASSERT_TRUE(failed(result)); + ASSERT_TRUE(succeeded(result)); + auto* split_op = input_list.front().front().getDefiningOp(); + ASSERT_TRUE(mlir::isa(split_op)); + + llvm::SmallVector split_inputs(split_op->getOperands()); + // Constant op for the split dimension + auto* const_op = split_inputs[0].getDefiningOp(); + ASSERT_TRUE(mlir::isa(const_op)); + // Pad op for the padding value to make it divisible by num_splits. + auto* pad_op = split_inputs[1].getDefiningOp(); + ASSERT_TRUE(mlir::isa(pad_op)); + llvm::SmallVector pad_inputs(pad_op->getOperands()); + auto* const_pad_value = pad_inputs[1].getDefiningOp(); + ASSERT_TRUE(mlir::isa(const_pad_value)); + // Destroy the ops to avoid error during block deletion (Same as above): + // use_empty() && "Cannot destroy a value that still has uses!" + split_op->destroy(); + const_op->destroy(); + pad_op->destroy(); + const_pad_value->destroy(); } } // namespace From d56766cb89e15c9ee5ca385bc9bf863aa5aa5047 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 20 Sep 2024 19:32:19 -0700 Subject: [PATCH 089/483] Add `SparseTensor` input validation to SparseCore conversion op. PiperOrigin-RevId: 677054398 --- .../tpu/kernels/sparse_core_preprocess_ops.cc | 31 ++++++++++++------- .../tpu/kernels/sparse_core_preprocess_ops.h | 2 +- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc index 8548a92efe0495..e25889827a49f3 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc @@ -92,7 +92,9 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, // tensor. } else if (indices_or_row_splits.dims() == 2 && indices_or_row_splits.NumElements() >= 0) { - // TODO(pineapplejuice233): Add checking logic for sparse tensor input. + // NOTE(mrry): Checking logic for SparseTensor inputs is in + // `ComputeRowIdsBeforePadding()`, to avoid an extra traversal of the + // indices matrix. } else if (indices_or_row_splits.dims() == 1 && indices_or_row_splits.NumElements() > 0) { // Ragged tensor. @@ -114,6 +116,7 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, const int32 total_id_count, + const int32 sample_count, int32* row_ids_before_padding) { // The only difference between dense tensor, sparse tensor and ragged tensor // is the row ids output. @@ -140,7 +143,14 @@ Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, if (current_row_id < previous_row_id) { return absl::InvalidArgumentError( "Invalid indices_or_row_splits input, indices of SparseTensor need " - "to be sorted in ascending order."); + "to be sorted in ascending (non-decreasing) order."); + } + if (current_row_id >= sample_count) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid indices_or_row_splits input, indices of SparseTensor " + "contained a row_id ", + current_row_id, " that was >= the sample count (", sample_count, + ").")); } *(row_ids_before_padding + i) = current_row_id; previous_row_id = current_row_id; @@ -309,9 +319,9 @@ class ConvertToCooTensorOp : public OpKernel { auto row_ids_before_dedup = std::make_unique(total_id_count); - OP_REQUIRES_OK( - ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, total_id_count, - row_ids_before_dedup.get())); + OP_REQUIRES_OK(ctx, ComputeRowIdsBeforePadding( + *indices_or_row_splits, total_id_count, + sample_count_, row_ids_before_dedup.get())); // Compute the rescaled gains for non-sum combiners. std::optional> gains_rescale = @@ -520,9 +530,8 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { "The number of minibatches per sparse core is ", num_minibatch_per_sc, ". But the max minibatches per sparse core is set to be ", max_minibatches_per_sc_, " which is smaller."))); - VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: " - << "program_key = '" << program_key << "'" - << ", table_name = '" << table_name_ << "'" + VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: " << "program_key = '" + << program_key << "'" << ", table_name = '" << table_name_ << "'" << ", max_ids = " << max_ids_per_partition << ", max_uniques = " << max_unique_ids_per_partition << ", num_minibatch_per_sc = " << num_minibatch_per_sc; @@ -1213,9 +1222,9 @@ void ConvertToListOfSparseCoreCooTensorsOp::Compute(OpKernelContext* ctx) { auto row_ids_before_dedup = std::unique_ptr( new std::remove_extent_t[total_id_count]); - OP_REQUIRES_OK( - ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, total_id_count, - row_ids_before_dedup.get())); + OP_REQUIRES_OK(ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, + total_id_count, sample_count_, + row_ids_before_dedup.get())); // Compute the rescaled gains for non-sum combiners. std::optional> gains_rescale = diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h index ce43521cbc5147..d3651d04de2d6e 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h @@ -55,7 +55,7 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, // Compute the row id list before padding. Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, - int32 total_id_count, + int32 total_id_count, int32 sample_count, int32* row_ids_before_padding); class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel { From bd3835e17de301a2fae97f4e4bed8f9c883384bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 20:50:15 -0700 Subject: [PATCH 090/483] Automated Code Change PiperOrigin-RevId: 677075333 --- tensorflow/lite/delegates/gpu/gl/compiler/BUILD | 9 +++++++++ .../lite/delegates/gpu/gl/compiler/compiled_node.cc | 1 + .../lite/delegates/gpu/gl/compiler/fuse_auto_input.cc | 4 ++-- .../delegates/gpu/gl/compiler/fuse_auto_input_test.cc | 3 +++ tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc | 4 ++-- .../lite/delegates/gpu/gl/compiler/fuse_inplace.cc | 1 + .../lite/delegates/gpu/gl/compiler/object_accessor.cc | 6 +++++- .../lite/delegates/gpu/gl/compiler/object_accessor.h | 1 + .../delegates/gpu/gl/compiler/object_accessor_test.cc | 2 ++ .../lite/delegates/gpu/gl/compiler/preprocessor.cc | 1 + .../lite/delegates/gpu/gl/compiler/preprocessor_test.cc | 1 + tensorflow/lite/delegates/gpu/gl/compiler/rename.cc | 3 +++ .../lite/delegates/gpu/gl/compiler/variable_accessor.cc | 3 +++ .../lite/delegates/gpu/gl/compiler/variable_accessor.h | 1 + .../delegates/gpu/gl/compiler/variable_accessor_test.cc | 2 +- 15 files changed, 36 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD index 8b13e0ff92cdbb..795dc219f9037d 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD +++ b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD @@ -25,6 +25,7 @@ cc_test( ], deps = [ ":preprocessor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) @@ -36,6 +37,7 @@ cc_library( deps = [ ":preprocessor", ":variable_accessor", + "//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/gl:object", @@ -53,8 +55,10 @@ cc_test( ], deps = [ ":object_accessor", + ":preprocessor", ":variable_accessor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:object", "//tensorflow/lite/delegates/gpu/gl:variable", "@com_google_absl//absl/types:variant", "@com_google_googletest//:gtest_main", @@ -178,6 +182,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", @@ -196,6 +201,9 @@ cc_test( deps = [ ":compiled_node", ":fuse_auto_input", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/gl:node_shader", "@com_google_absl//absl/types:any", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", @@ -225,6 +233,7 @@ cc_test( "tflite_not_portable_ios", ], deps = [ + ":preprocessor", ":variable_accessor", "//tensorflow/lite/delegates/gpu/common:types", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc index 75928dae5f204c..58cf0af1967136 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc index 761fb8b4602246..985da96ebff678 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc @@ -23,11 +23,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/types/any.h" -#include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/model.h" -#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input_test.cc index 403617366912ee..61c3114a3a0d88 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input_test.cc @@ -20,7 +20,10 @@ limitations under the License. #include #include #include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc index f227ab2147847a..6fae121e02cf36 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc @@ -23,12 +23,12 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" -#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h" #include "tensorflow/lite/delegates/gpu/gl/node_shader.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc index 1e27404b741bde..19e520be166f04 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc index 11228fd2efe58b..43e9fa83e4c9e1 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc @@ -21,12 +21,16 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h index 318709fe7ff235..74273a6864193e 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" #include "tensorflow/lite/delegates/gpu/gl/object.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc index 4bf9482436506a..fbca570d892f2f 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc @@ -23,7 +23,9 @@ limitations under the License. #include #include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc index 16db8945dece21..173c281e331fcd 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc index 95fcf6244606f4..d4b7cf4157916c 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc index fd418770104444..1a05bfa2d87050 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc @@ -21,13 +21,16 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" #include "tensorflow/lite/delegates/gpu/gl/object.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc index b55c480654146f..d1a7fd78e1a87b 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc @@ -22,8 +22,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" +#include "tensorflow/lite/delegates/gpu/gl/variable.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h index f6d5344d3b345e..0eb01c0ea284f5 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc index 0e8be2a577ba75..20ac0368c66644 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include #include #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" namespace tflite { namespace gpu { From 8c76d8df64fc68bb9d865e20ce0f7472cedf8886 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 21:37:25 -0700 Subject: [PATCH 091/483] Automated Code Change PiperOrigin-RevId: 677085654 --- tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD | 1 + .../mlir/quantization/tensorflow/python/pywrap_function_lib.cc | 1 - .../mlir/quantization/tensorflow/python/quantize_model.cc | 2 -- .../mlir/quantization/tensorflow/python/quantize_model.h | 1 + .../mlir/quantization/tensorflow/python/unfreeze_constants.cc | 1 - 5 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index c0a472ca8f2e26..41e2f04651aa20 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -93,6 +93,7 @@ cc_library( ":py_function_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc index fc181edb8a75f5..499a496c572153 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index e38310879184ef..016ba8dd41ce6f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -49,11 +49,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index a54e988c043aa3..9e36ce52f74cbc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { namespace quantization { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc index c3f5c32bdd9720..e7086c57ddc2c2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/core/platform/env.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace tensorflow { From 8402beda0ad3e6a44feb4dff6d5e959f23ac9306 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 21:51:34 -0700 Subject: [PATCH 092/483] Automated Code Change PiperOrigin-RevId: 677088039 --- tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc index cd29511d4d982f..11c7157925dad3 100644 --- a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc +++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc @@ -20,16 +20,10 @@ limitations under the License. #include #include -#include "absl/log/check.h" -#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/pjrt/cpu/cpu_client.h" -#include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" -#include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/tsl/framework/serving_device_selector.h" #include "xla/tsl/framework/test_util/mock_serving_device_selector.h" #include "xla/tsl/lib/core/status_test_util.h" From 344937159598602e91dce183f9a22fcb6792a863 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 22:22:10 -0700 Subject: [PATCH 093/483] Automated Code Change PiperOrigin-RevId: 677093692 --- tensorflow/compiler/mlir/tfrt/BUILD | 11 +++++++++++ .../compiler/mlir/tfrt/transforms/optimize.cc | 16 ++++++++++++++++ .../optimize_tf_control_flow_side_effect.cc | 11 ++++++++++- .../compiler/mlir/tfrt/transforms/passes.cc | 6 ++++++ .../compiler/mlir/tfrt/transforms/passes.h | 4 ++++ .../tfrt/transforms/remove_device_attribute.cc | 10 ++++++---- .../tfrt/transforms/remove_tf_if_const_args.cc | 17 ++++++++++++++++- .../mlir/tfrt/transforms/reorder_assert.cc | 10 +++++++++- .../set_shape_invariant_in_while_ops.cc | 4 +++- .../tfrt/transforms/sink_in_invariant_ops.cc | 11 ++++------- .../compiler/mlir/tfrt/transforms/tf_to_tfrt.cc | 17 +++++++++++++++++ .../transforms/update_op_cost_in_tfrt_mlir.cc | 3 +++ .../compiler/mlir/tfrt/transforms/utils.cc | 9 +++++---- .../mlir/tfrt/transforms/xla_rewrite_pass.cc | 11 +++++++++++ 14 files changed, 121 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 5c1519e43c2bcb..ff635198f2071d 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -228,6 +228,7 @@ cc_library( ":cost_analysis", ":fallback_converter", ":tensor_array_side_effect_analysis", + ":tfrt_compile_options", ":tfrt_pipeline_options", ":tpu_passes", ":transform_utils", @@ -248,8 +249,12 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "//tensorflow/compiler/tf2xla:tf2xla_defs", "//tensorflow/core:framework", + "//tensorflow/core/ir/types:Dialect", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", @@ -257,10 +262,12 @@ cc_library( "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:errors", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", "@tf_runtime//:stream_analysis", @@ -504,7 +511,9 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) @@ -658,7 +667,9 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow/ir/host_runtime:tensorflow_tfrt_ops", "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc index 0e47fad312c7cc..202aa9c8d2f9ec 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc @@ -12,11 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace tfrt_compiler { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc index 1c227426ce99e1..c9e3c79d6b7b79 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc @@ -13,9 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index 2b97ec6a9536ac..cda19dc2157651 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -21,14 +21,20 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/Passes.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.h b/tensorflow/compiler/mlir/tfrt/transforms/passes.h index a883db5e479268..7b1f322712fd47 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.h @@ -21,7 +21,11 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc b/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc index c765e08742ae24..6905ede4f2dbca 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc @@ -15,13 +15,15 @@ limitations under the License. // This pass removes the device attribute from every corert.executeop. -#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc b/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc index d855fa41344905..d60a09b14e5656 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc @@ -13,8 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc b/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc index eb6eb9dbdaee77..69fb2a858b092d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc @@ -12,8 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc index a80d1ba7e180ef..f8343c034e0ce3 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc index 8bdb39c913bf75..5645bdf16c11fe 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc @@ -22,8 +22,6 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/MapVector.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -31,17 +29,16 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index f090745e0ae1c4..0f71991b2f8f82 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -29,24 +29,36 @@ limitations under the License. #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h" @@ -58,10 +70,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" #include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime diff --git a/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc index 24dcb1904fa4ed..bebd279e28fadf 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" namespace tensorflow { namespace tfrt_compiler { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc index 711438f21d13f9..cf65e50af55abb 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc @@ -18,19 +18,20 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" -#include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime -#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime -#include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc index cc28f332cb96aa..57b07c69bf2b55 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc @@ -16,11 +16,22 @@ limitations under the License. #include #include +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/core/ir/types/dialect.h" namespace tensorflow { namespace tfrt_compiler { From 7b9d17358665bfd12d5d8520af214ad7e4522d16 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 23:22:56 -0700 Subject: [PATCH 094/483] IFRT proxy: Some refactoring for an upcoming change. This CL does the following: 1. Remove some unused fields in RequestMetadata. 2. Make op_id in the RequestMetadata be generated within client_session. In addition to making things easier for the upcoming CL, this change makes the entirety of the RequestMetadata and the ResponseMetadata be managed within client_session, instead of having that logic spread across. 3. Separate out convenience classes from the current code that will be reused by the upcoming CL: a `Queue` implementation in test_utils.h and a `XFlowHelper` in rpc_helper.cc. PiperOrigin-RevId: 677105612 --- .../xla/xla/python/ifrt_proxy/client/BUILD | 1 + .../ifrt_proxy/client/grpc_client_session.cc | 10 +- .../ifrt_proxy/client/grpc_client_session.h | 4 + .../client/grpc_client_session_test.cc | 91 ++++--------- .../python/ifrt_proxy/client/rpc_helper.cc | 125 +++++++++++------- .../xla/python/ifrt_proxy/client/rpc_helper.h | 4 - .../xla/xla/python/ifrt_proxy/common/BUILD | 11 ++ .../ifrt_proxy/common/ifrt_service.proto | 23 +--- .../xla/python/ifrt_proxy/common/test_utils.h | 96 ++++++++++++++ 9 files changed, 221 insertions(+), 144 deletions(-) create mode 100644 third_party/xla/xla/python/ifrt_proxy/common/test_utils.h diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 8e947a4e68beac..1ca0e1527fbb26 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -64,6 +64,7 @@ ifrt_proxy_cc_test( "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:test_utils", "@com_github_grpc_grpc//:gpr", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc index d673a5f7561829..b555bc62d29398 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc @@ -40,7 +40,6 @@ #include "grpcpp/support/channel_arguments.h" #include "xla/pjrt/distributed/util.h" #include "xla/python/ifrt/future.h" -#include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/common/grpc_credentials.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" @@ -54,8 +53,6 @@ namespace xla { namespace ifrt { namespace proxy { -using OpId = int64_t; - // Logically equivalent to a map, but thread-safe and // with various convenience functions. class GrpcClientSession::ResponseCallbackTable { @@ -146,9 +143,9 @@ Future> GrpcClientSession::Enqueue( absl::Status GrpcClientSession::Enqueue(std::unique_ptr req, ResponseCallback callback) { - const OpId op_id = req->request_metadata().op_id(); - absl::MutexLock l(&writer_mu_); + const OpId op_id = writer_next_op_id_++; + if (writes_stopped_) { return absl::FailedPreconditionError( "GrpcClientSession: writes no longer allowed."); @@ -156,6 +153,9 @@ absl::Status GrpcClientSession::Enqueue(std::unique_ptr req, TF_RETURN_IF_ERROR(response_callbacks_->Add(op_id, std::move(callback))); + CHECK_EQ(req->mutable_request_metadata()->op_id(), 0); + req->mutable_request_metadata()->set_op_id(op_id); + if (!stream_->Write(*req)) { CHECK(response_callbacks_->Pop(op_id).has_value()); return absl::UnknownError("GrpcClientSession: writing to stream failed."); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h index 9e80b9ad850858..3187098bb6dd0a 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h @@ -17,6 +17,7 @@ #ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ #define XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ +#include #include #include @@ -116,6 +117,9 @@ class GrpcClientSession : public ClientSession { // only one thread is allowed to write to the gRPC stream at a time. absl::Mutex writer_mu_; + using OpId = uint64_t; + OpId writer_next_op_id_ ABSL_GUARDED_BY(writer_mu_) = 1; + // Ensures logic inside `Finish()` is internally called only once. absl::once_flag finish_once_; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc index 882f7b271841d0..61039d9e93b116 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc @@ -15,7 +15,6 @@ #include "xla/python/ifrt_proxy/client/grpc_client_session.h" #include -#include #include #include #include @@ -49,6 +48,7 @@ #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/test_utils.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status_matchers.h" @@ -64,9 +64,6 @@ namespace { using ::testing::Not; using ::tsl::testing::IsOk; -constexpr int kOp1 = 1; -constexpr int kOp2 = 2; - // Sufficient time for all processing (that are not explicitly waiting for // further input) to have finished. constexpr absl::Duration kSufficientTime = absl::Seconds(5); @@ -79,49 +76,8 @@ GrpcIfrtSessionMetadata Metadata() { absl::Status TestError() { return absl::UnknownError("test error"); } -// A thread-safe queue of `absl::Status` values. -class Queue { - public: - void Push(absl::Status t) { - absl::MutexLock l(&mu_); - queue_.push_back(std::move(t)); - } - - std::optional PopOrTimeout( - absl::Duration timeout = kSufficientTime) { - absl::MutexLock l(&mu_); - auto cond = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) -> bool { - return !queue_.empty(); - }; - mu_.AwaitWithTimeout(absl::Condition(&cond), timeout); - if (queue_.empty()) { - return std::nullopt; - } - absl::Status result = std::move(queue_.front()); - queue_.pop_front(); - return result; - } - - absl::Status Pop(absl::Duration timeout = kSufficientTime) { - auto result = PopOrTimeout(timeout); - CHECK(result.has_value()) << "Timeout!"; - return *result; - } - - void PopAllDuringDestruction() { - absl::MutexLock l(&mu_); - allow_non_empty_destruction_ = true; - } - - ~Queue() { - absl::MutexLock l(&mu_); - if (!allow_non_empty_destruction_) CHECK(queue_.empty()) << " " << this; - } - - private: - absl::Mutex mu_; - std::deque queue_ ABSL_GUARDED_BY(mu_); - bool allow_non_empty_destruction_ ABSL_GUARDED_BY(mu_) = false; +struct Queue : public TestQueue { + Queue() : TestQueue(kSufficientTime) {} }; // Checks that the input is a list of zero-or-more OK statuses followed by @@ -252,7 +208,7 @@ class ClientAndServer { client_finished_notification_.Notify(); }); - client_finished_q_.PopAllDuringDestruction(); + client_finished_q_.AllowNonEmptyDestruction(/*allow=*/true); } void StopServer() { @@ -273,12 +229,11 @@ class ClientAndServer { Queue* client_finished_q() { return &client_finished_q_; } - absl::StatusOr SendSimpleRequest(int op_id) { + absl::StatusOr SendSimpleRequest() { owned_queues_.push_back(std::make_unique()); Queue* q = owned_queues_.back().get(); auto req = std::make_unique(); - req->mutable_request_metadata()->set_op_id(op_id); TF_RETURN_IF_ERROR(client_session_->Enqueue( std::move(req), [q](absl::StatusOr resp) { q->Push(resp.status()); @@ -300,7 +255,7 @@ class ClientAndServer { TEST(GrpcClientSessionTest, HappyCaseOneRequestWithServerTermination) { ClientAndServer cs; - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest()); EXPECT_THAT(response_q->Pop(), IsOk()); @@ -313,8 +268,8 @@ TEST(GrpcClientSessionTest, HappyCaseOneRequestWithServerTermination) { TEST(GrpcClientSessionTest, HappyCaseTwoRequestsWithClientFinish) { ClientAndServer cs; - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_2, cs.SendSimpleRequest(kOp2)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_2, cs.SendSimpleRequest()); EXPECT_THAT(response_q_1->Pop(), IsOk()); EXPECT_THAT(response_q_2->Pop(), IsOk()); @@ -329,10 +284,10 @@ TEST(GrpcClientSessionTest, ServerFinishesDuringFirstRead) { ClientAndServer cs( /*on_req_received=*/[](auto, auto) { return kStopSession; }); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); EXPECT_THAT(response_q_1->Pop(), Not(IsOk())); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); EXPECT_THAT(response_q_2.status(), Not(IsOk())); EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); @@ -342,8 +297,8 @@ TEST(GrpcClientSessionTest, ServerFinishesDuringConstruction) { ClientAndServer cs(/*on_req_received=*/nullptr, /*on_session_start=*/[]() { return kStopSession; }); - absl::StatusOr response_q_1 = cs.SendSimpleRequest(kOp1); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + absl::StatusOr response_q_1 = cs.SendSimpleRequest(); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); ExpectHeadAndTail({response_q_1, response_q_2}); if (response_q_1.ok()) EXPECT_THAT(response_q_1.value()->Pop(), Not(IsOk())); @@ -361,10 +316,10 @@ TEST(GrpcClientSessionTest, ClientFinishesAfterServerConsumesFirstRequest) { }); session_ptr.store(cs.client_session()); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); EXPECT_THAT(response_q_1->Pop(), Not(IsOk())); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); EXPECT_THAT(response_q_2.status(), Not(IsOk())); EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); @@ -384,8 +339,8 @@ TEST(GrpcClientSessionTest, ClientFinishesAfterServerWritesFirstResponse) { }); session_ptr.store(cs.client_session()); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); // The client may or may not terminate before the first response arrives. response_q_1->Pop().IgnoreError(); @@ -413,8 +368,8 @@ TEST(GrpcClientSessionTest, ClientFinishesDuringServerConstruction) { session_ptr.store(cs.client_session()); init_done.Notify(); - absl::StatusOr response_q_1 = cs.SendSimpleRequest(kOp1); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + absl::StatusOr response_q_1 = cs.SendSimpleRequest(); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); if (response_q_1.ok()) { EXPECT_THAT(response_q_1.value()->Pop(), Not(IsOk())); @@ -431,19 +386,19 @@ TEST(GrpcClientSessionTest, ClientFinishesDuringServerConstruction) { TEST(GrpcClientSessionTest, MethodsAfterFinishReturnError) { ClientAndServer cs; - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); cs.client_session()->Finish(TestError()); - EXPECT_THAT(cs.SendSimpleRequest(kOp2), Not(IsOk())); + EXPECT_THAT(cs.SendSimpleRequest(), Not(IsOk())); - response_q_1->PopAllDuringDestruction(); + response_q_1->AllowNonEmptyDestruction(/*allow=*/true); } TEST(GrpcClientSessionTest, ReceivingBadIfrtResponseDoesNotCrash) { ClientAndServer cs( /*on_req_received=*/[](const IfrtRequest& r, ServerStream* s) mutable { IfrtResponse resp; - resp.mutable_response_metadata()->set_op_id(kOp2); + resp.mutable_response_metadata()->set_op_id(2000); s->Write(resp); resp.mutable_response_metadata()->set_op_id( r.request_metadata().op_id()); @@ -451,7 +406,7 @@ TEST(GrpcClientSessionTest, ReceivingBadIfrtResponseDoesNotCrash) { return kContinueSession; }); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest()); EXPECT_THAT(response_q->Pop(), IsOk()); } diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index ff116a759e31a6..409978966a6f4e 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include "absl/log/check.h" @@ -24,7 +25,6 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" @@ -38,38 +38,88 @@ namespace xla { namespace ifrt { namespace proxy { +namespace { + using ::tsl::profiler::XFlow; +// XFlowHelper makes it easier to create trace spans with a flow between them. +// Typical usage: +// +// XFlowHelper flow("my_request"); +// ... +// +// auto response_handler = [flow](ResponseMsg msg) { +// flow.InstantActivity(); +// LOG(INFO) << "Received response: " << msg; +// } +// +// { +// auto request_span = flow.Span(); +// auto request_protobuf = CreateRequestProtobuf(); +// transport.Send(request_protobuf, response_handler); +// } +// +// +class XFlowHelper { + public: + explicit XFlowHelper(absl::string_view name) + : xflow_id_(tsl::random::New64() >> 8 /*XFlow IDs are 56 bits*/), + name_(name) {} + + typedef enum { kSend, kRecv, kRecvSend } Direction; + + template + tsl::profiler::TraceMe Span() const { + return tsl::profiler::TraceMe([xflow_id = xflow_id_, name = name_] { + return Encode(xflow_id, name); + }); + } + + template + void InstantActivity() const { + return tsl::profiler::TraceMe::InstantActivity( + [xflow_id = xflow_id_, name = name_] { + return Encode(xflow_id, name); + }); + } + + private: + template + static std::string Encode(uint64_t xflow_id, absl::string_view name) { + static constexpr absl::string_view flow_dir_str = + D == kSend ? "send" : (D == kRecv ? "recv" : "recv_send"); + const XFlow flow(xflow_id, D == kRecvSend ? XFlow::kFlowInOut + : (D == kRecv ? XFlow::kFlowIn + : XFlow::kFlowOut)); + return tsl::profiler::TraceMeEncode( + name, {{"dir", flow_dir_str}, {"flow", flow.ToStatValue()}}); + }; + + const uint64_t xflow_id_; + const absl::string_view name_; +}; + +} // namespace + // DoRpc is a templated function that implements the logic of all RPC-wrapping // functions of `RpcHelper`, such as `RpcHelper::MakeArrayFromHostBuffer()`. template Future> DoRpc(ClientSession* session, - RequestMetadata metadata, void (IfrtRequest::*set_req)(Req*), Resp* (IfrtResponse::*get_resp)(), bool (IfrtResponse::*has_resp)() const, std::unique_ptr req, - absl::string_view profiling_send_name, - absl::string_view profiling_recv_name) { + absl::string_view profiling_name) { auto ifrt_req = std::make_unique(); - *ifrt_req->mutable_request_metadata() = metadata; (ifrt_req.get()->*set_req)(req.release()); - const uint64_t xflow_id = tsl::random::New64() >> 8; // XFlow IDs are 56 bits - tsl::profiler::TraceMe traceme([xflow_id, profiling_send_name]() { - const XFlow flow(xflow_id, XFlow::FlowDirection::kFlowOut); - return tsl::profiler::TraceMeEncode(profiling_send_name, - {{"flow", flow.ToStatValue()}}); - }); + XFlowHelper x_flow_helper(profiling_name); + auto traceme = x_flow_helper.Span(); auto promise = Future>::CreatePromise(); - auto on_ready = [promise, has_resp, get_resp, xflow_id, profiling_recv_name]( + auto on_ready = [promise, has_resp, get_resp, x_flow_helper]( absl::StatusOr> r) mutable { - tsl::profiler::TraceMe traceme([xflow_id, profiling_recv_name]() { - const XFlow flow(xflow_id, XFlow::FlowDirection::kFlowIn); - return tsl::profiler::TraceMeEncode(profiling_recv_name, - {{"flow", flow.ToStatValue()}}); - }); + auto traceme = x_flow_helper.Span(); if (!r.ok()) { LOG_EVERY_N_SEC(ERROR, 10) << "Connection to IFRT proxy server was terminated: " << r.status(); @@ -123,36 +173,13 @@ Future> DoRpc(ClientSession* session, return Future>(promise); } -RequestMetadata RpcHelper::ManufactureRequestMetadata() { - RequestMetadata result; - { - absl::MutexLock l(&mu_); - result.set_op_id(next_op_id_++); - } - int prev_op_id = result.op_id() - 1; - if (prev_op_id != 0) { - // TODO(b/266635130): Depend only on necessary prior operations. - result.add_dependencies(prev_op_id); - } - // TODO(b/282757875): Add a ClearOps RPC for old dependencies. - return result; -} - -void RpcHelper::Disconnect() { - session_->Finish(absl::CancelledError("Disconnected by client")); -} - -// TODO(b/266635130): Remove this preprocessor macro. Preprocessor macros -// go against the style guide, but are convenient as we are introducing more -// RPCs and are making changes to the exact signature of the DoRpc function. -#define RPC(METHOD, PROPERTY) \ - RpcHelper::ResponseFuture RpcHelper::METHOD( \ - std::unique_ptr req) { \ - return DoRpc(session_.get(), ManufactureRequestMetadata(), \ - &IfrtRequest::set_allocated_##PROPERTY##_request, \ - &IfrtResponse::mutable_##PROPERTY##_response, \ - &IfrtResponse::has_##PROPERTY##_response, std::move(req), \ - "" #PROPERTY "_send", "" #PROPERTY "_recv"); \ +#define RPC(METHOD, PROPERTY) \ + RpcHelper::ResponseFuture RpcHelper::METHOD( \ + std::unique_ptr req) { \ + return DoRpc( \ + session_.get(), &IfrtRequest::set_allocated_##PROPERTY##_request, \ + &IfrtResponse::mutable_##PROPERTY##_response, \ + &IfrtResponse::has_##PROPERTY##_response, std::move(req), #PROPERTY); \ } RPC(Init, init); @@ -193,6 +220,10 @@ Future<> RpcHelper::CheckFuture(uint64_t handle) { return Future<>(std::move(promise)); } +void RpcHelper::Disconnect() { + session_->Finish(absl::CancelledError("Disconnected by client")); +} + } // namespace proxy } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h index 3ed2a3eeb58d2b..d6eb5b1fcd2c58 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h @@ -23,8 +23,6 @@ #include "absl/base/thread_annotations.h" #include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" @@ -137,8 +135,6 @@ class RpcHelper { Future<> CheckFuture(uint64_t handle); private: - RequestMetadata ManufactureRequestMetadata() ABSL_LOCKS_EXCLUDED(mu_); - const IfrtProxyVersion version_; const std::shared_ptr session_; std::shared_ptr host_buffer_store_; diff --git a/third_party/xla/xla/python/ifrt_proxy/common/BUILD b/third_party/xla/xla/python/ifrt_proxy/common/BUILD index 134a8505419f8d..7e30c171af6249 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/common/BUILD @@ -180,6 +180,17 @@ cc_library( ], ) +cc_library( + name = "test_utils", + hdrs = ["test_utils.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + # copybara:uncomment_begin # bzl_library( # name = "ifrt_proxy_bzl", diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 278f9156f4eb89..0d73b688c0ee3d 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -141,27 +141,10 @@ message RequestMetadata { // resync after transient connectivity failures. fixed64 op_id = 1; - // List of one or more prior ops this current op is "dependent" - // upon. Currently this allows the client to define the order in which the - // server starts the execution of requests. Future versions may add other - // types of dependencies. For instance, a separate list of dependencies that - // must *complete* executing before the current one can start to execute. - // - // An op_id that has not yet been seen by the server is treated as an error - // that fails the op. - repeated fixed64 dependencies = 2; - - // UserContext is a basic provenance mechanism that allows the server-side - // actions and artifacts (say, allocating a buffer) to be associated with the - // corresponding client-side context that triggered those actions. - // - // The optional UserContextId is generated by the client and are used as an - // opaque label by the server and the run-time systems behind it. - // TODO(b/282757875): Add a pointer to Usercontext bugs/design doc. - fixed64 user_context_id = 3; - - // Additional implementation-specific payloads. + // Implementation-specific payloads. repeated google.protobuf.Any payloads = 4; + + reserved 2, 3; } // Metadata of an IFRT Response. diff --git a/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h new file mode 100644 index 00000000000000..8ecae77206529b --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h @@ -0,0 +1,96 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_TEST_UTILS_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_TEST_UTILS_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// TestQueue implements a thread-safe queue that manages values of type T. +template +class TestQueue { + public: + explicit TestQueue(absl::Duration pop_timeout) + : pop_timeout_(std::move(pop_timeout)) {} + + // Pushes `t` into the queue. + void Push(T t) { + absl::MutexLock l(&mu_); + queue_.push_back(std::move(t)); + } + + // Pops the first element in the queue if a element is already available or + // appears within `pop_timeout` (because `Push` is called). Otherwise returns + // std::nullopt. + std::optional PopOrTimeout() { + absl::MutexLock l(&mu_); + auto cond = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) -> bool { + return !queue_.empty(); + }; + mu_.AwaitWithTimeout(absl::Condition(&cond), pop_timeout_); + if (queue_.empty()) { + return std::nullopt; + } + T result = std::move(queue_.front()); + queue_.pop_front(); + return result; + } + + // Pops the first element in the queue if a element is already available or + // appears within `pop_timeout`, and fails otherwise. + T Pop() { + std::optional result = PopOrTimeout(); + CHECK(result.has_value()) << "Timeout!"; + return std::move(*result); + } + + // Sets whether the queue is allowed to be destructed while it contains + // unpopped elements. + void AllowNonEmptyDestruction(bool allow) { + absl::MutexLock l(&mu_); + allow_non_empty_destruction_ = allow; + } + + // Checks that the queue is either empty, or `AllowNonEmptyDestruction(true)` + // has been called. + ~TestQueue() { + absl::MutexLock l(&mu_); + if (!allow_non_empty_destruction_) CHECK(queue_.empty()) << " " << this; + } + + private: + const absl::Duration pop_timeout_; + + absl::Mutex mu_; + std::deque queue_ ABSL_GUARDED_BY(mu_); + bool allow_non_empty_destruction_ ABSL_GUARDED_BY(mu_) = false; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_TEST_UTILS_H_ From 3acaa94fae6bd53b8376cc687674c1da708239fc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 23:34:12 -0700 Subject: [PATCH 095/483] Automated Code Change PiperOrigin-RevId: 677108593 --- tensorflow/compiler/tf2xla/rearrange_function_argument.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index a081fa18891ba2..479d8a230644b3 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -160,7 +160,7 @@ Status ReorderOutputEdges(Graph* g, Node* n, int input_count, // Given mapping between original input index and rearranged input index, change // "index" attribute for _Arg nodes. void RearrangeArgNodes( - const gtl::InlinedVector* arg_nodes, // non-absl ok + const absl::InlinedVector* arg_nodes, // non-absl ok const std::vector& index_mapping) { for (int i = 0; i < arg_nodes->size(); i++) { Node* n = (*arg_nodes)[i]; @@ -177,7 +177,7 @@ void RearrangeArgNodes( // hold mapping from DT_RESOURCE _Retval index to its input _Arg index. Here we // assume that all DT_RESOURCE _Retval nodes come from _Arg nodes directly. Status CalculateRetvalRearrange( - const gtl::InlinedVector& ret_nodes, // non-absl ok + const absl::InlinedVector& ret_nodes, // non-absl ok std::map* retval_index_mapping, std::map* resource_retval_to_arg) { for (int i = 0, end = ret_nodes.size(); i < end; i++) { @@ -259,7 +259,7 @@ Status RearrangeOutputEdges(Node* n, Graph* g, // change "index" attribute for _Retval nodes. Notice that DT_RESOURCE _Retval // nodes will be removed. void RearrangeRetvalNodes( - const gtl::InlinedVector& ret_nodes, // non-absl ok + const absl::InlinedVector& ret_nodes, // non-absl ok Graph* g, const std::map& retval_index_mapping) { for (int i = 0, end = ret_nodes.size(); i < end; i++) { Node* n = ret_nodes[i]; From 3c92a64b99ff9972dc341504faa747bbecc5e908 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Sep 2024 23:52:17 -0700 Subject: [PATCH 096/483] Automated Code Change PiperOrigin-RevId: 677112101 --- third_party/xla/xla/stream_executor/host/host_platform.h | 1 + .../xla/xla/stream_executor/host/jit_host_kernel_function.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/third_party/xla/xla/stream_executor/host/host_platform.h b/third_party/xla/xla/stream_executor/host/host_platform.h index b8ce8f4340d6c4..3d6f09d3fb49b5 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.h +++ b/third_party/xla/xla/stream_executor/host/host_platform.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc b/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc index 76b9d1a42cb9da..1034abca258ee6 100644 --- a/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc +++ b/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc @@ -38,6 +38,7 @@ limitations under the License. #include "llvm/ExecutionEngine/ObjectCache.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" #include "llvm/ExecutionEngine/Orc/LLJIT.h" From 845e7e7985792f61c18d45ef22e8763c1967c6e9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Sep 2024 02:02:18 -0700 Subject: [PATCH 097/483] compat: Update forward compatibility horizon to 2024-09-21 PiperOrigin-RevId: 677145124 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 9561bd3a5f1fde..905b9f61683c96 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 20) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 21) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 92c4d48c2432ced97def2c186ef46f1a94bda57e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Sep 2024 02:02:18 -0700 Subject: [PATCH 098/483] Update GraphDef version to 1992. PiperOrigin-RevId: 677145125 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index b427efa92f342f..8d0996c93db008 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1991 // Updated: 2024/9/20 +#define TF_GRAPH_DEF_VERSION 1992 // Updated: 2024/9/21 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From d0777fa9a04ae3a19119c06c25d154b3d3e6d213 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Sep 2024 09:01:46 -0700 Subject: [PATCH 099/483] Automated Code Change PiperOrigin-RevId: 677219977 --- .../core/common_runtime/pluggable_device/pluggable_device.h | 2 +- .../pluggable_device/pluggable_device_context.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h index 74ad5893921ecc..67aa658a3fd9d8 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h @@ -84,7 +84,7 @@ class PluggableDevice : public LocalDevice { se::Stream* compute = nullptr; se::Stream* host_to_device = nullptr; se::Stream* device_to_host = nullptr; - gtl::InlinedVector device_to_device; + absl::InlinedVector device_to_device; }; class StreamGroupFactory; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h index 5798be1c13bc74..3baef39192487d 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h @@ -34,7 +34,7 @@ class PluggableDeviceContext : public DeviceContext { PluggableDeviceContext( int stream_id, se::Stream* stream, se::Stream* host_to_device_stream, se::Stream* device_to_host_stream, - gtl::InlinedVector device_to_device_stream) + absl::InlinedVector device_to_device_stream) : stream_id_(stream_id), stream_(stream), host_to_device_stream_(host_to_device_stream), @@ -81,7 +81,7 @@ class PluggableDeviceContext : public DeviceContext { // The stream to use for copying data from PluggableDevice to host. se::Stream* device_to_host_stream_; // Streams to use for copying data between PluggableDevices. - gtl::InlinedVector device_to_device_stream_; + absl::InlinedVector device_to_device_stream_; }; } // namespace tensorflow From 1fcc51367416e7dabe3fa5ed6df1756e594b8902 Mon Sep 17 00:00:00 2001 From: Qingtong Guo Date: Sat, 21 Sep 2024 13:12:31 -0700 Subject: [PATCH 100/483] Reverts 784164302f709dc2369aa3159d4c8a8666094a6a PiperOrigin-RevId: 677262779 --- third_party/xla/xla/autotuning.proto | 7 +- .../xla/xla/service/gpu/autotuning/BUILD | 11 +- .../gpu/autotuning/gemm_fusion_autotuner.cc | 293 +++++------------- .../gpu/autotuning/gemm_fusion_autotuner.h | 21 +- .../autotuning/gemm_fusion_autotuner_test.cc | 181 +---------- 5 files changed, 90 insertions(+), 423 deletions(-) diff --git a/third_party/xla/xla/autotuning.proto b/third_party/xla/xla/autotuning.proto index 4cadf6dbb250eb..a7ffcbb57ae6ef 100644 --- a/third_party/xla/xla/autotuning.proto +++ b/third_party/xla/xla/autotuning.proto @@ -83,10 +83,6 @@ message AutotuneResult { int64 num_ctas = 7; } - message CustomKernelFusionKey { - int64 kernel_index = 1; - } - int64 scratch_bytes = 8; google.protobuf.Duration run_time = 9; @@ -97,11 +93,10 @@ message AutotuneResult { GemmKey gemm = 6; TritonGemmKey triton = 17; CudaConvPlanKey cuda_conv_plan = 15; - CustomKernelFusionKey custom_kernel_fusion = 18; stream_executor.dnn.AlgorithmProto algorithm = 16; } - // Next ID: 19 + // Next ID: 17 } message AutotuningLog { diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index d162b1e8f0aded..be63f3888442af 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -45,11 +45,9 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/utils:hlo_query", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:algorithm_util", - "//xla/service:call_inliner", "//xla/service:dump", "//xla/service:executable", "//xla/service:float_normalization", @@ -60,15 +58,12 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:gpu_float_support", + "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:split_k_gemm_rewriter", "//xla/service/gpu:stream_executor_util", - "//xla/service/gpu/kernels:custom_kernel", - "//xla/service/gpu/kernels:custom_kernel_fusion", - "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", "//xla/service/gpu/transforms:cudnn_fusion_compiler", - "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:priority_fusion", @@ -77,9 +72,11 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:bits", "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -140,8 +137,6 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:semantic_version", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 8f041d6e8d27f2..79524924584c97 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -49,28 +51,24 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" -#include "xla/service/call_inliner.h" #include "xla/service/dump.h" +#include "xla/service/executable.h" #include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/kernels/custom_kernel_fusion.h" -#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" -#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/priority_fusion.h" @@ -84,6 +82,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -141,6 +140,76 @@ constexpr std::array kNumCtas = {1, 2, 4, 8, 16}; using AutoTuneCacheKeyCount = absl::flat_hash_map; +class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) + : config_(config) {} + + absl::Status HandleFusion(HloInstruction* hlo) override { + TF_ASSIGN_OR_RETURN(auto gpu_config, + hlo->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + if (backend_config.kind() != kTritonGemmFusionKind && + backend_config.kind() != kCuDnnFusionKind) { + return absl::OkStatus(); + } + + VLOG(4) << "Processing " << hlo->ToString(); + if (!backend_config.has_triton_gemm_config() && + !backend_config.has_cudnn_fusion_config()) { + TF_ASSIGN_OR_RETURN( + AutotuneResult autotune_result, + AutotunerUtil::Autotune( + hlo, config_, [&]() -> absl::StatusOr { + if (config_.IsDeviceless()) { + return absl::InternalError(absl::StrCat( + "Expect autotune result cache hit for deviceless " + "compilation (HLO: ", + hlo->ToString(), ")")); + } + return absl::InternalError("Expect autotune result cache hit."); + })); + VLOG(4) << "Result: " << autotune_result.ShortDebugString(); + + if (autotune_result.has_triton()) { + *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + } else if (autotune_result.has_gemm()) { + // Falling back to cuBLAS: Converting the fusion to a Call, so that it + // can be inlined back again. + HloComputation* const computation = hlo->parent(); + HloInstruction* const call = computation->AddInstruction( + HloInstruction::CreateCall(hlo->shape(), hlo->operands(), + hlo->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); + hlo = call; + } else { + CHECK(autotune_result.has_algorithm()); + backend_config.set_kind(std::string(kCuDnnFusionKind)); + backend_config.mutable_cudnn_fusion_config()->set_plan_id( + autotune_result.algorithm().algo_id()); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + } + } + + if (backend_config.has_triton_gemm_config()) { + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); + } + } + + MarkAsChanged(); + return absl::OkStatus(); + } + + private: + AutotuneConfig config_; +}; + class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { public: explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl) @@ -190,9 +259,7 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { bool missing_config = (backend_config.kind() == kTritonGemmFusionKind && !backend_config.has_triton_gemm_config()) || (backend_config.kind() == kCuDnnFusionKind && - !backend_config.has_cudnn_fusion_config()) || - (backend_config.kind() == kCustomFusionKind && - !backend_config.has_custom_fusion_config()); + !backend_config.has_cudnn_fusion_config()); if (missing_config) { if (error_out_on_cache_miss_) { return absl::NotFoundError(absl::StrCat( @@ -360,46 +427,6 @@ absl::StatusOr> CublasGemmAutotuneExtractor( return new_module; } -absl::Status UpdateFusionInstructionKernelIndex( - HloInstruction* fusion_instruction, int kernel_index) { - GpuBackendConfig gpu_config = - fusion_instruction->backend_config().value(); - gpu_config.mutable_fusion_backend_config() - ->mutable_custom_fusion_config() - ->set_kernel_index(kernel_index); - TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config)); - - return absl::OkStatus(); -} - -absl::StatusOr> CustomFusionKernelAutotuneExtractor( - const GemmFusionAutotunerImpl::CustomKernelFusionConfig& cutlass_config, - const AutotuneConfig& config, const se::SemanticVersion& toolkit_version, - const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { - const HloComputation* fusion_computation = fusion->called_computation(); - std::unique_ptr new_module = - ExtractComputationIntoNewModule(*fusion_computation); - new_module->mutable_config().set_debug_options(debug_opts); - - CustomKernelFusionRewriter rewriter( - &config.GetExecutor()->GetDeviceDescription()); - PriorityFusion fusion_pass( - /*thread_pool=*/nullptr, config.GetExecutor()->GetDeviceDescription(), - PriorityFusionOptions()); - TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); - TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); - - // Select custom kernel fusion kernel. - HloInstruction* custom_kernel_fusion = - hlo_query::GetFirstInstructionWithOpcode(*new_module->entry_computation(), - HloOpcode::kFusion); - int64_t kernel_index = cutlass_config.kernel_index; - TF_RETURN_IF_ERROR( - UpdateFusionInstructionKernelIndex(custom_kernel_fusion, kernel_index)); - - return new_module; -} - absl::StatusOr> FusionExtractor( const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); @@ -448,11 +475,6 @@ AutotuneResult FromConfig(const BackendConfig& config) { AutotuneResult res; if (std::holds_alternative(config)) { res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT); - } else if (std::holds_alternative< - GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config)) { - res.mutable_custom_kernel_fusion()->set_kernel_index( - std::get(config) - .kernel_index); } else if (std::holds_alternative( config)) { res.mutable_algorithm()->set_algo_id( @@ -552,98 +574,6 @@ std::string Serialize(const BackendConfig& config) { } // anonymous namespace -absl::Status RewriteGemmFusionToCall(HloInstruction* fusion_instr) { - // Falling back to cuBLAS: Converting the fusion to a Call, so that it - // can be inlined back again. - HloComputation* const computation = fusion_instr->parent(); - HloInstruction* const call = - computation->AddInstruction(HloInstruction::CreateCall( - fusion_instr->shape(), fusion_instr->operands(), - fusion_instr->fused_instructions_computation())); - return computation->ReplaceInstruction(fusion_instr, call); -} - -absl::Status RewriteGemmFusionToCustomKernelFusion( - HloInstruction* fusion_instr, se::DeviceDescription device_description, - int64_t kernel_index) { - // Rewrites gemm fusion to custom kernel fusion. - // First convert the fusion to a call. Then inlines the call. Then - // rewrites to custom kernel fusion. - HloComputation* const computation = fusion_instr->parent(); - HloInstruction* const call = - computation->AddInstruction(HloInstruction::CreateCall( - fusion_instr->shape(), fusion_instr->operands(), - fusion_instr->fused_instructions_computation())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(fusion_instr, call)); - HloPassPipeline pipeline("autotuner_custom_kernel_fusion_rewriter"); - pipeline.AddPass(); - pipeline.AddPass(&device_description, - kernel_index); - HloModule* hlo_module = call->GetModule(); - return pipeline.Run(hlo_module).status(); -} - -absl::Status GemmFusionAutotunerRewriterVisitor::HandleFusion( - HloInstruction* fusion_instr) { - TF_ASSIGN_OR_RETURN(auto gpu_config, - fusion_instr->backend_config()); - FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - if (backend_config.kind() != kTritonGemmFusionKind && - backend_config.kind() != kCuDnnFusionKind && - backend_config.kind() != kCustomFusionKind) { - return absl::OkStatus(); - } - - VLOG(4) << "Processing " << fusion_instr->ToString(); - if (!backend_config.has_triton_gemm_config() && - !backend_config.has_cudnn_fusion_config() && - !backend_config.has_custom_fusion_config()) { - TF_ASSIGN_OR_RETURN( - AutotuneResult autotune_result, - AutotunerUtil::Autotune( - fusion_instr, config_, [&]() -> absl::StatusOr { - if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( - "Expect autotune result cache hit for deviceless " - "compilation (HLO: ", - fusion_instr->ToString(), ")")); - } - return absl::InternalError("Expect autotune result cache hit."); - })); - VLOG(4) << "Result: " << autotune_result.ShortDebugString(); - - if (autotune_result.has_triton()) { - *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); - TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); - } else if (autotune_result.has_gemm()) { - TF_RETURN_IF_ERROR(RewriteGemmFusionToCall(fusion_instr)); - } else if (autotune_result.has_custom_kernel_fusion()) { - TF_RETURN_IF_ERROR(RewriteGemmFusionToCustomKernelFusion( - fusion_instr, config_.GetExecutor()->GetDeviceDescription(), - autotune_result.custom_kernel_fusion().kernel_index())); - } else { - CHECK(autotune_result.has_algorithm()); - backend_config.set_kind(std::string(kCuDnnFusionKind)); - backend_config.mutable_cudnn_fusion_config()->set_plan_id( - autotune_result.algorithm().algo_id()); - TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); - } - } - - if (backend_config.has_triton_gemm_config()) { - TF_ASSIGN_OR_RETURN( - const TritonGemmConfig config, - TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(fusion_instr, config)); - } - } - - MarkAsChanged(); - return absl::OkStatus(); -} - // Methods required for sorting the configs. bool GemmFusionAutotunerImpl::CuBlasConfig::operator<( const CuBlasConfig& other) const { @@ -653,10 +583,6 @@ bool GemmFusionAutotunerImpl::CuDnnConfig::operator<( const CuDnnConfig& other) const { return plan_id < other.plan_id; } -bool GemmFusionAutotunerImpl::CustomKernelFusionConfig::operator<( - const CustomKernelFusionConfig& other) const { - return false; -} bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { return debug_options_.xla_gpu_autotune_level() > 0 && @@ -677,48 +603,6 @@ bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { } } -std::vector GenerateCustomKernelFusionConfigs( - const HloFusionInstruction& fusion, - se::DeviceDescription device_description) { - std::vector configs; - const CustomKernelFusionPatternRegistry* patterns = - CustomKernelFusionPatternRegistry::Default(); - HloComputation* computation = fusion.called_computation(); - // Get the first dot instruction in the fusion body. - HloInstruction* dot_instruction = - hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); - std::vector match = - patterns->Match(device_description, dot_instruction); - - // For Cutlass we expect only one match for a gemm fusion. - if (match.size() == 1) { - CustomKernelFusionRegistry* registry = - CustomKernelFusionRegistry::Default(); - auto* custom_kernel_fusion = registry->Lookup(match[0].config().name()); - - // If custom fusion is not found it means that some of the build targets - // might not be statically linked into the binary. - if (custom_kernel_fusion != nullptr) { - // Load custom kernels that can implement a fusion computation. - absl::StatusOr> kernels = - custom_kernel_fusion->LoadKernels( - device_description, fusion.fused_instructions_computation()); - if (!kernels.ok()) { - VLOG(2) << "Skip custom kernel config. Failed to load custom kernels: " - << kernels.status(); - } else { - for (int i = 0; i < kernels.value().size(); ++i) { - GemmFusionAutotunerImpl::CustomKernelFusionConfig config{ - /*kernel_index=*/i}; - configs.push_back(config); - } - } - } - } - - return configs; -} - absl::StatusOr> GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { const HloDotInstruction* dot = @@ -758,19 +642,6 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { } } - // Add CustomKernelFusion (Cutlass) configs, if available. - // Go through all the instructions in the fusion body try to match them to - // a custom kernel fusion pattern. - if ((IsFusionKind(fusion, kCustomFusionKind) || - IsFusionKind(fusion, kTritonGemmFusionKind)) && - IsAutotuningEnabled() && !config_.IsDeviceless()) { - std::vector custom_kernel_fusion_configs = - GenerateCustomKernelFusionConfigs( - fusion, config_.GetExecutor()->GetDeviceDescription()); - configs.insert(configs.end(), custom_kernel_fusion_configs.begin(), - custom_kernel_fusion_configs.end()); - } - // Add triton configs. TF_ASSIGN_OR_RETURN(std::vector triton_configs, GenerateTritonConfigs(*dot)); @@ -934,14 +805,6 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, config_, config_.GetExecutor()->GetDeviceDescription(), toolkit_version_, fusion, opts); })); - } else if (std::holds_alternative(config)) { - TF_ASSIGN_OR_RETURN(executable, - compile_util.Compile([&](const DebugOptions& opts) { - return CustomFusionKernelAutotuneExtractor( - std::get(config), - config_, toolkit_version_, fusion, opts); - })); - } else { LOG(FATAL) << "Unsupported config type: " << config.index(); } @@ -1442,8 +1305,8 @@ absl::StatusOr GemmFusionAutotuner::Run( } } - return GemmFusionAutotunerRewriterVisitor(config_).RunOnModule( - module, execution_threads); + return GemmFusionAutotunerVisitor(config_).RunOnModule(module, + execution_threads); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index 17272607532c20..7c262ffc8c613b 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -29,9 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" @@ -48,18 +46,6 @@ limitations under the License. namespace xla { namespace gpu { -// Uses profile results to rewrite a gemm fusion to use the best backend. -class GemmFusionAutotunerRewriterVisitor : public DfsHloRewriteVisitor { - public: - explicit GemmFusionAutotunerRewriterVisitor(const AutotuneConfig& config) - : config_(config) {} - - absl::Status HandleFusion(HloInstruction* fusion_instr) override; - - private: - AutotuneConfig config_; -}; - // Takes a gemm fusion and chooses between cuBLAS, cuDNN, and Triton backends. // In the case of Triton, it also chooses the best tiling configuration. // @@ -113,13 +99,8 @@ class GemmFusionAutotunerImpl { int64_t plan_id; bool operator<(const CuDnnConfig& other) const; }; - struct CustomKernelFusionConfig { - int64_t kernel_index; - bool operator<(const CustomKernelFusionConfig& other) const; - }; using BackendConfig = - std::variant; + std::variant; using BackendConfigs = std::vector< std::pair>>; diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index cb6309ffe9b8ea..f47003ecea4256 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -50,9 +50,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/semantic_version.h" -#include "xla/stream_executor/stream_executor.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" @@ -197,25 +195,6 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { .cuda_compute_capability(); } - absl::StatusOr> - GetPossibleMatmulAutotuneConfigs( - const HloFusionInstruction& fusion, - const se::CudaComputeCapability& compute_capability, - const se::SemanticVersion& toolkit_version, - const DebugOptions& debug_options) { - se::GpuDeviceInfoProto deviceless_proto; - auto ccc = deviceless_proto.mutable_cuda_compute_capability(); - ccc->set_major(compute_capability.major); - ccc->set_minor(compute_capability.minor); - - DeviceConfig test_config{backend().default_stream_executor(), - backend().memory_allocator()}; - AutotuneConfig autotune_config{test_config, debug_options}; - GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, - debug_options, nullptr); - return autotuner.GenerateConfigs(fusion); - } - void CheckTritonAutotuning(absl::string_view hlo, absl::string_view expected) { HloPassPipeline pipeline("gemm_rewrite"); @@ -268,8 +247,7 @@ class GemmFusionAutotunerTestWithMorePreciseReduction } }; -absl::StatusOr> -GetPossibleMatmulAutotuneTritonConfigs( +absl::StatusOr> GetPossibleMatmulAutotuneConfigs( const HloDotInstruction& dot, const se::CudaComputeCapability& compute_capability, const se::SemanticVersion& toolkit_version, @@ -298,7 +276,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -320,7 +298,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -342,7 +320,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -897,7 +875,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -929,7 +907,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -960,7 +938,7 @@ ENTRY wais { TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), debug_options)); @@ -1024,151 +1002,6 @@ ENTRY entry { CHECK_OK(autotuner.CompileAll(*compile_util, configs)); } -TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { - const std::string kHlo = R"( - HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} - - %gemm_fusion_r_computation { - %parameter_0 = bf16[1024,1024]{1,0} parameter(0) - %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) - %parameter_1 = bf16[1024,1024]{1,0} parameter(1) - %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) - ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY main { - %p0 = bf16[1024,1024]{1,0} parameter(0) - %p1 = bf16[1024,1024]{1,0} parameter(1) - ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} - })"; - - std::unique_ptr module = - ParseAndReturnVerifiedModule(kHlo).value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::AMPERE, /*minor=*/0}; - - TF_ASSERT_OK_AND_ASSIGN( - const std::vector configs, - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); - EXPECT_TRUE(std::any_of( - configs.begin(), configs.end(), - [](const GemmFusionAutotunerImpl::BackendConfig& config) { - return std::holds_alternative< - GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); - })); -} - -TEST_F(GemmFusionAutotunerTest, - IgnoreCustomKernelFusionConfigIfKernelNotFound) { - // There are cases where the custom kernel fusion pattern is matched, but - // the kernel is not found. Make sure that the autotuner ignores this case. - const std::string kHlo = R"( - HloModule module - - %gemm_fusion_r_computation (parameter_0.1: f32[1,256,4,4096], parameter_1.1: bf16[1,4,4096,4096]) -> bf16[1048576] { - %parameter_0.1 = f32[1,256,4,4096]{3,2,1,0} parameter(0) - %bitcast.60 = f32[256,16384]{1,0} bitcast(f32[1,256,4,4096]{3,2,1,0} %parameter_0.1) - %parameter_1.1 = bf16[1,4,4096,4096]{3,2,1,0} parameter(1) - %bitcast.61 = bf16[16384,4096]{1,0} bitcast(bf16[1,4,4096,4096]{3,2,1,0} %parameter_1.1) - %convert.22 = f32[16384,4096]{1,0} convert(bf16[16384,4096]{1,0} %bitcast.61) - %dot.5 = f32[256,4096]{1,0} dot(f32[256,16384]{1,0} %bitcast.60, f32[16384,4096]{1,0} %convert.22), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(_einsum)/jit(main)/dot_general" source_file="gdm/jax/tokamax/xla/utils.py" source_line=33} - %convert.23 = bf16[256,4096]{1,0} convert(f32[256,4096]{1,0} %dot.5), metadata={op_name="jit(_einsum)/jit(main)/convert_element_type" source_file="gdm/jax/tokamax/xla/utils.py" source_line=33} - %bitcast.62 = bf16[1,256,4096]{2,1,0} bitcast(bf16[256,4096]{1,0} %convert.23) - %transpose.18 = bf16[1,4096,256]{2,1,0} transpose(bf16[1,256,4096]{2,1,0} %bitcast.62), dimensions={0,2,1}, metadata={op_name="jit(_einsum)/jit(main)/convert_element_type" source_file="gdm/jax/tokamax/xla/utils.py" source_line=33} - ROOT %bitcast.63 = bf16[1048576]{0} bitcast(bf16[1,4096,256]{2,1,0} %transpose.18) - } - - ENTRY main { - %p0 = f32[1,256,4,4096] parameter(0) - %p1 = bf16[1,4,4096,4096] parameter(1) - ROOT %gemm_fusion_r = bf16[1048576] fusion(%p0, %p1), kind=kCustom, - calls=gemm_fusion_r_computation, - backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} - } -)"; - - std::unique_ptr module = - ParseAndReturnVerifiedModule(kHlo).value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::AMPERE, /*minor=*/0}; - - TF_ASSERT_OK_AND_ASSIGN( - const std::vector configs, - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); - EXPECT_TRUE(std::none_of( - configs.begin(), configs.end(), - [](const GemmFusionAutotunerImpl::BackendConfig& config) { - return std::holds_alternative< - GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); - })); -} - -TEST_F(GemmFusionAutotunerTest, RewritesTritonFusionToCustomKernelFusion) { - const std::string kHlo = R"( - HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} - - %gemm_fusion_r_computation { - %parameter_0 = bf16[1024,1024]{1,0} parameter(0) - %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) - %parameter_1 = bf16[1024,1024]{1,0} parameter(1) - %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) - ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY main { - %p0 = bf16[1024,1024]{1,0} parameter(0) - %p1 = bf16[1024,1024]{1,0} parameter(1) - ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} - } -)"; - - std::unique_ptr module = - ParseAndReturnVerifiedModule(kHlo).value(); - - DebugOptions opts; - AutotuneConfig autotune_config{ - DeviceConfig{backend().default_stream_executor(), - backend().memory_allocator()}, - opts}; - AutotuneCacheKey cache_key(autotune_config.GetModelStr(), - *module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override, - ParseTextProto(R"pb( - version: 3 - results { - device: "..." - hlo: "..." - result { - custom_kernel_fusion { kernel_index: 1 } - run_time { nanos: 14 } - } - })pb")); - autotune_results_override.mutable_results(0)->set_device( - std::string(cache_key.GetModelStr())); - autotune_results_override.mutable_results(0)->set_hlo( - std::string(cache_key.GetHlo())); - - GemmFusionAutotunerRewriterVisitor visitor(autotune_config); - - CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override)); - visitor.RunOnModule(module.get(), {}).value(); - std::string pattern = R"( - CHECK: ROOT %cutlass_gemm_with_upcast - CHECK-SAME: fusion - CHECK-SAME: kind=kCustom - CHECK-SAME: "kernel_index":1 - )"; - TF_ASSERT_OK_AND_ASSIGN(bool file_check_matches, - RunFileCheck(module->ToString(), pattern)); - EXPECT_TRUE(file_check_matches); -} - } // namespace } // namespace gpu } // namespace xla From 98e3c52657bc24fe032b9f75c66dbeee94f7e9aa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Sep 2024 15:47:59 -0700 Subject: [PATCH 101/483] Automated Code Change PiperOrigin-RevId: 677289193 --- third_party/xla/xla/service/gpu/triton_fusion_analysis.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc index d3b003c7add4bd..c317045189bcd7 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tools/hlo_decomposer.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" From fcd7efe0233ee5adb99067b15a73008b541f7797 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 02:04:32 -0700 Subject: [PATCH 102/483] compat: Update forward compatibility horizon to 2024-09-22 PiperOrigin-RevId: 677409161 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 905b9f61683c96..880fbfc269a41a 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 21) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 22) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From da7e9c1353ead503548127941f535faef01a39f3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 02:09:56 -0700 Subject: [PATCH 103/483] Update GraphDef version to 1993. PiperOrigin-RevId: 677410390 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 8d0996c93db008..0630c0ec562e53 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1992 // Updated: 2024/9/21 +#define TF_GRAPH_DEF_VERSION 1993 // Updated: 2024/9/22 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 91bddd27d3cf2ea25725d60e6eefe692c5cecd54 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 07:16:41 -0700 Subject: [PATCH 104/483] Automated Code Change PiperOrigin-RevId: 677463213 --- tensorflow/python/lib/core/py_seq_tensor.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 1c81b35e48cc5e..2b6c2e289918a9 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -617,9 +617,9 @@ tstring PyRepr(PyObject* obj) { bool IsPyDimension(PyObject* obj) { const char* tp_name = obj->ob_type->tp_name; if (strcmp(tp_name, "Dimension") != 0) return false; - bool ret = str_util::EndsWith( - PyRepr(PyType(obj)), - "tensorflow.python.framework.tensor_shape.Dimension'>"); + bool ret = + absl::EndsWith(PyRepr(PyType(obj)), + "tensorflow.python.framework.tensor_shape.Dimension'>"); return ret; } From ed5b6d146495cd45c6cc9bff7fe921688eacbf5b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 07:24:56 -0700 Subject: [PATCH 105/483] Automated Code Change PiperOrigin-RevId: 677464307 --- .../common_runtime/process_function_library_runtime.cc | 2 +- .../process_function_library_runtime_test.cc | 6 +++--- tensorflow/core/common_runtime/propagator_state.h | 8 ++++---- tensorflow/core/common_runtime/simple_propagator_state.h | 4 ++-- .../core/common_runtime/single_threaded_executor.cc | 4 ++-- tensorflow/core/common_runtime/step_stats_collector.cc | 2 +- tensorflow/core/common_runtime/step_stats_collector.h | 2 +- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 52f8c3c8df00b4..69a2afdad67a4f 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -637,7 +637,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( &data_lib_def, absl::StrCat(function_name, "_partitioned_", random::New64())); const int num_subgraphs = subgraphs->size(); - gtl::InlinedVector instantiate_status(num_subgraphs); + absl::InlinedVector instantiate_status(num_subgraphs); // Before instantiating component functions, determine synchronous execution. data->enable_sync_execution = false; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 344d22efc58b28..4e38964b5b3abe 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -925,7 +925,7 @@ FunctionDef AddVarAcrossDevices() { class TestFunctionPackedArgs : public FunctionArgsInterface { public: TestFunctionPackedArgs(const int index, - gtl::InlinedVector&& tensor_args) { + absl::InlinedVector&& tensor_args) { packed_args_.emplace(index, std::move(tensor_args)); } @@ -942,7 +942,7 @@ class TestFunctionPackedArgs : public FunctionArgsInterface { std::vector GetLocalTensors() const override { return {}; } private: - absl::flat_hash_map> packed_args_; + absl::flat_hash_map> packed_args_; }; TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) { @@ -985,7 +985,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) { // Packed TensorHandle { - gtl::InlinedVector handles; + absl::InlinedVector handles; handles.push_back(TensorValue(&resource_handle0)); handles.push_back(TensorValue(&resource_handle1)); TestFunctionPackedArgs args(0, std::move(handles)); diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 680cb13ef3ecb4..e5f4fd6bfec0ac 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -34,7 +34,7 @@ limitations under the License. namespace tensorflow { -typedef gtl::InlinedVector AllocatorAttributeVec; +typedef absl::InlinedVector AllocatorAttributeVec; // Represents the ephemeral "edge state" associated with one invocation of // `Executor::Run()`. @@ -115,12 +115,12 @@ class PropagatorState { // TODO(b/152925936): Re-evaluate these constants with current usage // patterns. static constexpr int kSpillThreshold = 16384; - gtl::InlinedVector ready_; + absl::InlinedVector ready_; int front_index_; }; // TODO(b/152925936): Re-evaluate this constant with current usage patterns. - typedef gtl::InlinedVector TaggedNodeSeq; + typedef absl::InlinedVector TaggedNodeSeq; private: // The state of an iteration in a particular frame. @@ -283,7 +283,7 @@ class PropagatorState { private: // The active iteration states of this frame. - gtl::InlinedVector iterations; + absl::InlinedVector iterations; IterationState** const iterations_raw TF_GUARDED_BY(mu); IterationState* iterations_first TF_GUARDED_BY(mu); diff --git a/tensorflow/core/common_runtime/simple_propagator_state.h b/tensorflow/core/common_runtime/simple_propagator_state.h index 1a08b8b4d67fb1..9f465ef1e91d3d 100644 --- a/tensorflow/core/common_runtime/simple_propagator_state.h +++ b/tensorflow/core/common_runtime/simple_propagator_state.h @@ -99,12 +99,12 @@ class SimplePropagatorState { // TODO(b/152925936): Re-evaluate these constants with current usage // patterns. static constexpr int kSpillThreshold = 16384; - gtl::InlinedVector ready_; + absl::InlinedVector ready_; int front_index_; }; // TODO(b/152925936): Re-evaluate this constant with current usage patterns. - typedef gtl::InlinedVector TaggedNodeSeq; + typedef absl::InlinedVector TaggedNodeSeq; // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. void ActivateRoots(gtl::ArraySlice roots, diff --git a/tensorflow/core/common_runtime/single_threaded_executor.cc b/tensorflow/core/common_runtime/single_threaded_executor.cc index 19b6a831382151..e9a7fbc5af3612 100644 --- a/tensorflow/core/common_runtime/single_threaded_executor.cc +++ b/tensorflow/core/common_runtime/single_threaded_executor.cc @@ -62,8 +62,8 @@ Status ValidateOpIsSafeForSyncExecution( namespace { -typedef gtl::InlinedVector TensorValueVec; -typedef gtl::InlinedVector AllocatorAttributeVec; +typedef absl::InlinedVector TensorValueVec; +typedef absl::InlinedVector AllocatorAttributeVec; static const string& kSingleThreadedExecutor = *new string("SINGLE_THREADED_EXECUTOR"); diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 3121a938bd0152..695b7d55217094 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -462,7 +462,7 @@ string StepStatsCollector::ReportAllocsOnResourceExhausted( std::make_pair(dev_stat.first, alloc.first->allocator_name()); AllocStats& dev_allocs_stats = allocs_map[dev_allocator]; TrackingAllocator* tracking_alloc = alloc.second; - gtl::InlinedVector cur_records = + absl::InlinedVector cur_records = tracking_alloc->GetCurrentRecords(); int64_t cur_bytes = 0; for (const auto& r : cur_records) { diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index df1e579f6d8932..277630cd40f9de 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -124,7 +124,7 @@ class NodeExecStatsWrapper : public NodeExecStatsInterface { void AddAllocation(Allocator* allocator, TrackingAllocator* tracking_allocator); - gtl::InlinedVector, 2> + absl::InlinedVector, 2UL> allocations_; std::unique_ptr stats_; const NodeDef* const node_; // Not owned. From e81dff8a49adeaeec435afdfbc1234be28c88117 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 11:08:47 -0700 Subject: [PATCH 106/483] Automated Code Change PiperOrigin-RevId: 677502161 --- tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index f850b7f4e775ff..dcf16782319d12 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -278,10 +278,8 @@ cc_library( hdrs = ["pre_calibration.h"], compatible_with = get_compatible_with_portable(), visibility = [ - "//tensorflow:__pkg__", "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", - "//tensorflow/python:__pkg__", ], deps = [ ":component", From 59f28efd772fb8352b62834d55db272e926ef920 Mon Sep 17 00:00:00 2001 From: Matt Bahr Date: Sun, 22 Sep 2024 17:12:46 -0400 Subject: [PATCH 107/483] Declare size as uint64_t --- tensorflow/core/kernels/ragged_range_op.cc | 4 ++-- tensorflow/core/kernels/sequence_ops.cc | 8 ++++---- tensorflow/core/ops/math_ops.cc | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/ragged_range_op.cc b/tensorflow/core/kernels/ragged_range_op.cc index f2c24534f3fd6e..f36b409c5137d1 100644 --- a/tensorflow/core/kernels/ragged_range_op.cc +++ b/tensorflow/core/kernels/ragged_range_op.cc @@ -81,7 +81,7 @@ class RaggedRangeOp : public OpKernel { T limit = broadcast_limits ? limits(0) : limits(row); T delta = broadcast_deltas ? deltas(0) : deltas(row); OP_REQUIRES(context, delta != 0, InvalidArgument("Requires delta != 0")); - SPLITS_TYPE size; // The number of elements in the specified range. + uint64_t size; // The number of elements in the specified range. if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) { size = 0; @@ -105,7 +105,7 @@ class RaggedRangeOp : public OpKernel { context, size_auto <= std::numeric_limits::max(), errors::InvalidArgument("Requires ((limit - start) / delta) <= ", std::numeric_limits::max())); - size = static_cast(size_auto); + size = static_cast(size_auto); } OP_REQUIRES(context, size >= 0, InvalidArgument("Requires size >= 0")); OP_REQUIRES( diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index 1fa09a0e09d54a..1bc4d35c20c737 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -36,10 +36,10 @@ namespace functor { template struct RangeFunctor { - void operator()(OpKernelContext* context, int64_t size, T start, T delta, + void operator()(OpKernelContext* context, uint64_t size, T start, T delta, typename TTypes::Flat output) const { (void)context; - for (int64_t i = 0; i < size; ++i) { + for (uint64_t i = 0; i < size; ++i) { output(i) = start + static_cast(i) * delta; } } @@ -91,7 +91,7 @@ class RangeOp : public OpKernel { errors::InvalidArgument( "Requires start >= limit when delta < 0: ", start, "/", limit)); } - int64_t size; + uint64_t size; if constexpr (std::is_integral::value) { uint64_t range; if ((limit > 0 && start < 0) || (limit < 0 && start > 0)) { @@ -110,7 +110,7 @@ class RangeOp : public OpKernel { context, size_auto <= std::numeric_limits::max(), errors::InvalidArgument("Requires ((limit - start) / delta) <= ", std::numeric_limits::max())); - size = static_cast(size_auto); + size = static_cast(size_auto); } TensorShape shape; diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 65a98857cb0310..6b5de9fda33191 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1511,7 +1511,7 @@ Status RangeSize(const Tensor* start_t, const Tensor* limit_t, return errors::InvalidArgument("Requires delta != 0"); } - int64_t size; + uint64_t size; if (std::is_integral::value) { uint64_t range; if ((limit > 0 && start < 0) || (limit < 0 && start > 0)) { @@ -1530,10 +1530,10 @@ Status RangeSize(const Tensor* start_t, const Tensor* limit_t, return errors::InvalidArgument("Requires ((limit - start) / delta) <= ", std::numeric_limits::max()); } - size = static_cast(size_auto); + size = static_cast(size_auto); } - c->set_output(0, c->Vector(static_cast(size))); + c->set_output(0, c->Vector(static_cast(size))); return absl::OkStatus(); } From 15789eb274de77d60c0cea1d0ad4c1e5a62786c5 Mon Sep 17 00:00:00 2001 From: Eric Salo Date: Sun, 22 Sep 2024 16:09:01 -0700 Subject: [PATCH 108/483] cleanup: remove api_version from BUILD files PiperOrigin-RevId: 677554107 --- tensorflow/core/function/polymorphism/BUILD | 1 - tensorflow/core/function/trace_type/BUILD | 3 --- tensorflow/core/grappler/costs/BUILD | 1 - tensorflow/core/grappler/optimizers/inference/BUILD | 1 - third_party/xla/xla/BUILD | 2 -- third_party/xla/xla/pjrt/BUILD | 1 - 6 files changed, 9 deletions(-) diff --git a/tensorflow/core/function/polymorphism/BUILD b/tensorflow/core/function/polymorphism/BUILD index 0a333bb9d73ffd..3289406a109e26 100644 --- a/tensorflow/core/function/polymorphism/BUILD +++ b/tensorflow/core/function/polymorphism/BUILD @@ -100,7 +100,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "function_type_py_pb2", -# api_version = 2, # visibility = ["//visibility:private"], # deps = [":function_type_proto"], # ) diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD index 788f4278a187b6..cf8502fbe77d1c 100644 --- a/tensorflow/core/function/trace_type/BUILD +++ b/tensorflow/core/function/trace_type/BUILD @@ -186,21 +186,18 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "serialization_py_pb2", -# api_version = 2, # visibility = ["//tensorflow:internal"], # deps = [":serialization_proto"], # ) # # py_proto_library( # name = "serialization_test_py_pb2", -# api_version = 2, # visibility = ["//tensorflow:internal"], # deps = [":serialization_test_proto"], # ) # # py_proto_library( # name = "default_types_py_pb2", -# api_version = 2, # visibility = ["//tensorflow:internal"], # deps = [":default_types_proto"], # ) diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 59f08fedafae7d..9576fa5c2dc693 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -416,7 +416,6 @@ tf_cc_test( # py_proto_library( # name = "op_performance_data_py_pb2", # has_services = 0, -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":op_performance_data"], # ) diff --git a/tensorflow/core/grappler/optimizers/inference/BUILD b/tensorflow/core/grappler/optimizers/inference/BUILD index 3b6e92e6a0434c..41f6c0728cc7de 100644 --- a/tensorflow/core/grappler/optimizers/inference/BUILD +++ b/tensorflow/core/grappler/optimizers/inference/BUILD @@ -25,7 +25,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "batch_op_rewriter_proto_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":batch_op_rewriter_proto"], # ) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 0d9ad2f953a560..039f05333174db 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1301,7 +1301,6 @@ cc_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "xla_data_proto_py_pb2", -# api_version = 2, # visibility = internal_visibility([":friends"]), # deps = [":xla_data_proto"], # ) @@ -1309,7 +1308,6 @@ cc_library( # py_proto_library( # name = "xla_py_pb2", # testonly = 0, -# api_version = 2, # compatible_with = get_compatible_with_portable(), # visibility = internal_visibility([":friends"]), # deps = [":xla_proto"], diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index b98825d15f9b97..48439f796f1a7e 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -942,7 +942,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "compile_options_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":compile_options_proto"], # ) From a50403494fbffb10c136e78c4cfe58ed7a3ccadf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 17:22:41 -0700 Subject: [PATCH 109/483] Automated Code Change PiperOrigin-RevId: 677566674 --- tensorflow/compiler/mlir/tfr/resources/test_ops.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tfr/resources/test_ops.cc b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc index 3aaa0850805030..c9dcfd26104e86 100644 --- a/tensorflow/compiler/mlir/tfr/resources/test_ops.cc +++ b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { From 19faa81daa36027232a7858e46cfd45d36da99be Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Sun, 22 Sep 2024 17:50:10 -0700 Subject: [PATCH 110/483] [XLA] Introduce infeed token propagation During computation inlining, specifically loop unrolling, it is posibble for infeeds (and outfeeds) to get reordered in a way that breaks the original scheduling constraints set by the computation boundaries. This is a result of Tensorflow not exposing tokens for these ops to the user, so the input and output tokens end up dangling. Loop unrolling in XLA can be thought of applying the same function repeatedly to itself, e.g. transforming f(x) into f(f(x)). By pushing the tokens outside the loop body, we can guarantee that the output token of the first infeed will become the input token of the next infeed, thus creating a data dependency chain and preserving the original ordering. PiperOrigin-RevId: 677572109 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 18 +- third_party/xla/xla/hlo/ir/hlo_instruction.h | 5 +- third_party/xla/xla/service/BUILD | 35 + .../xla/service/infeed_token_propagation.cc | 459 +++++++++++++ .../xla/service/infeed_token_propagation.h | 45 ++ .../service/infeed_token_propagation_test.cc | 601 ++++++++++++++++++ .../while_loop_invariant_code_motion.cc | 1 + 7 files changed, 1159 insertions(+), 5 deletions(-) create mode 100644 third_party/xla/xla/service/infeed_token_propagation.cc create mode 100644 third_party/xla/xla/service/infeed_token_propagation.h create mode 100644 third_party/xla/xla/service/infeed_token_propagation_test.cc diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index ad23f22909db5e..8822767a05704b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -3399,18 +3399,30 @@ const PtrVec& HloInstruction::branch_computations() const { return called_computations(); } -int HloInstruction::branch_count() const { +int32_t HloInstruction::branch_count() const { CHECK(HloOpcode::kConditional == opcode_); return called_computations().size(); } -HloComputation* HloInstruction::branch_computation(int b) const { - CHECK(HloOpcode::kConditional == opcode_); +HloComputation* HloInstruction::branch_computation(int32_t b) const { + CHECK_EQ(HloOpcode::kConditional, opcode_); CHECK_GE(b, 0); CHECK_LT(b, called_computations().size()); return called_computations()[b]; } +int32_t HloInstruction::branch_index(HloComputation* computation) const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + CHECK_NE(computation, nullptr); + for (int32_t idx = 0; idx < branch_count(); idx++) { + if (branch_computation(idx) == computation) { + return idx; + } + } + LOG(FATAL) << absl::StrFormat("Conditional %s does not contain branch %s", + name(), computation->name()); +} + void HloInstruction::set_branch_computation(int b, HloComputation* computation) { CHECK_EQ(HloOpcode::kConditional, opcode_); diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 42729daec64df3..b6ba1bd8d37571 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1808,8 +1808,9 @@ class HloInstruction { // // Precondition: The instruction is a Conditional instruction. const PtrVec& branch_computations() const; - int branch_count() const; - HloComputation* branch_computation(int b) const; + int32_t branch_count() const; + HloComputation* branch_computation(int32_t b) const; + int32_t branch_index(HloComputation* computation) const; // Sets a branch HloComputation for Conditional. // The setter should only be called by HloModule or HloComputation methods. // diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 8216d818fb8085..1165e6b19c081c 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -8543,4 +8543,39 @@ xla_cc_test( ], ) +cc_library( + name = "infeed_token_propagation", + srcs = ["infeed_token_propagation.cc"], + hdrs = ["infeed_token_propagation.h"], + deps = [ + ":hlo_dce", + ":tuple_simplifier", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "infeed_token_propagation_test", + srcs = ["infeed_token_propagation_test.cc"], + deps = [ + ":infeed_token_propagation", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/third_party/xla/xla/service/infeed_token_propagation.cc b/third_party/xla/xla/service/infeed_token_propagation.cc new file mode 100644 index 00000000000000..11c4d6bb7d0c5c --- /dev/null +++ b/third_party/xla/xla/service/infeed_token_propagation.cc @@ -0,0 +1,459 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/infeed_token_propagation.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_dce.h" +#include "xla/service/tuple_simplifier.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { +bool IsDanglingInfeed(HloInstruction* infeed) { + CHECK(infeed->opcode() == HloOpcode::kInfeed); + if (infeed->has_sharding()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + return false; + } + + // Check for dangling input token. + if (const HloInstruction* after_all = infeed->operand(0); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + return false; + } + + // Check for dangling output token. + for (const HloInstruction* user : infeed->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 1) { + return false; + } + } + + return true; +} + +bool IsDanglingOutfeed(HloInstruction* outfeed) { + CHECK(outfeed->opcode() == HloOpcode::kOutfeed); + if (outfeed->has_sharding()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + return false; + } + + // Check for dangling input token. + if (const HloInstruction* after_all = outfeed->operand(1); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + return false; + } + + // Check for dangling output token. + if (outfeed->user_count() != 0) { + return false; + } + + return true; +} + +HloInstruction* ReconstructTuple(HloInstruction* tuple) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + std::vector gtes; + gtes.resize(tuple->shape().tuple_shapes_size()); + for (int64_t idx = 0; idx < gtes.size(); ++idx) { + gtes[idx] = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(tuple, idx)); + } + + return computation->AddInstruction(HloInstruction::CreateTuple(gtes)); +} + +absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, + bool add_token_operand) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + // Recreate the original tuple, we'll need to pass this to all the users. + std::vector original_users = tuple->users(); + HloInstruction* original_tuple = ReconstructTuple(tuple); + for (HloInstruction* original_user : original_users) { + int64_t idx = original_user->operand_index(tuple); + TF_RETURN_IF_ERROR(original_user->ReplaceOperandWith(idx, original_tuple)); + } + + // Append the token to the parameter tuple. + *tuple->mutable_shape()->add_tuple_shapes() = ShapeUtil::MakeTokenShape(); + if (add_token_operand) { + tuple->AppendOperand( + computation->AddInstruction(HloInstruction::CreateToken())); + } + + HloInstruction* input_token_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + tuple, tuple->shape().tuple_shapes_size() - 1)); + return input_token_gte; +} + +absl::Status CanonicalizeConditionalBranch(HloComputation* branch) { + CHECK(branch->IsConditionalBranchComputation()); + CHECK_EQ(branch->num_parameters(), 1); + + // Tuplify the branch parameter if needed. + HloInstruction* parameter = branch->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = branch->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the branch tuple if needed. + HloInstruction* conditional = branch->ConditionalCallInstruction(); + int64_t branch_operand_idx = conditional->branch_index(branch) + 1; + HloInstruction* branch_tuple = + conditional->mutable_operand(branch_operand_idx); + if (!branch_tuple->shape().IsTuple()) { + branch_tuple = conditional->parent()->AddInstruction( + HloInstruction::CreateTuple({branch_tuple})); + TF_RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + } + + // Explicitly disjoin computation parameters from branch inputs, so we can + // insert tokens into the input tuple. + if (branch_tuple->opcode() == HloOpcode::kParameter) { + branch_tuple = ReconstructTuple(branch_tuple); + TF_RETURN_IF_ERROR( + conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); + } + + // If the computation root is a also a computation parameter, explicitly split + // them, as the input and output tokens cannot be part of the same + // instruction. + HloInstruction* root = branch->root_instruction(); + if (root->opcode() == HloOpcode::kParameter) { + root = ReconstructTuple(root); + branch->set_root_instruction(root); + } + + // ConditionalCanonicalizer should have already turned the conditional output + // to be a tuple. + CHECK(conditional->shape().IsTuple()); + return absl::OkStatus(); +} + +absl::Status CanonicalizeWhileBody(HloComputation* body) { + CHECK(body->IsWhileBodyComputation()); + CHECK_EQ(body->num_parameters(), 1); + + // Tuplify the body parameter if needed. + HloInstruction* parameter = body->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = body->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the body root if needed. + HloInstruction* root = body->root_instruction(); + if (!root->shape().IsTuple()) { + root = body->AddInstruction(HloInstruction::CreateTuple({root})); + body->set_root_instruction(root, /*accept_different_shape=*/true); + } + + // Tuplify the condition parameter if needed. + HloInstruction* loop = body->WhileCallInstruction(); + HloComputation* cond = loop->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + if (!cond_parameter->shape().IsTuple()) { + *cond_parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({cond_parameter->shape()}); + HloInstruction* original = cond->AddInstruction( + HloInstruction::CreateGetTupleElement(cond_parameter, 0)); + TF_RETURN_IF_ERROR( + cond_parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while instruction if needed. + if (!loop->shape().IsTuple()) { + *loop->mutable_shape() = ShapeUtil::MakeTupleShape({loop->shape()}); + HloInstruction* original = loop->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(loop, 0)); + TF_RETURN_IF_ERROR(loop->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while tuple if needed. + HloInstruction* loop_tuple = loop->mutable_operand(0); + if (!loop_tuple->shape().IsTuple()) { + loop_tuple = loop->parent()->AddInstruction( + HloInstruction::CreateTuple({loop_tuple})); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(0, loop_tuple)); + } + + // Explicitly disjoin computation parameters from loop inputs, so we can + // insert tokens into the input tuple. + if (loop_tuple->opcode() == HloOpcode::kParameter) { + loop_tuple = ReconstructTuple(loop_tuple); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWith(0, loop_tuple)); + } + + // If the computation root is a also a computation parameter, explicitly + // split them, as the input and output tokens cannot be part of the same + // instruction. + if (root->opcode() == HloOpcode::kParameter) { + root = ReconstructTuple(root); + body->set_root_instruction(root); + } + + return absl::OkStatus(); +} + +absl::StatusOr> +PropagateTokenThroughConditionalBranch(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + // Conditional branches can diverge in inputs, but must converge on outputs. + + // Fixup the branch. + HloComputation* comp = instruction->parent(); + TF_RETURN_IF_ERROR(CanonicalizeConditionalBranch(comp)); + HloInstruction* next_instruction = comp->ConditionalCallInstruction(); + + // Insert the output token into each branch. + for (HloComputation* branch : next_instruction->branch_computations()) { + HloInstruction* root = branch->root_instruction(); + if (branch == comp) { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token); + } else { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/true).status()); + } + } + + // Insert the input token into the branch parameter. + HloInstruction* parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the branch tuple. + int64_t branch_operand_idx = next_instruction->branch_index(comp) + 1; + HloInstruction* branch_tuple = + next_instruction->mutable_operand(branch_operand_idx); + TF_ASSIGN_OR_RETURN( + HloInstruction * next_input_token_gte, + InsertTokenIntoTuple(branch_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR(next_instruction->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + HloInstruction* next_input_token = + branch_tuple->mutable_operand(next_input_token_gte->tuple_index()); + + // Insert the output token into conditional instruction. + TF_ASSIGN_OR_RETURN( + HloInstruction * next_output_token, + InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + + return std::make_tuple(next_instruction, next_input_token, next_output_token); +} + +absl::StatusOr> +PropagateTokenThroughWhileBody(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + // While loops need to converge on input and output. + + // Fixup the while body. + HloComputation* comp = instruction->parent(); + TF_RETURN_IF_ERROR(CanonicalizeWhileBody(comp)); + HloInstruction* next_instruction = comp->WhileCallInstruction(); + + // Insert the output token into the body root. + HloInstruction* root = comp->root_instruction(); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token); + + // Insert the input token into the body parameter. + HloInstruction* body_parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(body_parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the condition parameter. + HloComputation* cond = next_instruction->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(cond_parameter, /*add_token_operand=*/false) + .status()); + + // Insert the input token into the while tuple. + HloInstruction* while_tuple = next_instruction->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * next_input_token, + InsertTokenIntoTuple(while_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR( + next_instruction->ReplaceOperandWithDifferentShape(0, while_tuple)); + + // Insert the input token into the while instruction. + TF_ASSIGN_OR_RETURN( + HloInstruction * next_output_token, + InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + + return std::make_tuple(next_instruction, next_input_token, next_output_token); +} + +absl::Status PropagateToken(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + HloComputation* comp = instruction->parent(); + if (comp->IsEntryComputation()) { + // If we propagate through the root instruction, reconstruct the original + // tuple and set that to be root. + if (instruction->IsRoot() && + (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional)) { + std::vector gtes; + int64_t output_token_idx = output_token->tuple_index(); + for (int64_t idx = 0; idx < instruction->shape().tuple_shapes_size(); + idx++) { + if (idx != output_token_idx) { + gtes.push_back(comp->AddInstruction( + HloInstruction::CreateGetTupleElement(instruction, idx))); + } + } + HloInstruction* original_tuple = + comp->AddInstruction(HloInstruction::CreateTuple(gtes)); + comp->set_root_instruction(original_tuple, + /*accept_different_shape=*/true); + } + return absl::OkStatus(); + } + + HloInstruction* next_instruction = nullptr; + HloInstruction* next_input_token = nullptr; + HloInstruction* next_output_token = nullptr; + if (comp->IsConditionalBranchComputation()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + if (comp->ConditionalCallInstruction()->has_sharding()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + std::tie(next_instruction, next_input_token, next_output_token), + PropagateTokenThroughConditionalBranch(instruction, input_token, + output_token)); + } else if (comp->IsWhileBodyComputation()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + if (comp->WhileCallInstruction()->has_sharding()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + std::tie(next_instruction, next_input_token, next_output_token), + PropagateTokenThroughWhileBody(instruction, input_token, output_token)); + } else { + // We only expect to encounter computations behind while and conditional + // instructions. In the case of it being behind a while condition, there is + // no way to propagate the output token, as the root only returns a + // predicate. All other computations that could possibly contain infeed + // or outfeed ops should have already been inlined. + VLOG(2) << "Unhandled computation: " << comp->name(); + return absl::OkStatus(); + } + CHECK_NE(next_instruction, nullptr); + CHECK_NE(next_input_token, nullptr); + CHECK_NE(next_output_token, nullptr); + + return PropagateToken(next_instruction, next_input_token, next_output_token); +} +} // namespace + +absl::StatusOr InfeedTokenPropagation::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + VLOG(5) << "Before InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + + std::vector dangling_infeeds; + std::vector dangling_outfeeds; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + if (!computation->IsEntryComputation()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed && + IsDanglingInfeed(instruction)) { + VLOG(1) << "Found dangling infeed: " << instruction->ToString(); + dangling_infeeds.push_back(instruction); + } else if (instruction->opcode() == HloOpcode::kOutfeed && + IsDanglingOutfeed(instruction)) { + VLOG(1) << "Found dangling outfeed: " << instruction->ToString(); + dangling_outfeeds.push_back(instruction); + } + } + } + } + + for (HloInstruction* dangling_infeed : dangling_infeeds) { + HloInstruction* input_token = dangling_infeed->mutable_operand(0); + HloInstruction* output_token = dangling_infeed->AddInstruction( + HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); + TF_RETURN_IF_ERROR( + PropagateToken(dangling_infeed, input_token, output_token)); + } + for (HloInstruction* dangling_outfeed : dangling_outfeeds) { + HloInstruction* input_token = dangling_outfeed->mutable_operand(1); + HloInstruction* output_token = dangling_outfeed; + TF_RETURN_IF_ERROR( + PropagateToken(dangling_outfeed, input_token, output_token)); + } + + bool changed = !dangling_infeeds.empty() || !dangling_outfeeds.empty(); + if (changed) { + TF_RETURN_IF_ERROR( + TupleSimplifier().Run(module, execution_threads).status()); + TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); + } + + VLOG(5) << "After InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + return changed; +} +} // namespace xla diff --git a/third_party/xla/xla/service/infeed_token_propagation.h b/third_party/xla/xla/service/infeed_token_propagation.h new file mode 100644 index 00000000000000..cc6994a62a98a9 --- /dev/null +++ b/third_party/xla/xla/service/infeed_token_propagation.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ +#define XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +// Finds dangling infeed/outfeed tokens inside nested computations and bubbles +// them up through callers until they reach the entry computation. This is +// needed to prepare these computations to be inlined, otherwise the previous +// computation boundaries won't be there to stop infeeds/outfeeds from being +// reordered during scheduling. +// +// This pass assumes the HLO graph is flattened. +class InfeedTokenPropagation : public HloModulePass { + public: + std::string_view name() const override { return "infeed-token-propagation"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; +} // namespace xla + +#endif // XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/infeed_token_propagation_test.cc b/third_party/xla/xla/service/infeed_token_propagation_test.cc new file mode 100644 index 00000000000000..8c1024253868d6 --- /dev/null +++ b/third_party/xla/xla/service/infeed_token_propagation_test.cc @@ -0,0 +1,601 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/infeed_token_propagation.h" + +#include +#include + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class InfeedTokenPropagationTest : public HloTestBase { + protected: + InfeedTokenPropagationTest() = default; +}; + +TEST_F(InfeedTokenPropagationTest, EntryComputationInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT gte.0 = get-tuple-element(infeed.0), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, EntryComputationOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.1 = tuple() +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The infeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(arg.0) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleConditional) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = s32[] parameter(0) + outfeed_tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, arg.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = cond->mutable_operand(1); + EXPECT_TRUE(true_tuple->shape().IsTuple()); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, DisjointConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, WhileInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed output token should have propagated through the while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed input token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, WhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + gte.0 = get-tuple-element(arg.0), index=0 + ROOT tuple.0 = tuple(gte.0) +} + +cond { + arg.0 = (s32[]) parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + while_tuple.0 = tuple(arg.0) + ROOT while.0 = (s32[]) while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, DisjointWhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleWhile) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = s32[] parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + ROOT while.0 = s32[] while(arg.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_TRUE(loop->shape().IsTuple()); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + EXPECT_THAT(loop->operand(0), op::Tuple(op::Parameter(), op::AfterAll())); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NestedInfeedOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + gte.0 = get-tuple-element(infeed.0), index=0 + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(gte.0) + false_tuple.0 = tuple() + cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp + ROOT tuple.0 = tuple() +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed and outfeed output tokens should have propagated through the + // loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed input tokens should have propagated through the loop + // tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed output tokens should have propagated through the + // while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1), + op::GetTupleElement(op::Conditional(), 0))); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc index ed44547af3fca4..b1aae51df132e9 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc @@ -232,6 +232,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( } if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kAfterAll || instruction->opcode() == HloOpcode::kParameter || !instruction->control_predecessors().empty() || !instruction->control_successors().empty()) { From 660d0a908f89309559a782367709e83356a6f8ab Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Sun, 22 Sep 2024 20:01:15 -0700 Subject: [PATCH 111/483] Automated g4 rollback of changelist 676140549. *** Reason for rollback *** Test breakage Reverts 3528f87b5cbda4774d6098caa67a2591d8a35764 PiperOrigin-RevId: 677597455 --- .../auto_sharding_dot_handler.cc | 8 +- .../xla/xla/service/sharding_propagation.cc | 153 +++++++++++------- .../xla/xla/service/sharding_propagation.h | 9 +- .../xla/service/sharding_propagation_test.cc | 108 ++----------- 4 files changed, 117 insertions(+), 161 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 1916172aa70901..d5d008ee19442d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -438,13 +438,13 @@ std::optional HandlerBase::GetShardingFromUser( CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get())); if (ins_->opcode() == HloOpcode::kConvolution) { xla::InferConvolutionShardingFromOperands( - ins_clone.get(), /* aggressiveness */ 10, - /* may_combine_partial_sharding */ true); + ins_clone.get(), call_graph_, 10, + /* may_combine_partial_sharding */ true, /* is_spmd */ true); } else { xla::InferDotShardingFromOperands( - ins_clone.get(), + ins_clone.get(), call_graph_, dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()), - /* aggressiveness */ 10, /* may_combine_partial_sharding */ true); + /* may_combine_partial_sharding/ */ true, /* is_spmd */ true); } if (!ins_clone->has_sharding()) { return std::nullopt; diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 886a3d63a90aed..3ffd85687793e5 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -16,10 +16,10 @@ limitations under the License. #include "xla/service/sharding_propagation.h" #include -#include #include #include #include +#include #include #include #include @@ -36,12 +36,10 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -49,7 +47,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/protobuf_util.h" -#include "xla/service/call_graph.h" #include "xla/service/dot_as_convolution_util.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/spmd/shard_barrier_partitioner.h" @@ -419,6 +416,55 @@ bool SupportSpatialPartitioning( } } +// Helper to lookahead sharding of user of an instruction to be used as guidance +// for ambiguous cases. +std::optional LookaheadUserSharding(HloInstruction* instr, + bool is_spmd, + const CallGraph& call_graph) { + if (instr->user_count() != 1) { + return std::nullopt; + } + HloInstruction* current_user = instr->users()[0]; + std::optional sharding; + std::vector users_chain = {instr, current_user}; + // Collect single user instructions along the way. + while (!current_user->has_sharding()) { + // Only consider single user chains. + if (current_user->users().size() != 1) { + users_chain.clear(); + break; + } + current_user = current_user->users()[0]; + users_chain.push_back(current_user); + } + // Early exit for unsupported cases. + if (users_chain.empty()) { + return std::nullopt; + } + for (int i = users_chain.size() - 1; i >= 1; --i) { + HloInstruction* user = users_chain[i]; + HloInstruction* current = users_chain[i - 1]; + CHECK(user->has_sharding()); + sharding = ShardingPropagation::GetShardingFromUser( + *current, *user, INT64_MAX, is_spmd, call_graph, + /*sharding_helper=*/nullptr); + // We need to set the sharding to the instruction, because + // GetShardingFromUser() interface uses sharding from the instruction + // itself. It will be cleared out later. + if (sharding.has_value() && i != 1) { + current->set_sharding(*sharding); + continue; + } + break; + } + // Clear the sharding of the middle instructions we set the sharding of + // because they were unsharded. + for (int i = 1; i < users_chain.size() - 1; ++i) { + users_chain[i]->clear_sharding(); + } + return sharding; +} + // Infer output sharding on index parallel dimensions for gather from operand // and indices. bool InferGatherParallelShardingFromOperands( @@ -1023,9 +1069,9 @@ bool IsCSEPreventionSharding(const HloSharding& sharding) { } // namespace bool InferDotShardingFromOperands( - HloInstruction* instruction, + HloInstruction* instruction, const CallGraph& call_graph, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - int64_t aggressiveness, bool may_combine_partial_sharding) { + bool may_combine_partial_sharding, bool is_spmd) { auto from_operand = [&](int64_t operand_index) { auto operand = instruction->operand(operand_index); const HloSharding& operand_sharding = operand->sharding(); @@ -1080,66 +1126,55 @@ bool InferDotShardingFromOperands( from_operand(1), instruction, may_combine_partial_sharding, /*allow_aggressive_resharding=*/false); } - - // Four cases based on if improved_operand_0 and improved_operand_1 are - // available. - // Case 0. Both operands have no improved sharding. + // If not improved sharding found then do not set any sharding. if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) { return false; } - // Case 1. Sharding found from operand 0 but not operand 1. Set sharding from - // operand 0. + // Sharding found from operand 0 but not operand 1. Set sharding from operand + // 0 if (improved_operand_0.has_value() && !improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_0); return true; } - // Case 2. Sharding found from operand 1 but not operand 0. Set sharding from - // operand 1. + // Sharding found from operand 1 but not operand 0. Set sharding from operand + // 1 if (!improved_operand_0.has_value() && improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_1); return true; } - // Case 3. Both operands have improved shardings. CHECK(improved_operand_0.has_value() && improved_operand_1.has_value()); - - // If one of the improved shardings is a sub-tiling or equal to the other, use - // the better sharding with more tiles. - if (hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *improved_operand_0, *improved_operand_1)) { - instruction->set_sharding(*improved_operand_0); - return true; - } - if (hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *improved_operand_1, *improved_operand_0)) { - instruction->set_sharding(*improved_operand_1); - return true; - } - - // If the two improved shardings are mergeable, there is no conflict. - if (std::optional improved_sharding = - hlo_sharding_util::ReturnImprovedShardingImpl( - *improved_operand_0, &improved_operand_1.value(), - instruction->shape(), may_combine_partial_sharding, - /*allow_aggressive_resharding=*/false)) { - instruction->set_sharding(*improved_sharding); - return true; - } - - if (aggressiveness < 3) { - // We can improve the dot with different shardings. Pause the propagation - // and wait for the winner between the two operands. - return false; - } - - // The two improved sharding are different and we are at the highest - // aggressiveness. Prioritize the operand with larger size. + std::optional lookahead_sharding = + LookaheadUserSharding(instruction, is_spmd, call_graph); std::array sharding_priority = {*improved_operand_0, *improved_operand_1}; - if (ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < - ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { + bool priority_defined_with_lookahead = false; + // Found sharding from lookahead. + if (lookahead_sharding.has_value()) { + const bool operand_0_is_lookahead_subtiling = + hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *lookahead_sharding, *improved_operand_0); + const bool operand_1_is_lookahead_subtiling = + hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *lookahead_sharding, *improved_operand_1); + // If the sharding from operand 0 is a subtiling of the user, but not the + // one from operand 1 prioritize that sharding. + if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) { + priority_defined_with_lookahead = true; + } + // If the sharding from operand 1 is a subtiling of the user, but not the + // one from operand 0 prioritize that sharding. + if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) { + instruction->set_sharding(*improved_operand_1); + std::swap(sharding_priority[0], sharding_priority[1]); + priority_defined_with_lookahead = true; + } + } + // If lookahead didn't define a priority then use size. + if (!priority_defined_with_lookahead && + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { std::swap(sharding_priority[0], sharding_priority[1]); } - // Set primary sharding to the instruction and then try to improve it with // the secondary sharding. instruction->set_sharding(sharding_priority[0]); @@ -1150,8 +1185,10 @@ bool InferDotShardingFromOperands( // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, + const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding) { + bool may_combine_partial_sharding, + bool is_spmd) { auto get_partitions_for_dims = [&](const HloInstruction* inst, absl::Span< @@ -1186,8 +1223,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 && instruction->batch_group_count() == 1 && instruction->feature_group_count() == 1)) { - return InferDotShardingFromOperands(instruction, dot_dims, aggressiveness, - may_combine_partial_sharding); + return InferDotShardingFromOperands(instruction, call_graph, dot_dims, + may_combine_partial_sharding, is_spmd); } const auto& dnums = instruction->convolution_dimension_numbers(); const HloInstruction* lhs = instruction->operand(0); @@ -2290,8 +2327,9 @@ bool ShardingPropagation::InferShardingFromOperands( 1); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands(instruction, aggressiveness, - may_combine_partial_sharding); + return InferConvolutionShardingFromOperands( + instruction, call_graph, aggressiveness, may_combine_partial_sharding, + is_spmd_); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!hlo_sharding_util::IsSpatiallyPartitioned(input)) { @@ -2380,8 +2418,9 @@ bool ShardingPropagation::InferShardingFromOperands( case HloOpcode::kDot: { const auto& dnums = dot_as_convolution_util::ParseDotGeneralFromDot(instruction); - return InferDotShardingFromOperands(instruction, dnums, aggressiveness, - may_combine_partial_sharding); + return InferDotShardingFromOperands(instruction, call_graph, dnums, + may_combine_partial_sharding, + is_spmd_); } case HloOpcode::kParameter: { auto parent_it = computation_map.find(instruction->parent()); diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h index 2654a1fd7d335b..27cef820977436 100644 --- a/third_party/xla/xla/service/sharding_propagation.h +++ b/third_party/xla/xla/service/sharding_propagation.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_SHARDING_PROPAGATION_H_ #define XLA_SERVICE_SHARDING_PROPAGATION_H_ -#include #include #include #include @@ -36,15 +35,17 @@ namespace xla { // Infers the shardings for a dot HLO op from the shardings on its operands, // which are expected to have sharding annotations. bool InferDotShardingFromOperands( - HloInstruction* instruction, + HloInstruction* instruction, const CallGraph& call_graph, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - int64_t aggressiveness, bool may_combine_partial_sharding); + bool may_combine_partial_sharding, bool is_spmd); // Infers the shardings for a convolution HLO op from the shardings on its // operands, which are expected to have sharding annotations. bool InferConvolutionShardingFromOperands(HloInstruction* instruction, + const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding); + bool may_combine_partial_sharding, + bool is_spmd); // Remove Sharding custom-call instruction by folding the sharding attribute // to its operand. If the operand already has a different sharding, insert a diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index 565314d9150e33..5ca4b47d8ea15c 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -3324,7 +3324,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); + ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -3396,7 +3396,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,4]0,2,3,1,4,6,7,5}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); + ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -11863,7 +11863,7 @@ ENTRY main.9 { op::Sharding("{{devices=[4]<=[4]}, {devices=[4]<=[4]}}")); } -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands1) { +TEST_F(ShardingPropagationTest, LookaheadUsersOfDot) { const char* const hlo_string = R"( HloModule module @@ -11880,108 +11880,24 @@ ENTRY %entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); EXPECT_TRUE(changed); XLA_VLOG_LINES(1, module->ToString()); + // Check dangling sharding custom-call can be removed by DCE after + // propagation. auto* instruction = FindInstruction(module.get(), "dot.1"); + // Check sharding is correctly propagated. EXPECT_THAT(instruction, op::Sharding( "{devices=[4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}")); } -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands2) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[16,32] parameter(0), sharding={devices=[16,1]<=[16]} - p1 = bf16[32,64] parameter(1), sharding={devices=[1,16]<=[16]} - dot = bf16[16,64] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT copy = bf16[16,64] copy(dot), sharding={replicated} -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT(instruction, op::Sharding("{devices=[1,16]<=[16]}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands3) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,4,2]<=[16]} - p1 = bf16[4,32,64] parameter(1), sharding={devices=[2,8,1]<=[16]} - dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} - ROOT copy = bf16[4,16,64] copy(dot) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT( - instruction, - op::Sharding("{devices=[2,4,1,2]<=[16] last_tile_dim_replicate}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands4) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,1,8]<=[16]} - p1 = bf16[4,32,64] parameter(1), sharding={devices=[4,1,4]<=[16]} - dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} - ROOT copy = bf16[4,16,64] copy(dot) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT(instruction, op::Sharding("{devices=[4,1,4]<=[16]}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands5) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[16,16] parameter(0), sharding={devices=[4,4]<=[4,4]T(1,0)} - p1 = bf16[16,16] parameter(1), sharding={devices=[4,4]<=[4,4]T(1,0)} - dot.0 = bf16[16,16] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} - p2 = bf16[16,16] parameter(2), sharding={devices=[4,4]<=[16]} - p3 = bf16[16,16] parameter(3), sharding={devices=[4,4]<=[16]} - dot.1 = bf16[16,16] dot(p2, p3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - add = bf16[16,16] add(dot.0, dot.1) - ROOT copy = bf16[16,16] copy(add) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - for (absl::string_view name : {"dot.0", "dot.1", "add"}) { - auto* instruction = FindInstruction(module.get(), name); - EXPECT_THAT(instruction, op::Sharding("{devices=[4,4]<=[16]}")); - } -} - TEST_F(ShardingPropagationTest, AsyncInstructionManualShardingArray) { const char* const hlo_string = R"( HloModule module From 92917da034c025930a0b20fb9deefd0c303cb2bc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 20:52:21 -0700 Subject: [PATCH 112/483] Automated Code Change PiperOrigin-RevId: 677609443 --- .../core/distributed_runtime/rpc/grpc_tensor_coding.cc | 5 +++-- .../core/distributed_runtime/rpc/grpc_tensor_coding_test.cc | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc index e7c5d68bc1ca65..33f40b9d39fa63 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc @@ -164,7 +164,8 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, } else { // skeleton is the encoded TensorProto contents (dtype and shape), but // not the actual data - gtl::InlinedVector skeleton(SkeletonEncodingSizeUpperBound(val)); + absl::InlinedVector skeleton( + SkeletonEncodingSizeUpperBound(val)); io::ProtoEncodeHelper e_skeleton(skeleton.data(), skeleton.size()); EncodeSkeleton(val, &e_skeleton); @@ -196,7 +197,7 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, // Encode all but the actual "tdata", but including the tag and // varlength header for the "tdata" - gtl::InlinedVector space(encoder_size); + absl::InlinedVector space(encoder_size); io::ProtoEncodeHelper e(space.data(), space.size()); // (A) e.WriteRawBytes(header); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc index ef28ad6667291b..f4b36334237a09 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc @@ -65,7 +65,7 @@ class GrpcTensorCodingTest : public ::testing::Test { } } void DoTestForStrings(DataType dt) { - gtl::InlinedVector v; + absl::InlinedVector v; for (int elems = 0; elems <= 10000; elems++) { if (elems < 100 || (elems % 1000 == 0)) { Tensor a(dt, TensorShape({1, static_cast(v.size())})); From bc33f6b3ae0f8e778daa2e1daf11ca93d2f93d6b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 22:16:03 -0700 Subject: [PATCH 113/483] Automated Code Change PiperOrigin-RevId: 677627360 --- tensorflow/lite/core/BUILD | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index 0f23c4b599e407..48b296170f00df 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -164,10 +164,7 @@ cc_library( "signature_runner.h", ], compatible_with = get_compatible_with_portable(), - visibility = [ - "//research/drishti/benchmarking/async:__subpackages__", - "//tensorflow/lite:__subpackages__", - ], + visibility = ["//tensorflow/lite:__subpackages__"], deps = [ ":model_builder", ":signature_runner", @@ -306,7 +303,6 @@ cc_library( compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow/lite:__pkg__", - "//tensorflow/lite/c:__subpackages__", "//tensorflow/lite/core:__subpackages__", ], deps = [ From fbb8ede2c95f8f7e749bb2fcf381424bebd1350f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 22:56:20 -0700 Subject: [PATCH 114/483] Remove now-obsolete comments and reformat. Remove old dependencies. PiperOrigin-RevId: 677638748 --- tensorflow/compiler/tf2xla/BUILD | 3 --- .../compiler/tf2xla/rearrange_function_argument.cc | 13 ++++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index da120ead9d7d62..257720495c651e 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1263,10 +1263,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:graph", "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:optional", "@local_xla//xla:status_macros", ], ) diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index 479d8a230644b3..1783833b0b2bd2 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -159,9 +159,8 @@ Status ReorderOutputEdges(Graph* g, Node* n, int input_count, // Given mapping between original input index and rearranged input index, change // "index" attribute for _Arg nodes. -void RearrangeArgNodes( - const absl::InlinedVector* arg_nodes, // non-absl ok - const std::vector& index_mapping) { +void RearrangeArgNodes(const absl::InlinedVector* arg_nodes, + const std::vector& index_mapping) { for (int i = 0; i < arg_nodes->size(); i++) { Node* n = (*arg_nodes)[i]; int new_index = index_mapping.at(i); @@ -177,7 +176,7 @@ void RearrangeArgNodes( // hold mapping from DT_RESOURCE _Retval index to its input _Arg index. Here we // assume that all DT_RESOURCE _Retval nodes come from _Arg nodes directly. Status CalculateRetvalRearrange( - const absl::InlinedVector& ret_nodes, // non-absl ok + const absl::InlinedVector& ret_nodes, std::map* retval_index_mapping, std::map* resource_retval_to_arg) { for (int i = 0, end = ret_nodes.size(); i < end; i++) { @@ -258,9 +257,9 @@ Status RearrangeOutputEdges(Node* n, Graph* g, // Given mapping between original output index and rearranged output index, // change "index" attribute for _Retval nodes. Notice that DT_RESOURCE _Retval // nodes will be removed. -void RearrangeRetvalNodes( - const absl::InlinedVector& ret_nodes, // non-absl ok - Graph* g, const std::map& retval_index_mapping) { +void RearrangeRetvalNodes(const absl::InlinedVector& ret_nodes, + Graph* g, + const std::map& retval_index_mapping) { for (int i = 0, end = ret_nodes.size(); i < end; i++) { Node* n = ret_nodes[i]; auto iter = retval_index_mapping.find(i); From 6e56796db11fb3c0fc87946b36d277135d96139b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Sep 2024 23:02:43 -0700 Subject: [PATCH 115/483] Automated Code Change PiperOrigin-RevId: 677640988 --- third_party/xla/xla/service/gpu/outfeed_manager.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/service/gpu/outfeed_manager.cc b/third_party/xla/xla/service/gpu/outfeed_manager.cc index 000531bde3e020..2d76f26ffd3348 100644 --- a/third_party/xla/xla/service/gpu/outfeed_manager.cc +++ b/third_party/xla/xla/service/gpu/outfeed_manager.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" namespace xla { From 6378e671e0158084e46239557217ce43392c9a8e Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:16:05 +0530 Subject: [PATCH 116/483] Updated tf.compat.as_str API with Args, Returns and Raises in compat.py Hi, Please review this PR. The API `tf.compat.as_str` acts as an alias for the `tf.compat.as_text` function. Updated the same in addition to Args, Returns and Raises. Thank You --- tensorflow/python/util/compat.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py index 7a4659e0f62251..6e4d3d8db902cf 100644 --- a/tensorflow/python/util/compat.py +++ b/tensorflow/python/util/compat.py @@ -109,7 +109,20 @@ def as_text(bytes_or_text, encoding='utf-8'): def as_str(bytes_or_text, encoding='utf-8'): - return as_text(bytes_or_text, encoding) + """Acts as an alias for the `as_text` function.. + + Args: + bytes_or_text: The input value to be converted. A bytes or unicode object. + encoding: Optional string. The encoding to use if bytes_or_text is a bytes object. Defaults to 'utf-8'. + + Returns: + A unicode string. + + Raises: + TypeError: If bytes_or_text is not a bytes or unicode object. + UnicodeDecodeError: If bytes_or_text is a bytes object and cannot be decoded using the specified encoding. + """ +return as_text(bytes_or_text, encoding) tf_export('compat.as_text')(as_text) tf_export('compat.as_bytes')(as_bytes) From 6237c86de0e5b3b32c9f793a4c62154216b680df Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:17:59 +0530 Subject: [PATCH 117/483] Update compat.py --- tensorflow/python/util/compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py index 6e4d3d8db902cf..161585fcaf5c53 100644 --- a/tensorflow/python/util/compat.py +++ b/tensorflow/python/util/compat.py @@ -121,8 +121,8 @@ def as_str(bytes_or_text, encoding='utf-8'): Raises: TypeError: If bytes_or_text is not a bytes or unicode object. UnicodeDecodeError: If bytes_or_text is a bytes object and cannot be decoded using the specified encoding. - """ -return as_text(bytes_or_text, encoding) + """ + return as_text(bytes_or_text, encoding) tf_export('compat.as_text')(as_text) tf_export('compat.as_bytes')(as_bytes) From f8d0a13a85d29629a442d5d493c6d0e27b12470c Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Sun, 22 Sep 2024 23:48:02 -0700 Subject: [PATCH 118/483] [HLO Componentization] Create hlo/translate sub-component (Phase II). This CL takes care of 1. Migrating external projects dependencies from xla/translate --> xla/hlo/translate Phase I takes care of 1. Migrating xla/translate --> xla/hlo/translate 2. Setting up build aliases in xla/translate ensuring external dependencies are still satisfied. PiperOrigin-RevId: 677653485 --- tensorflow/compiler/mlir/tensorflow/transforms/BUILD | 6 +++--- .../tensorflow/transforms/set_tpu_infeed_layout.cc | 2 +- .../mlir/tensorflow/transforms/shape_inference.cc | 4 ++-- tensorflow/compiler/mlir/tf2xla/api/v1/BUILD | 8 ++++---- .../compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc | 6 +++--- .../mlir/tf2xla/api/v1/compile_tf_graph_test.cc | 2 +- tensorflow/compiler/mlir/tf2xla/transforms/BUILD | 12 ++++++------ .../transforms/infeed_ops_xla_adjust_layout.cc | 2 +- .../compiler/mlir/tf2xla/transforms/legalize_tf.cc | 2 +- .../tf2xla/transforms/legalize_tf_communication.cc | 2 +- .../mlir/tf2xla/transforms/tf2xla_rewriter.cc | 6 +++--- tensorflow/compiler/tf2xla/kernels/BUILD | 6 +++--- .../tf2xla/kernels/xla_call_module_loader.cc | 2 +- .../compiler/tf2xla/kernels/xla_call_module_op.cc | 4 ++-- third_party/xla/xla/pjrt/BUILD | 4 ++-- third_party/xla/xla/pjrt/mlir_to_hlo.cc | 2 +- third_party/xla/xla/pjrt/pjrt_c_api_client.cc | 2 +- third_party/xla/xla/service/gpu/fusions/mlir/BUILD | 6 +++--- .../gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 2 +- .../xla/xla/service/gpu/fusions/mlir/type_util.cc | 2 +- third_party/xla/xla/service/gpu/runtime/BUILD | 4 ++-- .../xla/service/gpu/runtime/nccl_collective_thunk.h | 2 +- third_party/xla/xla/service/gpu/runtime/thunk.cc | 2 +- 23 files changed, 45 insertions(+), 45 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index fd0668055a3dc5..8b2dc042611872 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -826,9 +826,9 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:window_util", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_utils", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/service:shape_inference", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_utils", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/tsl/util:env_var", ], ) @@ -1069,10 +1069,10 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_xla//xla:shape_util", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:tpu_api", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc index 0eb552208194e1..6fa61e1bde3d93 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc @@ -27,12 +27,12 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/tpu_api.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index d3630226ed1f32..2cee935dc96f23 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -83,10 +83,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/tsl/util/env_var.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index daf7311870fe10..c063123616b489 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -63,12 +63,12 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:layout_util", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/mlir_hlo:mhlo_passes", - "@local_xla//xla/translate/mhlo_to_hlo:layout_util", - "@local_xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@stablehlo//:register", ], ) @@ -183,8 +183,8 @@ tf_cc_test( "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla/client:client_library", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/stream_executor:platform_manager", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/tsl/lib/core:status_test_util", "@local_xla//xla/tsl/lib/monitoring:test_utils", ], diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 8ed7d1ea727867..15a749d75f2016 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -72,13 +72,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/shape.h" -#include "xla/translate/mhlo_to_hlo/layout_util.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/error_payloads.h" diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc index c1af969caef8cc..52ead7efb325c1 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc @@ -26,9 +26,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/shape.h" #include "xla/stream_executor/platform_manager.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/monitoring/test_utils.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 4092a86dd38e5c..425909163fa04f 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -172,9 +172,9 @@ cc_library( "@local_xla//xla/client:padding", "@local_xla//xla/client:sharding_builder", "@local_xla//xla/client/lib:conv_grad_size_util", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:convert_op_folder", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@stablehlo//:chlo_ops", ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), ) @@ -295,14 +295,14 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:padding", "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:convert_op_folder", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:type_conversion", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:tpu_api", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", ], @@ -357,11 +357,11 @@ cc_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_function_importer", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc b/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc index b2ce3f56ef9960..f1e843b81f5476 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc @@ -32,12 +32,12 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/tpu_api.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 9f45164ba4dfe3..5764e08f71463e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -62,10 +62,10 @@ limitations under the License. #include "xla/client/lib/conv_grad_size_util.h" #include "xla/client/padding.h" #include "xla/client/sharding_builder.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" #include "xla/mlir_hlo/utils/hlo_utils.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/rng_alg.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index 68c412f79ff393..64de9790e0ddb0 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -43,10 +43,10 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/client/sharding_builder.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" #include "xla/side_effect_util.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index 228df862f82d3c..2142533e37f98e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -65,11 +65,11 @@ limitations under the License. #include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/attr_value.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 156ddbb3581222..6200574f76fffc 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -361,12 +361,12 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@local_xla//xla/mlir/utils:type_util", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/python:refine_polymorphic_shapes", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", @@ -407,11 +407,11 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/service:hlo_module_config", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 252d746727f443..e35ad96962589b 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -60,13 +60,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "xla/client/xla_computation.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/python/refine_polymorphic_shapes.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index 0438e648f4fe1c..251798c838c6bc 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -58,13 +58,13 @@ limitations under the License. #include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 48439f796f1a7e..978ce97ce210ba 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -596,12 +596,12 @@ cc_library( "//xla:debug_options_flags", "//xla:util", "//xla/client:xla_computation", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/mlir/utils:error_util", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -788,6 +788,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_helpers", @@ -799,7 +800,6 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/tsl/framework:allocator", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo.cc b/third_party/xla/xla/pjrt/mlir_to_hlo.cc index 90e904a533c98c..d5fbdf47ca0619 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo.cc +++ b/third_party/xla/xla/pjrt/mlir_to_hlo.cc @@ -58,12 +58,12 @@ limitations under the License. #include "stablehlo/dialect/Version.h" #include "stablehlo/transforms/Passes.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 7319ea3942145c..062ac13011aedf 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -43,6 +43,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -67,7 +68,6 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/tsl/framework/allocator.h" #include "xla/util.h" #include "xla/xla.pb.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index 4bf1a10b2f7025..ba07c39a2bd718 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -23,11 +23,11 @@ cc_library( "//xla:union_find", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions:fusion_emitter", "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:llvm_util", - "//xla/translate/hlo_to_mhlo:hlo_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -72,6 +72,7 @@ cc_library( "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/mlir/utils:type_util", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", @@ -82,7 +83,6 @@ cc_library( "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", - "//xla/translate/hlo_to_mhlo:hlo_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -278,8 +278,8 @@ cc_library( deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/mlir/utils:type_util", - "//xla/translate/hlo_to_mhlo:hlo_utils", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 68b3ab29359214..292659c1916898 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -66,6 +66,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" #include "xla/primitive_util.h" @@ -77,7 +78,6 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc b/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc index 81e568f956c3b2..76d4b284ebc331 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc @@ -20,11 +20,11 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/layout_util.h" #include "xla/mlir/utils/type_util.h" #include "xla/primitive_util.h" #include "xla/shape.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 773106baf16ac5..4d6e321f26b0bc 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -893,6 +893,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", @@ -906,7 +907,6 @@ cc_library( "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/translate/mhlo_to_hlo:attribute_exporter", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -1109,6 +1109,7 @@ cc_library( "//xla:executable_run_options", "//xla/ffi:execution_context", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:location_exporter", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service:global_device_id", @@ -1117,7 +1118,6 @@ cc_library( "//xla/service/gpu:gpu_executable_run_options", "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor", - "//xla/translate/mhlo_to_hlo:location_exporter", "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h index 2a549cdd81f520..d4c2eab9e9e5d8 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/global_device_id.h" @@ -48,7 +49,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/stream.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index d2ca3ae1184b52..90319ff9ee1651 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_allocations.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/stream.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" #include "tsl/platform/statusor.h" namespace xla { From e95c41518b0b8d6f92cb83a7ecd7134f22e4dde3 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 23 Sep 2024 00:03:59 -0700 Subject: [PATCH 119/483] PR #17394: Parameterize Float tests in literal_test Imported from GitHub PR https://github.com/openxla/xla/pull/17394 `literal_test.cc` contains multiple similar tests for Float types Changes: - Parameterize the float tests in literal_test.cc Copybara import of the project: -- fa74adf95af5331517795557b855321c6fef6358 by Alexander Pivovarov : Parameterize Float tests in literal_test Merging this change closes #17394 PiperOrigin-RevId: 677657685 --- third_party/xla/xla/literal_test.cc | 234 +++++++++------------------- 1 file changed, 77 insertions(+), 157 deletions(-) diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index 8a351e2262b69e..dd9c1df6a3eb24 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -121,6 +121,16 @@ class LiteralUtilTest : public ::testing::Test { Literal literal_r4_2x2x3x3_dim0minor_; }; +template +class LiteralUtilFloatTest : public LiteralUtilTest {}; + +using FloatTypes = + ::testing::Types; + +TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); + TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); EXPECT_EQ("pred[] true", true_lit.ToString()); @@ -1140,34 +1150,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { EXPECT_EQ(output, expected); } -TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { - Literal output(ShapeUtil::MakeShape(BF16, {})); - bfloat16 h(0.25f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR0(h); +TYPED_TEST(LiteralUtilFloatTest, PopulateWithValueR0Float) { + Literal output(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {})); + TypeParam h(0.25f); + output.PopulateWithValue(h); + auto expected = LiteralUtil::CreateR0(h); EXPECT_EQ(output, expected); } -TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { - Literal output(ShapeUtil::MakeShape(BF16, {3})); - bfloat16 h(0.5f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR1({h, h, h}); +TYPED_TEST(LiteralUtilFloatTest, PopulateWithValueR1Float) { + Literal output(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {3})); + TypeParam h(0.5f); + output.PopulateWithValue(h); + auto expected = LiteralUtil::CreateR1({h, h, h}); EXPECT_EQ(output, expected); } -TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { - Literal output(ShapeUtil::MakeShape(BF16, {2, 2})); - bfloat16 h(2.0f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { - Literal output(ShapeUtil::MakeShape(F32, {})); - output.PopulateWithValue(2.5f); - auto expected = LiteralUtil::CreateR0(2.5f); +TYPED_TEST(LiteralUtilFloatTest, PopulateWithValueR2Float) { + Literal output(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {2, 2})); + TypeParam h(2.0f); + output.PopulateWithValue(h); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, expected); } @@ -1201,70 +1207,6 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C128) { EXPECT_EQ(output, expected); } -TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { - Literal output(ShapeUtil::MakeShape(F16, {})); - half h(0.25f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { - Literal output(ShapeUtil::MakeShape(F16, {3})); - half h(0.5f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { - Literal output(ShapeUtil::MakeShape(F16, {2, 2})); - half h(2.0f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR0F8e5m2) { - Literal output(ShapeUtil::MakeShape(F8E5M2, {})); - tsl::float8_e5m2 x(0.25f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR0(x); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3) { - Literal output(ShapeUtil::MakeShape(F8E4M3FN, {3})); - tsl::float8_e4m3fn x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3b11) { - Literal output(ShapeUtil::MakeShape(F8E4M3B11FNUZ, {3})); - tsl::float8_e4m3b11fnuz x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3fnuz) { - Literal output(ShapeUtil::MakeShape(F8E4M3FNUZ, {3})); - tsl::float8_e4m3fnuz x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F8e5m2fnuz) { - Literal output(ShapeUtil::MakeShape(F8E5M2FNUZ, {3})); - tsl::float8_e5m2fnuz x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); - EXPECT_EQ(output, expected); -} - TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); @@ -1745,92 +1687,70 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { EXPECT_EQ(c128.Convert(S32).status().code(), tsl::error::UNIMPLEMENTED); } -TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { - auto s8 = LiteralUtil::CreateR2WithLayout({{0, 1}, {2, 3}}, - layout_r2_dim0major_); - auto f32 = LiteralUtil::CreateR2WithLayout({{0., 1.}, {2., 3.}}, - layout_r2_dim0major_); - auto c128 = LiteralUtil::CreateR2WithLayout({{0., 1.}, {2., 3.}}, - layout_r2_dim0major_); - using e5 = tsl::float8_e5m2; - auto f8e5m2 = LiteralUtil::CreateR2WithLayout( - {{e5{0.}, e5{1.}}, {e5{2.}, e5{3.}}}, layout_r2_dim0major_); - using e4 = tsl::float8_e4m3fn; - auto f8e4m3 = LiteralUtil::CreateR2WithLayout( - {{e4{0.}, e4{1.}}, {e4{2.}, e4{3.}}}, layout_r2_dim0major_); - using b11 = tsl::float8_e4m3b11fnuz; - auto f8e4m3b11 = LiteralUtil::CreateR2WithLayout( - {{b11{0.}, b11{1.}}, {b11{2.}, b11{3.}}}, layout_r2_dim0major_); - using e5f = tsl::float8_e5m2fnuz; - auto f8e5m2fnuz = LiteralUtil::CreateR2WithLayout( - {{e5f{0.}, e5f{1.}}, {e5f{2.}, e5f{3.}}}, layout_r2_dim0major_); - using e4f = tsl::float8_e4m3fnuz; - auto f8e4m3fnuz = LiteralUtil::CreateR2WithLayout( - {{e4f{0.}, e4f{1.}}, {e4f{2.}, e4f{3.}}}, layout_r2_dim0major_); - Literal conv; - - conv = s8.Convert(F8E5M2).value(); - EXPECT_EQ(conv, f8e5m2); - - conv = f32.Convert(F8E5M2).value(); - EXPECT_EQ(conv, f8e5m2); - - conv = f8e4m3.Convert(F8E5M2).value(); - EXPECT_EQ(conv, f8e5m2); - - conv = s8.Convert(F8E4M3FN).value(); - EXPECT_EQ(conv, f8e4m3); - - conv = f32.Convert(F8E4M3FN).value(); - EXPECT_EQ(conv, f8e4m3); - - conv = f8e5m2.Convert(F8E4M3FN).value(); - EXPECT_EQ(conv, f8e4m3); +TYPED_TEST(LiteralUtilFloatTest, ConvertIfTypesMatchF8) { + constexpr auto ptype = primitive_util::NativeToPrimitiveType(); + if (!primitive_util::IsF8Type(ptype)) { + GTEST_SKIP() << "Skipping test for non F8 types"; + } + auto s8 = LiteralUtil::CreateR2WithLayout( + {{0, 1}, {2, 3}}, LiteralUtilTest::layout_r2_dim0major_); + auto bf16 = LiteralUtil::CreateR2WithLayout( + {{bfloat16(0.), bfloat16(1.)}, {bfloat16(2.), bfloat16(3.)}}, + LiteralUtilTest::layout_r2_dim0major_); + auto f32 = LiteralUtil::CreateR2WithLayout( + {{0., 1.}, {2., 3.}}, LiteralUtilTest::layout_r2_dim0major_); + auto c128 = LiteralUtil::CreateR2WithLayout( + {{0., 1.}, {2., 3.}}, LiteralUtilTest::layout_r2_dim0major_); + // Let's also use a couple of popular F8 types as sources for conversion + using f8e5m2_t = tsl::float8_e5m2; + auto f8e5m2 = LiteralUtil::CreateR2WithLayout( + {{f8e5m2_t{0.}, f8e5m2_t{1.}}, {f8e5m2_t{2.}, f8e5m2_t{3.}}}, + LiteralUtilTest::layout_r2_dim0major_); + using e4m3fn_t = tsl::float8_e4m3fn; + auto f8e4m3fn = LiteralUtil::CreateR2WithLayout( + {{e4m3fn_t{0.}, e4m3fn_t{1.}}, {e4m3fn_t{2.}, e4m3fn_t{3.}}}, + LiteralUtilTest::layout_r2_dim0major_); + + auto f8 = LiteralUtil::CreateR2WithLayout( + {{TypeParam{0.}, TypeParam{1.}}, {TypeParam{2.}, TypeParam{3.}}}, + LiteralUtilTest::layout_r2_dim0major_); - conv = f8e5m2.Convert(S8).value(); - EXPECT_EQ(conv, s8); + Literal conv; - conv = f8e5m2.Convert(F32).value(); - EXPECT_EQ(conv, f32); + // Convert to f8 + conv = s8.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e5m2.Convert(C128).value(); - EXPECT_EQ(conv, c128); + conv = bf16.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e4m3.Convert(S8).value(); - EXPECT_EQ(conv, s8); + conv = f32.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e4m3.Convert(F32).value(); - EXPECT_EQ(conv, f32); + conv = f8e5m2.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e4m3.Convert(C128).value(); - EXPECT_EQ(conv, c128); + conv = f8e4m3fn.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e4m3b11.Convert(S8).value(); + // Convert from f8 + conv = f8.Convert(S8).value(); EXPECT_EQ(conv, s8); - conv = f8e4m3b11.Convert(F32).value(); - EXPECT_EQ(conv, f32); - - conv = f8e4m3b11.Convert(C128).value(); - EXPECT_EQ(conv, c128); - - conv = f8e5m2fnuz.Convert(S8).value(); - EXPECT_EQ(conv, s8); + conv = f8.Convert(BF16).value(); + EXPECT_EQ(conv, bf16); - conv = f8e5m2fnuz.Convert(F32).value(); + conv = f8.Convert(F32).value(); EXPECT_EQ(conv, f32); - conv = f8e5m2fnuz.Convert(C128).value(); + conv = f8.Convert(C128).value(); EXPECT_EQ(conv, c128); - conv = f8e4m3fnuz.Convert(S8).value(); - EXPECT_EQ(conv, s8); - - conv = f8e4m3fnuz.Convert(F32).value(); - EXPECT_EQ(conv, f32); + conv = f8.Convert(F8E5M2).value(); + EXPECT_EQ(conv, f8e5m2); - conv = f8e4m3fnuz.Convert(C128).value(); - EXPECT_EQ(conv, c128); + conv = f8.Convert(F8E4M3FN).value(); + EXPECT_EQ(conv, f8e4m3fn); } TEST_F(LiteralUtilTest, BitcastConvert) { From 471e2b5899c11540e4721005d3c9c9f3e582c910 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:47:59 +0530 Subject: [PATCH 120/483] Update compat.py --- tensorflow/python/util/compat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py index 161585fcaf5c53..3a721dc0a8fe89 100644 --- a/tensorflow/python/util/compat.py +++ b/tensorflow/python/util/compat.py @@ -113,14 +113,16 @@ def as_str(bytes_or_text, encoding='utf-8'): Args: bytes_or_text: The input value to be converted. A bytes or unicode object. - encoding: Optional string. The encoding to use if bytes_or_text is a bytes object. Defaults to 'utf-8'. + encoding: Optional string. The encoding to use if bytes_or_text is + a bytes object. Defaults to 'utf-8'. Returns: A unicode string. Raises: TypeError: If bytes_or_text is not a bytes or unicode object. - UnicodeDecodeError: If bytes_or_text is a bytes object and cannot be decoded using the specified encoding. + UnicodeDecodeError: If bytes_or_text is a bytes object and cannot be + decoded using the specified encoding. """ return as_text(bytes_or_text, encoding) From 1bfd3166a92eb72c0ddab481ff16adee071640f5 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 23 Sep 2024 00:05:47 -0700 Subject: [PATCH 121/483] Rename no_rocm tag to cuda-only This renames the tag `no_rocm` to `cuda-only` and prepares for the introduction of a new `rocm-only` tag. This part of a bigger effort to unify the build graph and get rid of too many compilation modes than we can't all test. PiperOrigin-RevId: 677658239 --- .../workflows/bazel_dependency_violations.yml | 6 ++-- .../xla/build_tools/configure/configure.py | 9 +++++- .../configure/testdata/cuda_clang.bazelrc | 8 +++--- .../testdata/default_cuda_clang.bazelrc | 8 +++--- .../configure/testdata/nvcc_clang.bazelrc | 8 +++--- .../configure/testdata/nvcc_gcc.bazelrc | 8 +++--- .../xla/build_tools/dependencies/aspects.bzl | 8 +++--- third_party/xla/build_tools/lint/tags.py | 2 +- third_party/xla/build_tools/rocm/run_xla.sh | 2 +- .../experiments/sm_bandwidth_benchmark/BUILD | 6 ++-- third_party/xla/xla/service/BUILD | 2 +- third_party/xla/xla/service/gpu/BUILD | 10 +++---- .../xla/xla/service/gpu/autotuning/BUILD | 9 +++--- third_party/xla/xla/service/gpu/kernels/BUILD | 28 +++++++++---------- third_party/xla/xla/service/gpu/tests/BUILD | 14 +++++----- .../xla/xla/service/gpu/transforms/BUILD | 4 +-- .../xla/xla/stream_executor/build_defs.bzl | 2 +- .../xla/xla/stream_executor/cuda/BUILD | 10 +++---- third_party/xla/xla/stream_executor/gpu/BUILD | 2 +- third_party/xla/xla/tests/BUILD | 10 +++---- third_party/xla/xla/tests/build_defs.bzl | 4 +-- third_party/xla/xla/tests/fuzz/BUILD | 2 +- third_party/xla/xla/tools/BUILD | 2 +- third_party/xla/xla/tools/hlo_opt/BUILD | 2 +- .../xla/xla/tools/multihost_hlo_runner/BUILD | 2 +- 25 files changed, 88 insertions(+), 80 deletions(-) diff --git a/third_party/xla/.github/workflows/bazel_dependency_violations.yml b/third_party/xla/.github/workflows/bazel_dependency_violations.yml index 988a84fed8a457..43c85576c2ba38 100644 --- a/third_party/xla/.github/workflows/bazel_dependency_violations.yml +++ b/third_party/xla/.github/workflows/bazel_dependency_violations.yml @@ -29,7 +29,7 @@ jobs: dependency-violations: strategy: matrix: - tag: [gpu, no_rocm] + tag: [gpu, cuda-only] name: no-${{ matrix.tag }}-targets-in-cpu-build runs-on: ubuntu-22.04 defaults: @@ -45,7 +45,9 @@ jobs: - name: "Run bazel cquery ... //xla/..." run: | set -euo pipefail - OUTPUT=$(bazelisk cquery --aspects build_tools/dependencies/aspects.bzl%validate_${{ matrix.tag }}_tag //xla/... 2>&1) + TAG_WITH_UNDERSCORES="${{ matrix.tag }}" + TAG_WITH_UNDERSCORES="${TAG_WITH_UNDERSCORES/-/_}" + OUTPUT=$(bazelisk cquery --aspects build_tools/dependencies/aspects.bzl%validate_${TAG_WITH_UNDERSCORES}_tag //xla/... 2>&1) if echo "$OUTPUT" | grep 'Violation' >/dev/null; then echo "The following dependency violations were found:" echo "$OUTPUT" | grep 'Violation' | sed -e 's/^.*\[Violation\]/ -/' diff --git a/third_party/xla/build_tools/configure/configure.py b/third_party/xla/build_tools/configure/configure.py index 43e0f234d49cfd..508307bb899953 100755 --- a/third_party/xla/build_tools/configure/configure.py +++ b/third_party/xla/build_tools/configure/configure.py @@ -299,6 +299,9 @@ def to_bazelrc_lines( build_and_test_tag_filters.append("-gpu") elif self.backend == Backend.CUDA: + build_and_test_tag_filters.append("-rocm-only") + build_and_test_tag_filters.append("-sycl-only") + compiler_pair = self.cuda_compiler, self.host_compiler if compiler_pair == (CudaCompiler.CLANG, HostCompiler.CLANG): @@ -347,8 +350,12 @@ def to_bazelrc_lines( if not self.using_nccl: rc.append("build --config nonccl") elif self.backend == Backend.ROCM: - pass + build_and_test_tag_filters.append("-cuda-only") + build_and_test_tag_filters.append("-sycl-only") elif self.backend == Backend.SYCL: + build_and_test_tag_filters.append("-cuda-only") + build_and_test_tag_filters.append("-rocm-only") + rc.append("build --config sycl") # Lines that are added for every backend diff --git a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc index 502bc8541c1285..3f42ca9e563aa2 100644 --- a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc @@ -15,7 +15,7 @@ test --test_size_filters small,medium build --copt -Wno-sign-compare build --copt -Wno-error=unused-command-line-argument build --copt -Wno-gnu-offsetof-extensions -build --build_tag_filters -no_oss -build --test_tag_filters -no_oss -test --build_tag_filters -no_oss -test --test_tag_filters -no_oss +build --build_tag_filters -no_oss,-rocm-only,-sycl-only +build --test_tag_filters -no_oss,-rocm-only,-sycl-only +test --build_tag_filters -no_oss,-rocm-only,-sycl-only +test --test_tag_filters -no_oss,-rocm-only,-sycl-only diff --git a/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc index 4623f6f52073fa..04b79c87aed6ab 100644 --- a/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc @@ -13,7 +13,7 @@ test --test_size_filters small,medium build --copt -Wno-sign-compare build --copt -Wno-error=unused-command-line-argument build --copt -Wno-gnu-offsetof-extensions -build --build_tag_filters -no_oss -build --test_tag_filters -no_oss -test --build_tag_filters -no_oss -test --test_tag_filters -no_oss +build --build_tag_filters -no_oss,-rocm-only,-sycl-only +build --test_tag_filters -no_oss,-rocm-only,-sycl-only +test --build_tag_filters -no_oss,-rocm-only,-sycl-only +test --test_tag_filters -no_oss,-rocm-only,-sycl-only diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc index 8cd19224698311..b56236998fe166 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc @@ -15,7 +15,7 @@ test --test_size_filters small,medium build --copt -Wno-sign-compare build --copt -Wno-error=unused-command-line-argument build --copt -Wno-gnu-offsetof-extensions -build --build_tag_filters -no_oss -build --test_tag_filters -no_oss -test --build_tag_filters -no_oss -test --test_tag_filters -no_oss +build --build_tag_filters -no_oss,-rocm-only,-sycl-only +build --test_tag_filters -no_oss,-rocm-only,-sycl-only +test --build_tag_filters -no_oss,-rocm-only,-sycl-only +test --test_tag_filters -no_oss,-rocm-only,-sycl-only diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc index be90a87545368b..f4d4f72c566e7f 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc @@ -10,7 +10,7 @@ build --python_path /usr/bin/python3 test --test_env LD_LIBRARY_PATH test --test_size_filters small,medium build --copt -Wno-sign-compare -build --build_tag_filters -no_oss -build --test_tag_filters -no_oss -test --build_tag_filters -no_oss -test --test_tag_filters -no_oss +build --build_tag_filters -no_oss,-rocm-only,-sycl-only +build --test_tag_filters -no_oss,-rocm-only,-sycl-only +test --build_tag_filters -no_oss,-rocm-only,-sycl-only +test --test_tag_filters -no_oss,-rocm-only,-sycl-only diff --git a/third_party/xla/build_tools/dependencies/aspects.bzl b/third_party/xla/build_tools/dependencies/aspects.bzl index 76b531112fdee6..40ea09beca11da 100644 --- a/third_party/xla/build_tools/dependencies/aspects.bzl +++ b/third_party/xla/build_tools/dependencies/aspects.bzl @@ -73,10 +73,10 @@ validate_gpu_tag = aspect( attr_aspects = ["deps"], ) -def _no_rocm_tag_violation_aspect_impl(target, ctx): - return _dependency_violation_aspect_impl(target, ctx, "no_rocm") +def _cuda_only_tag_violation_aspect_impl(target, ctx): + return _dependency_violation_aspect_impl(target, ctx, "cuda-only") -validate_no_rocm_tag = aspect( - implementation = _no_rocm_tag_violation_aspect_impl, +validate_cuda_only_tag = aspect( + implementation = _cuda_only_tag_violation_aspect_impl, attr_aspects = ["deps"], ) diff --git a/third_party/xla/build_tools/lint/tags.py b/third_party/xla/build_tools/lint/tags.py index 2ec82cc0113e65..aa555e7ddf63e1 100644 --- a/third_party/xla/build_tools/lint/tags.py +++ b/third_party/xla/build_tools/lint/tags.py @@ -32,7 +32,6 @@ "large": "Conventional tag for `test_suites` of large tests", # Various disable tags (currently recognized by OpenXLA CI) "no_oss": "Test is disabled on OpenXLA CI.", - "no_rocm": "Disabled on ROCm builds.", "no_mac": "Disabled on MacOS.", "no_windows": "Disabled on Windows.", "no_mac_arm64": "Disabled on ARM MacOS.", @@ -65,6 +64,7 @@ "requires-gpu-sm90-only": "Requires exactly sm90.", "gpu": "Catch-all tag for targets that should be built/tested on GPU CI", "cpu": "Catch-all tag for targets that should be built/tested on CPU CI.", + "cuda-only": "Targets that require the CUDA backend to be enabled.", # Below tags are generated by `xla_test`. "broken": "Test will be marked with other tags to disable in `xla_test`.", "xla_interpreter": "Uses interpreter backend.", diff --git a/third_party/xla/build_tools/rocm/run_xla.sh b/third_party/xla/build_tools/rocm/run_xla.sh index d7eee422ec01db..23cc801fc260f0 100755 --- a/third_party/xla/build_tools/rocm/run_xla.sh +++ b/third_party/xla/build_tools/rocm/run_xla.sh @@ -50,7 +50,7 @@ fi export PYTHON_BIN_PATH=`which python3` export TF_NEED_ROCM=1 export ROCM_PATH=$ROCM_INSTALL_DIR -TAGS_FILTER="gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-no_rocm" +TAGS_FILTER="gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-cuda-only" UNSUPPORTED_GPU_TAGS="$(echo -requires-gpu-sm{60,70,80,86,89,90}{,-only})" TAGS_FILTER="${TAGS_FILTER},${UNSUPPORTED_GPU_TAGS// /,}" diff --git a/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD b/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD index 8583582c5690c4..4218993bd106d5 100644 --- a/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD +++ b/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD @@ -7,8 +7,8 @@ cc_library( name = "sm_bw_utils", hdrs = ["sm_bw_utils.h"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ "@local_config_cuda//cuda:cuda_headers", @@ -20,7 +20,7 @@ cuda_library( name = "sm_bw_kernels", srcs = ["sm_bw_kernels.cu.cc"], hdrs = ["sm_bw_kernels.h"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":sm_bw_utils", ], @@ -30,8 +30,8 @@ xla_cc_test( name = "sm_bw_test", srcs = ["sm_bw_test.cc"], tags = [ + "cuda-only", "gpu", - "no_rocm", "requires-gpu-nvidia", ], deps = [ diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 1165e6b19c081c..53d6b29fc5bcd7 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -8363,9 +8363,9 @@ xla_cc_test( ":xla_aot_compile_test_gpu_executable_convolution_runtime_autotuning", ]), tags = [ + "cuda-only", "gpu", "no_oss", - "no_rocm", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. "requires-gpu-sm60-only", ], diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 90dcee55408629..8621ff79c54fff 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1695,9 +1695,9 @@ cc_library( "nvptx_compiler_registration.cc", ], tags = [ + "cuda-only", "gpu", "manual", - "no_rocm", ], deps = [ ":nvptx_compiler_impl", @@ -1717,9 +1717,9 @@ cc_library( "nvptx_compiler.h", ], tags = [ + "cuda-only", "gpu", "manual", - "no_rocm", ], deps = [ ":buffer_sharing", @@ -1829,7 +1829,7 @@ xla_test( "gpu_a100", ], tags = [ - "no_rocm", + "cuda-only", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. ], deps = [ @@ -1867,7 +1867,7 @@ xla_test( "gpu", ], tags = [ - "no_rocm", + "cuda-only", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. ], deps = [ @@ -2261,7 +2261,7 @@ xla_cc_test( cuda_library( name = "stream_executor_util_kernel", srcs = ["stream_executor_util_kernel.cu.cc"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = ["@local_config_cuda//cuda:cuda_headers"], ) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index be63f3888442af..88f55c49da7e2f 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -30,8 +30,8 @@ cc_library( srcs = ["gemm_fusion_autotuner.cc"], hdrs = ["gemm_fusion_autotuner.h"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":autotuner_compile_util", @@ -109,8 +109,8 @@ xla_test( "gpu_h100", ], tags = [ + "cuda-only", "no_mac", - "no_rocm", ], deps = [ ":autotuner_compile_util", @@ -398,7 +398,7 @@ xla_test( "gpu_amd_any", ], tags = [ - "no_rocm", + "cuda-only", "noasan", "nomsan", ], @@ -473,7 +473,7 @@ xla_test( backends = [ "gpu", ], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":autotuner_util", ":custom_kernel_fusion_autotuner", @@ -509,7 +509,6 @@ xla_cc_test( ], tags = [ "gpu", - "no_rocm", ], deps = [ ":autotuner_util", diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 6676cafcc20dd1..986ed6486dc821 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -82,8 +82,8 @@ cc_library( srcs = ["cutlass_gemm_fusion.cc"], hdrs = ["cutlass_gemm_fusion.h"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":custom_kernel", @@ -115,7 +115,7 @@ xla_test( backends = ["gpu"], # TODO(b/332820384): Enable when it passes on H100. disabled_backends = DEFAULT_DISABLED_BACKENDS + ["gpu_h100"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":custom_kernel_fusion_pattern", ":cutlass_gemm_custom_kernel", @@ -264,8 +264,8 @@ cc_library( srcs = ["cutlass_gemm_custom_kernel.cc"], hdrs = ["cutlass_gemm_custom_kernel.h"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":custom_kernel", @@ -288,7 +288,7 @@ xla_test( srcs = ["cutlass_gemm_custom_kernel_test.cc"], backends = ["gpu"], data = [":cutlass_gemm_kernel_f32xf32_to_f32.so"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", @@ -309,8 +309,8 @@ xla_cc_binary( testonly = 1, srcs = ["cutlass_gemm_custom_kernel_benchmarks.cc"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":cutlass_gemm_custom_kernel", @@ -346,7 +346,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), # __grid_constant__ is not supported by clang - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm", "@cutlass_archive//:cutlass", @@ -355,7 +355,7 @@ cuda_library( cuda_library( name = "cutlass_gemm_epilogue", - tags = ["no_rocm"], + tags = ["cuda-only"], # TODO(ezhulenev): Update to regular hdrs after fixing CUTLASS headers. textual_hdrs = ["cutlass_gemm_epilogue.cu.h"], deps = ["@cutlass_archive//:cutlass"], @@ -371,8 +371,8 @@ cuda_library( cc_library( name = "cutlass_gemm_kernels", tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":cutlass_gemm_kernel_bf16xbf16_to_bf16", @@ -399,7 +399,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -414,7 +414,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -432,7 +432,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -450,7 +450,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -463,8 +463,8 @@ cuda_library( srcs = ["cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc"], copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":cutlass_gemm_adaptor", @@ -483,8 +483,8 @@ cc_binary( linkshared = True, linkstatic = False, tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [":cutlass_gemm"], ) diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 2f0d8b366f8c0e..c8632c377da2d7 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -206,7 +206,7 @@ xla_test( "swap_conv_operands_test.cc", ], backends = ["gpu"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":gpu_codegen_test", "//xla:error_spec", @@ -674,11 +674,11 @@ lit_test_suite( default_tags = tf_cuda_tests_tags(), hermetic_cuda_data_dir = "%S/../../../../../cuda_nvcc", tags_override = { - "element_wise_row_vectorization.hlo": ["no_rocm"], - "scatter_bf16.hlo": ["no_rocm"], - "single_instruction.hlo": ["no_rocm"], - "reduce_unnested.hlo": ["no_rocm"], - "reduction_vectorization_sm_all.hlo": ["no_rocm"], + "element_wise_row_vectorization.hlo": ["cuda-only"], + "scatter_bf16.hlo": ["cuda-only"], + "single_instruction.hlo": ["cuda-only"], + "reduce_unnested.hlo": ["cuda-only"], + "reduction_vectorization_sm_all.hlo": ["cuda-only"], }, tools = [ "//xla/tools:hlo-opt", @@ -789,7 +789,7 @@ xla_test( "gpu_a100", "gpu_h100", ], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = if_cuda_is_configured( [ ":gpu_codegen_test", diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index d59bd94e726f52..31c99ac6984708 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -869,7 +869,7 @@ xla_test( "gpu", ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cudnn_fused_mha_rewriter", ":cudnn_fused_mha_transpose_fusion", @@ -1458,8 +1458,8 @@ xla_cc_test( name = "dynamic_slice_fusion_rewriter_test", srcs = ["dynamic_slice_fusion_rewriter_test.cc"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":dynamic_slice_fusion_rewriter", diff --git a/third_party/xla/xla/stream_executor/build_defs.bzl b/third_party/xla/xla/stream_executor/build_defs.bzl index 959de0fa172871..109872e2a0f0df 100644 --- a/third_party/xla/xla/stream_executor/build_defs.bzl +++ b/third_party/xla/xla/stream_executor/build_defs.bzl @@ -81,7 +81,7 @@ def cuda_only_cc_library(name, tags = [], **kwargs): ) cc_library( name = "%s_cuda_only" % name, - tags = tags + ["manual", "no_rocm"], + tags = tags + ["manual", "cuda-only"], **kwargs ) native.alias( diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 6ff53527b930f4..0e4b26fbcb9bd1 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -250,7 +250,7 @@ xla_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], backends = ["gpu"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cuda_diagnostics", ":cuda_driver", @@ -492,7 +492,7 @@ cuda_library( "command_buffer_kernels.cc", "command_buffer_kernels.cu.cc", ], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ "//xla/stream_executor:kernel_spec", "//xla/stream_executor/gpu:gpu_types_header", @@ -584,7 +584,7 @@ cc_library( cc_library( name = "ptx_compiler", hdrs = ["ptx_compiler.h"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = select({ ":libnvptxcompiler_support_enabled": [":ptx_compiler_impl"], "//conditions:default": [":ptx_compiler_stub"], @@ -599,7 +599,7 @@ xla_test( name = "cuda_platform_test", srcs = ["cuda_platform_test.cc"], backends = ["gpu"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cuda_platform", "//xla/stream_executor:platform", @@ -616,7 +616,7 @@ xla_cc_test( name = "ptx_compiler_test", srcs = ["ptx_compiler_test.cc"], tags = [ - "no_rocm", + "cuda-only", # TODO(b/343996893): Figure out whether msan reports a false positive or not. "nomsan", ], diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index c5c4a3fbb1ba75..b123c2fc40ac38 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -645,7 +645,7 @@ xla_test( name = "gpu_cudamallocasync_allocator_test", srcs = ["gpu_cudamallocasync_allocator_test.cc"], backends = ["gpu_any"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":gpu_cudamallocasync_allocator", ":gpu_stream", diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 18d23e96d1cc29..c7dbbee3efe6a7 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1216,7 +1216,7 @@ xla_test( ], shard_count = 50, tags = [ - "no_rocm", + "cuda-only", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1235,7 +1235,7 @@ xla_test( ], shard_count = 50, tags = [ - "no_rocm", + "cuda-only", "optonly", "test_xla_cpu_thunks", ], @@ -1256,7 +1256,7 @@ xla_test( backends = ["gpu"], shard_count = 40, tags = [ - "no_rocm", + "cuda-only", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1273,7 +1273,7 @@ xla_test( backends = ["gpu"], shard_count = 40, tags = [ - "no_rocm", + "cuda-only", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1302,7 +1302,7 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - tags = ["no_rocm"], + tags = ["cuda-only"], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl index 990daa1423aeaa..006811c7b322be 100644 --- a/third_party/xla/xla/tests/build_defs.bzl +++ b/third_party/xla/xla/tests/build_defs.bzl @@ -97,7 +97,7 @@ def prepare_nvidia_gpu_backend_data(backends, disabled_backends, backend_tags, b sm_tag += ":%d" % num_gpus new_backend_tags[gpu_backend] = [t for t in all_tags if t not in requires_gpu] new_backend_tags[gpu_backend].append(sm_tag) - new_backend_tags[gpu_backend].append("no_rocm") + new_backend_tags[gpu_backend].append("cuda-only") return new_backends, new_disabled_backends, new_backend_tags, new_backend_args @@ -130,7 +130,7 @@ def prepare_amd_gpu_backend_data(backends, disabled_backends, backend_tags, back new_backend_tags.setdefault(key, gpu_backend_tags[:]) for backend in AMD_GPU_DEFAULT_BACKENDS: - if "no_rocm" not in gpu_backend_tags: + if "cuda-only" not in gpu_backend_tags: new_backend_tags[backend].append("requires-gpu-amd") new_backend_tags[backend].append("notap") diff --git a/third_party/xla/xla/tests/fuzz/BUILD b/third_party/xla/xla/tests/fuzz/BUILD index fa5dde0ff1c3df..dbd927dc7fd3e8 100644 --- a/third_party/xla/xla/tests/fuzz/BUILD +++ b/third_party/xla/xla/tests/fuzz/BUILD @@ -16,7 +16,7 @@ cc_library( [hlo_test( name = hlo + "_test", hlo = hlo, - tags = (["no_rocm"] if hlo == "rand_000079.hlo" else []), # No int8 + tags = (["cuda-only"] if hlo == "rand_000079.hlo" else []), # No int8 ) for hlo in glob( include = ["rand_*.hlo"], exclude = [ diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index acc2b000972189..06c37dddab3401 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -879,7 +879,7 @@ xla_test( ], tags = [ "config-cuda-only", - "no_rocm", + "cuda-only", ], deps = [ ":xla_compile_lib", diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index d9cefa0eddad8a..79d513bae551b5 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -177,7 +177,7 @@ lit_test_suite( default_tags = tf_cuda_tests_tags(), hermetic_cuda_data_dir = "%S/../../../../cuda_nvcc", tags_override = { - "gpu_hlo_ptx.hlo": ["no_rocm"], + "gpu_hlo_ptx.hlo": ["cuda-only"], }, tools = [ "//xla/tools:hlo-opt", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index 4578fcef45e177..6eb573004e2584 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -83,9 +83,9 @@ xla_cc_binary( name = "hlo_runner_main_gpu", testonly = True, tags = [ + "cuda-only", "gpu", "no_mac", - "no_rocm", ] + tf_gpu_tests_tags(), deps = [ ":hlo_runner_main_lib", From 0a1297f2dc1d7f7e41e950d39f207a877b8445fb Mon Sep 17 00:00:00 2001 From: Siqiao Wu Date: Mon, 23 Sep 2024 00:18:03 -0700 Subject: [PATCH 122/483] Internal changes only PiperOrigin-RevId: 677661687 --- .../tfrt/transforms/ifrt/ifrt_backend_compiler.cc | 8 +++++++- .../tfrt/transforms/ifrt/ifrt_backend_compiler.h | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index 4c77fc0d42e4e1..514322b7cc1dd5 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -153,10 +153,16 @@ absl::Status IfrtBackendCompiler::CompileTensorflow( tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_before", module); } + TfrtTpuCompileOptions options; + options.disable_set_default_tpu_device_and_device_assignment_attributes = + compile_options_ + .disable_set_default_tpu_device_and_device_assignment_attributes; + options.support_multi_dims_sharding = true; + if (tpu_compiler_ != nullptr) { // Run backward compat pass so that we can use bridge to do clustering. if (mlir::failed( - tpu_compiler_->RunTPUBackwardCompatConversion(module, {}))) { + tpu_compiler_->RunTPUBackwardCompatConversion(module, options))) { return diag_handler.Combine( absl::InternalError("Failed to handle legacy TPU Ops")); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h index 085c70812feaed..0dfaa081822862 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h @@ -28,9 +28,22 @@ namespace ifrt_serving { // Implements the custom backend compiler for IFRT based serving in TFRT. class IfrtBackendCompiler : public tensorflow::BackendCompiler { public: + struct Options { + // If true, disable running TFRTSetTPUDeviceAttrPass which set the default + // `tf.device` and `device_assignment` attributes. + // This is a server-level option for now. We can consider to make it a + // per-model option in the future. + bool disable_set_default_tpu_device_and_device_assignment_attributes = true; + }; + explicit IfrtBackendCompiler(TpuCompiler* tpu_compiler = nullptr) : tpu_compiler_(tpu_compiler) {} + explicit IfrtBackendCompiler(const Options& ifrt_backend_compile_options, + TpuCompiler* tpu_compiler = nullptr) + : tpu_compiler_(tpu_compiler), + compile_options_(ifrt_backend_compile_options) {} + void GetDependentDialects(mlir::DialectRegistry& registry) const override { if (tpu_compiler_) { tpu_compiler_->RegisterTPUDialects(®istry); @@ -45,6 +58,7 @@ class IfrtBackendCompiler : public tensorflow::BackendCompiler { private: TpuCompiler* tpu_compiler_; // Not owned. + Options compile_options_; }; } // namespace ifrt_serving From 15244d58bfc59ace7e2d78f9945f911b0f23eecb Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 23 Sep 2024 00:29:30 -0700 Subject: [PATCH 123/483] PR #17476: Fix chlo_legalize_to_mhlo.mlir.test by using CHECK-DAG Imported from GitHub PR https://github.com/openxla/xla/pull/17476 To handle situations where `mhlo.multiply` and `mhlo.subtract` can run in any order, we need to make the test more flexible by using `CHECK-DAG` instead of `CHECK`. ### Testing ```bash bazel test //xla/mlir_hlo/tests:Dialect/chlo/chlo_legalize_to_mhlo.mlir.test ``` Copybara import of the project: -- 18b4bcc03560c59a572eba5c3c4d11b0a169e109 by Alexander Pivovarov : Fix chlo_legalize_to_mhlo.mlir.test by using CHECK-DAG Merging this change closes #17476 PiperOrigin-RevId: 677664963 --- .../Dialect/chlo/chlo_legalize_to_mhlo.mlir | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index ddd8348641cb30..b47a159ff6d654 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -1668,8 +1668,8 @@ func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_30]], %[[TMP_3]] // CHECK: %[[TMP_34:.*]] = mhlo.power %[[TMP_33]], %[[TMP_4]] // CHECK: %[[TMP_35:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_33]] - // CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_35]] + // CHECK-DAG: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_33]] + // CHECK-DAG: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_35]] // CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_36]], %[[TMP_37]] // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_33]] // CHECK: %[[TMP_40:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_39]] @@ -1943,8 +1943,8 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_91]] // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_92]] // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] - // CHECK: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] + // CHECK-DAG: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] + // CHECK-DAG: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_124]], %[[TMP_125]] // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] // CHECK: %[[TMP_128:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_127]] @@ -2330,8 +2330,8 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_91]] // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_92]] // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] - // CHECK: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] + // CHECK-DAG: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] + // CHECK-DAG: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_124]], %[[TMP_125]] // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] // CHECK: %[[TMP_128:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_127]] @@ -2754,13 +2754,13 @@ func.func @atanh_complex_f32(%arg : tensor>) -> tensor // CHECK-NEXT: %[[ABS_IMAG:.*]] = mhlo.abs %[[IMAG]] // CHECK-NEXT: %[[CMP4:.*]] = mhlo.compare LT, %[[ABS_IMAG]], %[[SQUARE]] // CHECK-NEXT: %[[AND:.*]] = mhlo.and %[[CMP3]], %[[CMP4]] - // CHECK-NEXT: %[[SUB0:.*]] = mhlo.subtract %[[ONE]], %[[ABS_REAL]] - // CHECK-NEXT: %[[SQUARE1:.*]] = mhlo.multiply %[[SUB0]], %[[SUB0]] - // CHECK-NEXT: %[[SQUARE2:.*]] = mhlo.multiply %[[IMAG]], %[[IMAG]] - // CHECK-NEXT: %[[ADD0:.*]] = mhlo.add %[[SQUARE1]], %[[SQUARE2]] - // CHECK-NEXT: %[[DIV0:.*]] = mhlo.divide %[[ABS_REAL]], %[[ADD0]] - // CHECK-NEXT: %[[MULT0:.*]] = mhlo.multiply %[[ABS_IMAG]], %[[SELECT2]] - // CHECK-NEXT: %[[CMP5:.*]] = mhlo.compare LT, %[[MULT0]], %[[ABS_REAL]] + // CHECK-NEXT: %[[SUB0:.*]] = mhlo.subtract %[[ONE]], %[[ABS_REAL]] + // CHECK-NEXT: %[[SQUARE1:.*]] = mhlo.multiply %[[SUB0]], %[[SUB0]] + // CHECK-NEXT: %[[SQUARE2:.*]] = mhlo.multiply %[[IMAG]], %[[IMAG]] + // CHECK-NEXT: %[[ADD0:.*]] = mhlo.add %[[SQUARE1]], %[[SQUARE2]] + // CHECK-NEXT: %[[DIV0:.*]] = mhlo.divide %[[ABS_REAL]], %[[ADD0]] + // CHECK-NEXT: %[[MULT0:.*]] = mhlo.multiply %[[ABS_IMAG]], %[[SELECT2]] + // CHECK-NEXT: %[[CMP5:.*]] = mhlo.compare LT, %[[MULT0]], %[[ABS_REAL]] // CHECK-NEXT: %[[DIV1:.*]] = mhlo.divide %[[ONE]], %[[ABS_REAL]] // CHECK-NEXT: %[[ISINF_REAL:.*]] = mhlo.constant dense<0x7F800000> // CHECK-NEXT: %[[ISNINF_REAL:.*]] = mhlo.compare EQ, %[[REAL]], %[[ISINF_REAL]] From 7bd7f705fce5ecff9bc5db6dabda32db61ea5b81 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 23 Sep 2024 00:35:34 -0700 Subject: [PATCH 124/483] PR #17477: Fix XLA_FFI_REGISTER_ macros - global qualification of class name is invalid Imported from GitHub PR https://github.com/openxla/xla/pull/17477 Currently `bazel test //xla/ffi/api:ffi_test` fails with compilation error: ```bash In file included from ./xla/ffi/api/ffi.h:48, from xla/ffi/api/ffi_test.cc:16: ./xla/ffi/api/api.h:1774:38: error: global qualification of class name is invalid before '{' token 1774 | struct ::xla::ffi::AttrDecoding { \ | ^ xla/ffi/api/ffi_test.cc:71:1: note: in expansion of macro 'XLA_FFI_REGISTER_ENUM_ATTR_DECODING' 71 | XLA_FFI_REGISTER_ENUM_ATTR_DECODING(::xla::ffi::Int32BasedEnum); | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ``` To solve "global qualification of class name is invalid" issue we can add `namespace xla::ffi { ` block to the macros and remove `::xla::ffi::` prefix in struct decls inside `XLA_FFI_REGISTER_*` macros ### Testing ``` bazel test //xla/ffi/... INFO: Build completed successfully, 47 total actions //xla/ffi:ffi_test PASSED in 0.5s //xla/ffi:call_frame_test PASSED in 0.1s //xla/ffi:execution_context_test PASSED in 0.1s //xla/ffi:execution_state_test PASSED in 0.1s //xla/ffi:type_id_registry_test PASSED in 0.1s //xla/ffi/api:ffi_test PASSED in 0.5s Executed 6 out of 6 tests: 6 tests pass. ``` ### Related links: - https://github.com/openxla/xla/pull/15747 - https://github.com/openxla/xla/commit/ef49d057bffd4b8ff14bda925d48ea7610aaa856 Copybara import of the project: -- fffa62b2d47feb915c0c6300b0af5540974911d4 by Alexander Pivovarov : Fix XLA_FFI_REGISTER_ macros Merging this change closes #17477 PiperOrigin-RevId: 677666742 --- third_party/xla/xla/ffi/api/api.h | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 952a31eb872388..0e142c42286e12 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -888,7 +888,7 @@ struct CtxDecoding; // XLA_FFI_Error* Encode(const XLA_FFI_Api* api, // XLA_FFI_ExecutionContext* ctx, // absl::Status status) {...} -// } +// }; // // Result encoding is execution stage specific, for example at instantiation // stage FFI handler can return an FFI handler state, while at execution stage @@ -907,7 +907,7 @@ struct CtxDecoding; // std::variant Encode( // const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, // xla::ffi::Future future) {...} -// } +// }; // template struct ResultEncoding; @@ -1744,13 +1744,14 @@ auto DictionaryDecoder(Members... m) { // binding specification inference from a callable signature. // #define XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(T, ...) \ + namespace xla::ffi { \ template <> \ - struct ::xla::ffi::AttrsBinding { \ + struct AttrsBinding { \ using Attrs = T; \ }; \ \ template <> \ - struct ::xla::ffi::AttrDecoding { \ + struct AttrDecoding { \ using Type = T; \ static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ DiagnosticEngine& diagnostic) { \ @@ -1765,13 +1766,17 @@ auto DictionaryDecoder(Members... m) { reinterpret_cast(attr), \ internal::StructMemberNames(__VA_ARGS__), diagnostic); \ } \ - } + }; \ + } /* namespace xla::ffi */ \ + static_assert(std::is_class_v<::xla::ffi::AttrsBinding>); \ + static_assert(std::is_class_v<::xla::ffi::AttrDecoding>) // Registers decoding for a user-defined enum class type. Uses enums underlying // type to decode the attribute as a scalar value and cast it to the enum type. #define XLA_FFI_REGISTER_ENUM_ATTR_DECODING(T) \ + namespace xla::ffi { \ template <> \ - struct ::xla::ffi::AttrDecoding { \ + struct AttrDecoding { \ using Type = T; \ using U = std::underlying_type_t; \ static_assert(std::is_enum::value, "Expected enum class"); \ @@ -1784,7 +1789,8 @@ auto DictionaryDecoder(Members... m) { } \ \ auto* scalar = reinterpret_cast(attr); \ - auto expected_dtype = internal::NativeTypeToCApiDataType(); \ + auto expected_dtype = \ + ::xla::ffi::internal::NativeTypeToCApiDataType(); \ if (XLA_FFI_PREDICT_FALSE(scalar->dtype != expected_dtype)) { \ return diagnostic.Emit("Wrong scalar data type: expected ") \ << expected_dtype << " but got " << scalar->dtype; \ @@ -1793,7 +1799,9 @@ auto DictionaryDecoder(Members... m) { auto underlying = *reinterpret_cast(scalar->value); \ return static_cast(underlying); \ } \ - }; + }; \ + } /* namespace xla::ffi */ \ + static_assert(std::is_class_v<::xla::ffi::AttrDecoding>) //===----------------------------------------------------------------------===// // Helper macro for registering FFI implementations From 063f3b27324d06c931ebabc0c1ae8b687b8330b2 Mon Sep 17 00:00:00 2001 From: Luke Boyer Date: Mon, 23 Sep 2024 01:04:21 -0700 Subject: [PATCH 125/483] Add folder for lrt in runtime. PiperOrigin-RevId: 677674721 --- tensorflow/lite/experimental/lrt/BUILD | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 tensorflow/lite/experimental/lrt/BUILD diff --git a/tensorflow/lite/experimental/lrt/BUILD b/tensorflow/lite/experimental/lrt/BUILD new file mode 100644 index 00000000000000..cd9efefb75dcab --- /dev/null +++ b/tensorflow/lite/experimental/lrt/BUILD @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], +) From 5a74b87f9529d7522055aae15cd4f8a816d2bd6e Mon Sep 17 00:00:00 2001 From: Luke Boyer Date: Mon, 23 Sep 2024 01:25:45 -0700 Subject: [PATCH 126/483] Make unit tests OS friendly PiperOrigin-RevId: 677680731 --- tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD | 2 -- tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD | 1 - 2 files changed, 3 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD index d584670b49730e..3601f00137de02 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD @@ -70,7 +70,6 @@ cc_library( cc_test( name = "model_test", srcs = ["model_test.cc"], - tags = ["no_oss"], deps = [ ":api_internal", ":graph_tools", @@ -98,7 +97,6 @@ cc_library( cc_test( name = "algo_test", srcs = ["algo_test.cc"], - tags = ["no_oss"], deps = [ ":algo", ":api_internal", diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD index fbb21622ab2d30..9be77a87971794 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD @@ -37,7 +37,6 @@ cc_shared_library( cc_test( name = "mul_op_plugin_test", srcs = ["mul_op_plugin_test.cc"], - tags = ["no_oss"], deps = [ ":mul_op_plugin", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", From 2cde48c0d9732bf82bcab77c886fc267101e0d85 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 23 Sep 2024 01:37:12 -0700 Subject: [PATCH 127/483] [XLA:GPU][Emitters] Add layout attribute. PiperOrigin-RevId: 677684803 --- .../xla/xla/service/gpu/fusions/ir/BUILD | 9 +++- .../{indexing_map_attr.mlir => attrs.mlir} | 9 ++++ .../service/gpu/fusions/ir/xla_gpu_attrs.cc | 43 ++++++++++++++++--- .../service/gpu/fusions/ir/xla_gpu_attrs.td | 38 ++++++++++++++++ .../service/gpu/fusions/ir/xla_gpu_dialect.cc | 7 +++ .../xla/service/gpu/fusions/ir/xla_gpu_ops.cc | 2 - .../xla/service/gpu/fusions/ir/xla_gpu_ops.h | 1 + .../service/gpu/fusions/ir/xla_gpu_types.cc | 7 +-- 8 files changed, 100 insertions(+), 16 deletions(-) rename third_party/xla/xla/service/gpu/fusions/ir/tests/{indexing_map_attr.mlir => attrs.mlir} (94%) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD index 16fb100182636e..0ef48df7fd1741 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD @@ -68,6 +68,14 @@ gentbl_cc_library( name = "xla_gpu_attrs_inc_gen", strip_include_prefix = ".", tbl_outs = [ + ( + ["-gen-enum-decls"], + "xla_gpu_enums.h.inc", + ), + ( + ["-gen-enum-defs"], + "xla_gpu_enums.cc.inc", + ), ( [ "-gen-attrdef-decls", @@ -138,7 +146,6 @@ cc_library( "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", - "@stablehlo//:stablehlo_type_inference", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir similarity index 94% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir index 76a74dd7908eca..bc37a3ac56fc7c 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir @@ -146,3 +146,12 @@ func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @empty // CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]> + +// ----- + +func.func private @tensor_layout( + %in0: tensor<42xf32, #xla_gpu.layout<"shmem", + (d0) -> (), domain: d0 in [0, 42], is_simplified: true>>) +// CHECK: #layout = #xla_gpu.layout<"shmem", (d0) -> (), +// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true> +// CHECK: tensor<42xf32, #layout> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index 8fec5e91c9c3a1..cb9ba368702c9f 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include "absl/strings/str_format.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/LogicalResult.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -28,12 +28,9 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" -#define GET_ATTRDEF_LIST -#define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" - namespace xla { namespace gpu { @@ -114,9 +111,9 @@ void PrintConstraints(AsmPrinter& p, } } -mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { +mlir::Attribute parseIndexingMapImpl(mlir::AsmParser& parser) { mlir::AffineMap map; - if (parser.parseLess() || parser.parseAffineMap(map)) { + if (parser.parseAffineMap(map)) { return {}; } @@ -175,6 +172,13 @@ mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { constraints, is_simplified); } +mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { + if (parser.parseLess()) { + return {}; + } + return parseIndexingMapImpl(parser); +} + void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { printer << "<" << getIndexingMap().ToString() << ">"; } @@ -215,5 +219,30 @@ int64_t IndexingMapAttr::getNumResults() const { return getMap().getNumResults(); } +mlir::Attribute LayoutAttr::parse(mlir::AsmParser& parser, mlir::Type) { + mlir::StringAttr memory_space_str; + if (parser.parseLess() || parser.parseAttribute(memory_space_str) || + parser.parseComma()) { + return {}; + } + std::optional memspace = + symbolizeMemorySpace(memory_space_str.getValue()); + if (!memspace.has_value()) { + return {}; + } + auto thread_map = mlir::cast(parseIndexingMapImpl(parser)); + if (!thread_map) { + return {}; + } + mlir::MLIRContext* context = parser.getContext(); + auto memory_space_attr = MemorySpaceAttr::get(context, *memspace); + return LayoutAttr::get(context, memory_space_attr, thread_map); +} + +void LayoutAttr::print(mlir::AsmPrinter& printer) const { + printer << "<\"" << stringifyMemorySpace(getMemorySpace().getValue()) + << "\", " << getThreadMap().getIndexingMap().ToString() << '>'; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td index 3bcdc79e4ff119..44e8dd4353a5b6 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td @@ -17,6 +17,7 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" class XLAGPU_Attr traits = []> : @@ -81,4 +82,41 @@ def XLAGPU_LaunchGridAttr : XLAGPU_Attr<"LaunchGrid"> { }]; } +//===----------------------------------------------------------------------===// +// Tensor layout attribute +//===----------------------------------------------------------------------===// + +def XLAGPU_MemorySpace : I32EnumAttr<"MemorySpace", + "element-wise op type", [ + I32EnumAttrCase<"kRegisters", 0, "registers">, + I32EnumAttrCase<"kSharedMemory", 1, "shmem"> + ]> { + let cppNamespace = "::xla::gpu"; + let genSpecializedAttr = 0; +} + +def XLAGPU_MemorySpaceAttr : EnumAttr< + XlaGpuDialect, XLAGPU_MemorySpace, "memory_space"> { + let assemblyFormat = "`<` $value `>`"; +} + +def XLAGPU_LayoutAttr : XLAGPU_Attr<"Layout"> { + let mnemonic = "layout"; + let summary = "Layout consists of a thread ID indexing map + memory space."; + let description = [{ + This attribute is used as an encoding for RankedTensorType. It indicates in + which memory space the tensor is stored and the access pattern from the + warps/threads. + ```mlir + tensor<42xf32, #xla_gpu.layout<"shmem", (d0) -> (), domain: d0 in [0, 42]>> + ``` + }]; + let parameters = (ins + AttrParameter<"MemorySpaceAttr", "memory_space">:$memory_space, + AttrParameter<"IndexingMapAttr", "thread_map">:$thread_map + ); + let hasCustomAssemblyFormat = 1; +} + + #endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc index 5ac9c59ce773df..c46561a98d0d45 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc @@ -18,6 +18,9 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/Transforms/InliningUtils.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" + +// The order of these includes is important. +#include "xla/service/gpu/fusions/ir/xla_gpu_enums.cc.inc" #define GET_ATTRDEF_CLASSES #include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" #define GET_TYPEDEF_CLASSES @@ -112,6 +115,10 @@ struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { os << "indexing_map"; return AliasResult::FinalAlias; } + if (llvm::isa(attr)) { + os << "layout"; + return AliasResult::FinalAlias; + } return AliasResult::NoAlias; } }; diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index d31c5bef66ac34..967df31ba84397 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -23,7 +23,6 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -46,7 +45,6 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "stablehlo/dialect/TypeInference.h" #include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc" #include "xla/service/gpu/model/indexing_map.h" diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h index e025bd90b37e64..e47f2cf3e2d323 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep #include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_enums.h.inc" #include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep #define GET_ATTRDEF_CLASSES #include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc index dbcc20b36f9951..86f2dffa74f4f2 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc @@ -21,14 +21,9 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep -#define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" - -#define GET_TYPEDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc" namespace xla { namespace gpu { From 2b9d3277249225ae87ae7fff71f666ff44c5499a Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 23 Sep 2024 01:44:53 -0700 Subject: [PATCH 128/483] Remove IsLoopIterationOffset() method from DynamicSliceFusion emitter. It relies on the while instruction back pointer which is known to be buggy. Also it is currently not covered by tests. PiperOrigin-RevId: 677687232 --- .../xla/xla/service/gpu/fusions/custom.cc | 49 ------------------- 1 file changed, 49 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index c89cdaba4f89db..c4270462ccdc16 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -176,51 +176,6 @@ absl::StatusOr GetOperandSlice( return absl::InternalError("WTF"); } -// Returns true if `offset` is a loop iteration number. This pattern matching -// detects HLOs that generated by `jax.lax.scan` and will miss slightly -// different patterns that still compute slice offset as loop iteration number. -static bool IsLoopIterationOffset(const HloInstruction* offset) { - const HloComputation* parent = offset->parent(); - if (!parent->IsWhileBodyComputation()) return false; - - // Scan loops trip count must be known at compile time as it iterates over the - // leading dimension of the statically shaped input. - const HloInstruction* while_instr = parent->WhileCallInstruction(); - auto config = while_instr->backend_config(); - if (!config.ok() || !config->has_known_trip_count()) return false; - int32_t trip_count = config->known_trip_count().n(); - - // Check that offset is defined by a loop fusion that computes offset - // from the loop iteration number. - if (!offset->IsLoopFusion() || - !Match( - offset->fused_expression_root(), - m::Select(m::Compare(m::Parameter(0), m::ConstantScalar(0)), - m::Add(m::Parameter(0), m::ConstantScalar(trip_count)), - m::Parameter(0)))) { - return false; - } - - // Check that we get loop iteration directly from loop parameters bundle. - HloInstruction* get_loop_iteration; - if (!Match(const_cast(offset->operand(0)), - m::GetTupleElement(&get_loop_iteration, m::Parameter(0)))) { - return false; - } - int32_t loop_iter_idx = get_loop_iteration->tuple_index(); - - // Check that loop iteration counter updated with a +1 fusion. - const HloInstruction* loop_inc = - parent->root_instruction()->operand(loop_iter_idx); - if (!loop_inc->IsLoopFusion() || - !Match(loop_inc->fused_expression_root(), - m::Add(m::Parameter(0), m::ConstantScalar(1)))) { - return false; - } - - return true; -} - // Returns the constant literal, if the offset is from an offset array. Returns // `std::nullopt` otherwise. std::optional GetOffsetArray(const HloInstruction* inst) { @@ -276,10 +231,6 @@ absl::Status CollectSliceInfo( "Unsupported constant offset shape: ", cst->shape().ToString())); } - } else if (IsLoopIterationOffset(offset_value)) { - // Loop offset defined by a loop iteration number. - arg_offsets.emplace_back() = DynamicSliceThunk::LoopIter(); - } else { // Loop offset computed on device and has to be transferred to host. TF_ASSIGN_OR_RETURN(arg_offsets.emplace_back(), From caeb516b3978b940a8ddf5bb69fb923531932ec7 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Mon, 23 Sep 2024 01:52:13 -0700 Subject: [PATCH 129/483] [XLA:GPU] Avoid copying Shape in HloRematerialization This CL avoids copying Shape in HloRematerialization by returning a pointer to the Shape in the internal cache. This change is expected to improve 30% the performance of HloRematerialization. PiperOrigin-RevId: 677689332 --- .../xla/xla/service/hlo_rematerialization.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/hlo_rematerialization.cc b/third_party/xla/xla/service/hlo_rematerialization.cc index fed3b52c29ecbb..d7090605bbc836 100644 --- a/third_party/xla/xla/service/hlo_rematerialization.cc +++ b/third_party/xla/xla/service/hlo_rematerialization.cc @@ -724,7 +724,7 @@ class MemoryUsageTracker { // Get the compact shape of given hlo instruction. An internal cache is used // to avoid computing the shape multiple times. - absl::StatusOr GetCompactShape(const HloInstruction* hlo); + absl::StatusOr GetCompactShape(const HloInstruction* hlo); // Creates a Buffer representing the given logical buffer. The buffer is added // to buffers_ and a reference is returned. @@ -1506,17 +1506,16 @@ std::string MemoryUsageTracker::ToString() const { return output; } -absl::StatusOr MemoryUsageTracker::GetCompactShape( +absl::StatusOr MemoryUsageTracker::GetCompactShape( const HloInstruction* hlo) { auto it = compact_shape_.find(hlo); if (it != compact_shape_.end()) { - return it->second; + return &it->second; } const Shape& original_shape = hlo->shape(); TF_ASSIGN_OR_RETURN(Shape min_shape, options_.compact_shape_function(original_shape)); - compact_shape_[hlo] = min_shape; - return min_shape; + return &compact_shape_.emplace(hlo, min_shape).first->second; } bool MemoryUsageTracker::Check() const { @@ -1660,9 +1659,10 @@ std::optional MemoryUsageTracker::GetCostOfCompression( return {}; } - Shape compact_shape = GetCompactShape(candidate_item->instruction).value(); + const Shape* compact_shape = + GetCompactShape(candidate_item->instruction).value(); const int64_t memory_reduced = - MemoryReducedIfCompressed(candidate_item, compact_shape); + MemoryReducedIfCompressed(candidate_item, *compact_shape); // Since the compressed and uncompressed buffers need to be alive // while performing the compression/uncompression, only perform // the compression if the sum of the two sizes is less than the @@ -1670,7 +1670,7 @@ std::optional MemoryUsageTracker::GetCostOfCompression( const int64_t size = options_.hlo_cost_analysis.GetShapeSize( candidate_item->instruction->shape()); const int64_t reduced_size = - options_.hlo_cost_analysis.GetShapeSize(compact_shape); + options_.hlo_cost_analysis.GetShapeSize(*compact_shape); // TODO(victorstone): I don't think this size check is right. if (memory_reduced > 0 && size + reduced_size < peak_memory_bytes) { return memory_limit_bytes / memory_reduced; @@ -1907,7 +1907,7 @@ MemoryUsageTracker::PickRematerializationCandidates( // computed inside GetCostOfCompression, should we get it from there? Or // is it ok to recompute? best_strategy.compact_shape = - GetCompactShape(block[0]->instruction).value(); + *GetCompactShape(block[0]->instruction).value(); best_items = block; best_cost = *cost; } From 5ce3519f6532a6f15fb044eac7645fe279b39ff6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 02:03:51 -0700 Subject: [PATCH 130/483] Update GraphDef version to 1994. PiperOrigin-RevId: 677693327 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 0630c0ec562e53..e89334c66dc1b5 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1993 // Updated: 2024/9/22 +#define TF_GRAPH_DEF_VERSION 1994 // Updated: 2024/9/23 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 033bcafbe0717f836a6b18f0a50bf809bf6f4b50 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 02:03:59 -0700 Subject: [PATCH 131/483] compat: Update forward compatibility horizon to 2024-09-23 PiperOrigin-RevId: 677693371 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 880fbfc269a41a..f7a548d8aa13ca 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 22) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 23) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 2dc146aa2f7f256b5c6b13df8fef57a1e0e1fa41 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 02:07:24 -0700 Subject: [PATCH 132/483] Automated Code Change PiperOrigin-RevId: 677694505 --- third_party/xla/xla/service/graphcycles/graphcycles.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/service/graphcycles/graphcycles.cc b/third_party/xla/xla/service/graphcycles/graphcycles.cc index 019087c1a98276..056cdcd74a01c7 100644 --- a/third_party/xla/xla/service/graphcycles/graphcycles.cc +++ b/third_party/xla/xla/service/graphcycles/graphcycles.cc @@ -38,7 +38,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/service/graphcycles/ordered_set.h" From 0cfdece9e5ced1973126f221a55524259baff9ab Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Mon, 23 Sep 2024 03:26:15 -0700 Subject: [PATCH 133/483] PR #17205: [ROCM] fixing build-brake: noexcept MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/17205 The member variable of **ExecutionOutput** with type **ScopedShapeBuffer** has a move-ctor and move operator= without noexcept keyword which causes the following compile error on ROCM side: ‘xla::ExecutionOutput::ExecutionOutput(xla::ExecutionOutput&&) noexcept’ is implicitly deleted because its exception-specification does not match the implicit exception-specification ‘’ @xla-rotation: would you please have a look ? Copybara import of the project: -- 59f59d72f53c837580cea0690974259d8089d787 by Pavel Emeliyanenko : fixing build-brake -- 3cb0308acaa4b1f0881194be7efec94c3954a52a by Pavel Emeliyanenko : added noexcept to the base class move ctor too Merging this change closes #17205 PiperOrigin-RevId: 677716783 --- third_party/xla/xla/service/shaped_buffer.cc | 9 +++++---- third_party/xla/xla/service/shaped_buffer.h | 8 ++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/shaped_buffer.cc b/third_party/xla/xla/service/shaped_buffer.cc index a5155e4331f624..0576644409b427 100644 --- a/third_party/xla/xla/service/shaped_buffer.cc +++ b/third_party/xla/xla/service/shaped_buffer.cc @@ -46,7 +46,7 @@ ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape, int device_ordinal, int physical_device_ordinal) : ShapedBuffer(on_device_shape, device_ordinal, physical_device_ordinal) {} -ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) +ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) noexcept : on_host_shape_(std::move(s.on_host_shape_)), on_device_shape_(std::move(s.on_device_shape_)), device_ordinal_(s.device_ordinal_), @@ -58,7 +58,7 @@ ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) buffers_.replace_shape_ptr(on_device_shape_); } -ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { +ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) noexcept { on_device_shape_ = std::move(s.on_device_shape_); on_host_shape_ = std::move(s.on_host_shape_); device_ordinal_ = s.device_ordinal_; @@ -140,13 +140,14 @@ ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, se::DeviceMemoryAllocator* allocator) : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {} -ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) +ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) noexcept : ShapedBuffer(static_cast(s)), allocator_(s.allocator_) { // Null out s.allocator_ so it doesn't try to free anything in its destructor. s.allocator_ = nullptr; } -ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { +ScopedShapedBuffer& ScopedShapedBuffer::operator=( + ScopedShapedBuffer&& s) noexcept { Deallocate(); *static_cast(this) = std::move(static_cast(s)); diff --git a/third_party/xla/xla/service/shaped_buffer.h b/third_party/xla/xla/service/shaped_buffer.h index 5faf97cb64f3d7..003e142a8bd355 100644 --- a/third_party/xla/xla/service/shaped_buffer.h +++ b/third_party/xla/xla/service/shaped_buffer.h @@ -52,8 +52,8 @@ class ShapedBuffer { int physical_device_ordinal = -1); // Movable, but not copyable. - ShapedBuffer(ShapedBuffer&& s); - ShapedBuffer& operator=(ShapedBuffer&&); + ShapedBuffer(ShapedBuffer&& s) noexcept; + ShapedBuffer& operator=(ShapedBuffer&&) noexcept; ShapedBuffer(const ShapedBuffer&) = delete; ShapedBuffer& operator=(const ShapedBuffer&) = delete; @@ -170,8 +170,8 @@ class ScopedShapedBuffer : public ShapedBuffer { se::DeviceMemoryAllocator* allocator); // Movable, but not copyable. - ScopedShapedBuffer(ScopedShapedBuffer&& s); - ScopedShapedBuffer& operator=(ScopedShapedBuffer&&); + ScopedShapedBuffer(ScopedShapedBuffer&& s) noexcept; + ScopedShapedBuffer& operator=(ScopedShapedBuffer&&) noexcept; ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; From d5bf83fa04d96c2e50c075a637b020fb8cc56bf1 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 23 Sep 2024 03:37:00 -0700 Subject: [PATCH 134/483] [XLA:GPU] Add heuristic to detect coalesced reads in Tiled Cost Model. This adds a simple heuristic that looks at tile size, stride and layout and determines if the resulting load will be coalesced. If the read is not coalesced, a bandwidth penalty is applied to simulate the inefficient use of cache lines. This helps prevent fusions with degraded performance due to subpar data access patterns. PiperOrigin-RevId: 677719995 --- third_party/xla/xla/service/gpu/model/BUILD | 7 + .../service/gpu/model/coalescing_analysis.cc | 46 +++++- .../service/gpu/model/coalescing_analysis.h | 14 +- .../gpu/model/coalescing_analysis_test.cc | 135 +++++++++++++++++- .../model/gpu_indexing_performance_model.cc | 35 +++-- .../gpu_indexing_performance_model_test.cc | 50 ++++++- .../service/gpu/model/tiled_hlo_instruction.h | 2 + 7 files changed, 271 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 051e81048e96a4..6ef51b853d38c1 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -805,12 +805,14 @@ cc_library( deps = [ ":affine_map_evaluator", ":indexing_analysis", + ":tiled_hlo_instruction_or_computation", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu/fusions:fusion_emitter", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", @@ -827,6 +829,9 @@ xla_cc_test( srcs = ["coalescing_analysis_test.cc"], deps = [ ":coalescing_analysis", + ":symbolic_tile", + ":symbolic_tile_analysis", + ":tiled_hlo_instruction_or_computation", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", @@ -838,8 +843,10 @@ xla_cc_test( "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 9e7a685d590a29..9612a5566a3c69 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -25,9 +25,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -43,8 +42,10 @@ limitations under the License. #include "xla/service/gpu/model/affine_map_evaluator.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -97,15 +98,52 @@ bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, return true; } +bool IsTiledReadCoalescedHeuristic(const TiledHloInstruction& operand, + const se::DeviceDescription& device_info) { + const Shape& shape = operand.hlo()->shape(); + + // Compute the number of elements in the contiguous part of the tile. + int64_t contiguous_read_elements = 1; + for (const auto dim_idx : shape.layout().minor_to_major()) { + // This dimension is strided, so it's not contiguous. + if (operand.tile_stride(dim_idx) != 1) { + break; + } + + int64_t tile_size = operand.tile_size(dim_idx); + int64_t dim_size = shape.dimensions(dim_idx); + + // Make sure to ignore the mask if there is one. + contiguous_read_elements *= std::min(tile_size, dim_size); + + // This dimension is only partially captured, so more major dimensions are + // necessarily not captured contiguously. + if (tile_size < dim_size) { + break; + } + } + + // Compute the size of the contiguous part of the tile in bytes. + int64_t contiguous_bytes_accessed = + contiguous_read_elements * + ShapeUtil::ByteSizeOfPrimitiveType(operand.hlo()->shape().element_type()); + + // We consider a read coalesced if the contiguous part of the read covers the + // whole DRAM->L2 cache line. + // + // TODO(b/332714755): note that we don't check that we fully exploit all the + // cache lines we read from if we happen to read through several of them. + return contiguous_bytes_accessed >= + device_info.dram_to_l2_transaction_size_bytes(); +} + namespace { using ::mlir::AffineBinaryOpExpr; using ::mlir::AffineConstantExpr; -using ::mlir::AffineDimExpr; using ::mlir::AffineExpr; using ::mlir::AffineExprKind; using ::mlir::AffineMap; -using ::mlir::AffineSymbolExpr; using ::mlir::getAffineConstantExpr; using ::mlir::MLIRContext; diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h index da2c6872b191e9..5e82b6455afcd7 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h @@ -23,7 +23,8 @@ limitations under the License. #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -72,6 +73,17 @@ bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, const HloInstruction* producer, const HloInstruction* consumer = nullptr); +// Returns true if read of this tiled hlo operand is coalesced. +// +// We consider a read coalesced if the operand tile consist of contiguous chunk +// of memory that saturate DRAM->L2 cache line. For post-V100 NVIDIA GPUs, that +// is 64 bytes by default. +// +// TODO(b/332714755): check whether we should bump up the granularity of +// memory transactions. +bool IsTiledReadCoalescedHeuristic(const TiledHloInstruction& operand, + const se::DeviceDescription& device_info); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index aefe84294472a2..70cca31981174e 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -15,12 +15,14 @@ limitations under the License. #include "xla/service/gpu/model/coalescing_analysis.h" +#include #include #include #include #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -29,11 +31,15 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -148,7 +154,7 @@ TEST_F(CoalescingTest, OutputAndLhsTransposedLayout) { fusion { p0 = f32[100, 200]{1, 0} parameter(0) p1 = f32[100, 200]{0, 1} parameter(1) - ROOT exp = f32[100, 200]{1, 0} add(p0, p1) + ROOT add = f32[100, 200]{1, 0} add(p0, p1) } ENTRY e { p0 = f32[100, 200]{1, 0} parameter(0) @@ -510,6 +516,133 @@ TEST_F(CoalescingTest, Param) { EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, true, true)); } +class CoalescingForTiledHloTest : public CoalescingTest { + public: + std::vector IsTiledReadCoalescedPerOperand( + const HloInstruction* root, absl::Span tile_sizes) { + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); + + SymbolicTileAnalysis symbolic_tile_analysis = + std::get(SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_)); + + TiledHloComputation tiled_hlo_computation = + *symbolic_tile_analysis.ComputeTiledHloInstructions( + tile_sizes, /*constraints_are_known_satisfied=*/true, + /*compute_all_tile_offset_indexing_maps=*/true); + + const TiledHloInstruction* tiled_hlo_root = tiled_hlo_computation.GetRoot(); + std::vector result; + for (const TiledHloInstruction* operand : tiled_hlo_root->operands()) { + result.push_back(IsTiledReadCoalescedHeuristic(*operand, device_info_)); + } + return result; + } +}; + +TEST_F(CoalescingForTiledHloTest, TiledReadCoalescedHeuristic_Transpose) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = f32[2048, 48] parameter(0) + ROOT transpose = f32[48, 2048] transpose(p0), dimensions={1, 0} +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + // The operand is not coalesced because the tile has stride 48. + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {1, 2048}), + ElementsAre(false)); + + // The operand is coalesced because we read 48 contiguous elements. + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {48, 32}), + ElementsAre(true)); +} + +TEST_F(CoalescingForTiledHloTest, + TiledReadCoalescedHeuristic_MaskingIsHandledCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = f32[2048, 12] parameter(0) + ROOT transpose = f32[12, 2048] transpose(p0), dimensions={1, 0} +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + constexpr int kNumBytesPerParamRow = 12 * 4; + + // The transaction size can be configured in different ways, and the minimum + // possible value on A100 is 32 bytes---which would make this test fail. + // Ensure that the transaction size is configured to be large enough. + ASSERT_GT(device_info_.dram_to_l2_transaction_size_bytes(), + kNumBytesPerParamRow); + + // The operand is coalesced because we read 4 * 12 = 48 contiguous elements + // (though the tile contains 64 elements due to the mask). + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 4}), ElementsAre(true)); + + // The mask should be ignored when checking whether reads are coalesced. + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {1024, 1}), + ElementsAre(false)); +} + +TEST_F(CoalescingForTiledHloTest, RhsTransposedLayout) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = f32[256, 512]{1,0} parameter(0) + p1 = f32[256, 512]{0,1} parameter(1) + ROOT add = f32[256, 512]{1,0} add(p0, p1) +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + constexpr int kExpectedDramToL2TransactionSize = 64; + ASSERT_EQ(device_info_.dram_to_l2_transaction_size_bytes(), + kExpectedDramToL2TransactionSize); + + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {1, 16}), + ElementsAre(true, false)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 1}), + ElementsAre(false, true)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 16}), + ElementsAre(true, true)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {8, 8}), + ElementsAre(false, false)); +} + +TEST_F(CoalescingForTiledHloTest, SmallDataTypes) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = s8[256, 512] parameter(0) + p1 = s8[256, 512] parameter(1) + ROOT add = s8[256, 512] add(p0, p1) +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + constexpr int kExpectedDramToL2TransactionSize = 64; + ASSERT_EQ(device_info_.dram_to_l2_transaction_size_bytes(), + kExpectedDramToL2TransactionSize); + + // To be coalesced, a contiguous chunk of memory load should be at least + // kExpectedDramToL2TransactionSize bytes long. + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 16}), + ElementsAre(false, false)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 32}), + ElementsAre(false, false)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 64}), + ElementsAre(true, true)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 128}), + ElementsAre(true, true)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 520d0448c0c78f..7798b80d17a681 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -58,6 +58,15 @@ namespace xla { namespace gpu { namespace { +// Information about an operand read. +struct OperandReadInfo { + // Total number of bytes read from the operand. + int64_t total_bytes_read = 0; + + // Whether the read is coalesced. + int64_t is_coalesced = true; +}; + // Returns the number of elements in the tile after each dimension is padded to // the next power of 2. // TODO(b/353484968): Delete this function once we have constraints to only @@ -347,7 +356,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( const HloFusionAdaptor& fusion_adaptor, const TiledHloComputation& tiled_hlo_computation, const LaunchDimensions& launch_dimensions) { - absl::flat_hash_map n_bytes_total_map; + absl::flat_hash_map n_bytes_total_map; int64_t flops = 0; int64_t bytes_read = 0; @@ -405,19 +414,27 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( int64_t tile_bytes_read = element_type_size * num_elements; bytes_read += tile_bytes_read; - n_bytes_total_map[hlo] += tile_bytes_read; + + bool is_coalesced = + IsTiledReadCoalescedHeuristic(*tiled_hlo, *device_info_); + + OperandReadInfo& operand_read_info = n_bytes_total_map[hlo]; + operand_read_info.total_bytes_read += tile_bytes_read; + operand_read_info.is_coalesced &= is_coalesced; } } absl::Duration read_time = absl::ZeroDuration(); - for (const auto& [hlo, n_bytes_total] : n_bytes_total_map) { + for (const auto& [hlo, operand_read_info] : n_bytes_total_map) { int64_t operand_size = shape_size_(hlo->shape()); - int64_t n_bytes_net = std::min(operand_size, n_bytes_total); - - read_time += ReadTimeWithDRAMHeuristic( - *device_info_, num_blocks, n_bytes_net, n_bytes_total, - /*element_type=*/hlo->shape().element_type(), - /*coalesced=*/true); + int64_t n_bytes_net = + std::min(operand_size, operand_read_info.total_bytes_read); + + read_time += + ReadTimeWithDRAMHeuristic(*device_info_, num_blocks, n_bytes_net, + operand_read_info.total_bytes_read, + /*element_type=*/hlo->shape().element_type(), + /*coalesced=*/operand_read_info.is_coalesced); } int64_t bytes_written = diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index 18ce88e553277f..dddf7b1d428f9d 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -381,9 +381,9 @@ ENTRY main { indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, launch_dimensions, /*output_tile_sizes=*/{1, 1})); - EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.read_time), 183, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.read_time), 5863, 1); EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.compute_time), 39, 1); - EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 185, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 5865, 1); } // TODO(b/351342921): Remove this test once there is no special filter for @@ -536,6 +536,51 @@ ENTRY main { EXPECT_EQ(res.flops, kPaddedOutputTileSize * kAddFlops); } +TEST_F(GpuIndexingPerformanceModelTest, + EstimateRunTimeForTiledFusion_UncoalescedReadsTakeMoreTime) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +triton_softmax_computation { + param_0 = f32[2048,512] parameter(0) + param_1 = f32[2048,512] parameter(1) + ROOT add = f32[2048,512] add(param_0, param_1) +} + +ENTRY main { + param_0 = f32[2048,512] parameter(0) + param_1 = f32[2048,512] parameter(1) + ROOT triton_softmax = f32[2048,512] fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + TF_ASSERT_OK_AND_ASSIGN( + auto tiling_result, + indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor)); + + TF_ASSERT_OK_AND_ASSIGN( + auto res_coalesced, + indexing_cost_model_.EstimateRunTimeForTiledFusion( + *fusion_adaptor, /*launch_dimensions=*/{8192, 2 * WarpSize()}, + /*output_tile_sizes=*/{1, 128})); + + TF_ASSERT_OK_AND_ASSIGN( + auto res_uncoalesced, + indexing_cost_model_.EstimateRunTimeForTiledFusion( + *fusion_adaptor, /*launch_dimensions=*/{8192, 2 * WarpSize()}, + /*output_tile_sizes=*/{128, 1})); + + constexpr int64_t kParamSizeBytes = 2048 * 512 * 4; + // The number of bytes read is the same for coalesced and uncoalesced reads. + EXPECT_EQ(res_coalesced.bytes_read, 2 * kParamSizeBytes); + EXPECT_EQ(res_uncoalesced.bytes_read, 2 * kParamSizeBytes); + + EXPECT_NEAR(absl::ToDoubleMicroseconds(res_coalesced.read_time), 11, 1); + EXPECT_NEAR(absl::ToDoubleMicroseconds(res_uncoalesced.read_time), 175, 1); +} + TEST_F(GpuIndexingPerformanceModelTest, GetLaunchDimensionsForTiledFusion_IsSupported) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( @@ -576,7 +621,6 @@ ENTRY main { // and corresponds to 4 warps. EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); } - class FlopsPerElementTest : public GpuIndexingPerformanceModelTest { public: void CompareFlopsModels(absl::string_view hlo_module_string) { diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h index 146035b0cb1e55..409bbdf6b4e62a 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h @@ -71,12 +71,14 @@ class TiledHloInstruction { // Returns the tile sizes. The number of tile sizes is equal to the rank of // the output shape. const llvm::SmallVector& tile_sizes() const { return tile_sizes_; } + int64_t tile_size(int64_t dim_idx) const { return tile_sizes_[dim_idx]; } // Returns the tile strides. The number of tile strides is equal to the rank // of the output shape. const llvm::SmallVector& tile_strides() const { return tile_strides_; } + int64_t tile_stride(int64_t dim_idx) const { return tile_strides_[dim_idx]; } // Returns the indexing map from tile multi-index to tile offsets. The map has // a form of `(d0, d1, ...) -> (tile_offset0, tile_offset1, ...)`. The number From 6683674d7d90319fb42804a98e00a1daed0ea30e Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 23 Sep 2024 04:06:12 -0700 Subject: [PATCH 135/483] Remove gpu_runtime.h in favor of cuda_runtime.h and rocm_runtime.h Functions in `GpuRuntime` are only ever called from backend specific code. Therefore we won't need the backend-agnostic abstraction and can just call the backend-specific implementations directly. PiperOrigin-RevId: 677727952 --- .../xla/xla/stream_executor/cuda/BUILD | 7 +-- .../xla/stream_executor/cuda/cuda_executor.cc | 6 +-- .../xla/stream_executor/cuda/cuda_runtime.cc | 8 ++-- .../gpu_runtime.h => cuda/cuda_runtime.h} | 14 +++--- third_party/xla/xla/stream_executor/gpu/BUILD | 13 ----- .../xla/xla/stream_executor/rocm/BUILD | 9 +--- .../xla/stream_executor/rocm/rocm_executor.cc | 6 +-- .../xla/stream_executor/rocm/rocm_runtime.cc | 8 ++-- .../xla/stream_executor/rocm/rocm_runtime.h | 47 +++++++++++++++++++ 9 files changed, 72 insertions(+), 46 deletions(-) rename third_party/xla/xla/stream_executor/{gpu/gpu_runtime.h => cuda/cuda_runtime.h} (86%) create mode 100644 third_party/xla/xla/stream_executor/rocm/rocm_runtime.h diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 0e4b26fbcb9bd1..651eda5a62e39b 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -213,9 +213,8 @@ cuda_only_cc_library( cuda_only_cc_library( name = "cuda_runtime", srcs = ["cuda_runtime.cc"], + hdrs = ["cuda_runtime.h"], deps = [ - "//xla/stream_executor/gpu:gpu_runtime_header", - "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -223,7 +222,6 @@ cuda_only_cc_library( "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) @@ -831,7 +829,7 @@ cuda_only_cc_library( ":cuda_event", # buildcleaner: keep ":cuda_kernel", # buildcleaner: keep ":cuda_platform_id", - ":cuda_runtime", # buildcleaner: keep + ":cuda_runtime", ":cuda_status", ":cuda_version_parser", "//xla/stream_executor", @@ -856,7 +854,6 @@ cuda_only_cc_library( "//xla/stream_executor/gpu:gpu_event_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", - "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_semaphore", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_timer", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 28e15f1546b6c2..27486707b5707c 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/cuda_runtime.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/cuda/cuda_version_parser.h" #include "xla/stream_executor/cuda/delay_kernel.h" @@ -56,7 +57,6 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" @@ -230,7 +230,7 @@ absl::StatusOr> CudaExecutor::LoadKernel( << " from symbol pointer: " << symbol; TF_ASSIGN_OR_RETURN( GpuFunctionHandle function, - GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); + CudaRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); cuda_kernel->set_gpu_function(function); } else { @@ -678,7 +678,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { ParseCudaVersion(GpuDriver::GetDriverVersion().value_or(0)) .value_or(SemanticVersion{0, 0, 0})); desc.set_runtime_version( - ParseCudaVersion(GpuRuntime::GetRuntimeVersion().value_or(0)) + ParseCudaVersion(CudaRuntime::GetRuntimeVersion().value_or(0)) .value_or(SemanticVersion{0, 0, 0})); desc.set_compile_time_toolkit_version( ParseCudaVersion(CUDA_VERSION).value_or(SemanticVersion{0, 0, 0})); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_runtime.cc b/third_party/xla/xla/stream_executor/cuda/cuda_runtime.cc index bf355cf9b7b1da..c9ced05c4b91e0 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_runtime.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_runtime.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/cuda/cuda_runtime.h" + #include #include "absl/base/optimization.h" @@ -23,8 +25,6 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" -#include "xla/stream_executor/gpu/gpu_types.h" #include "tsl/platform/logging.h" namespace stream_executor::gpu { @@ -42,7 +42,7 @@ static const char* ToString(cudaError_t error) { } \ } while (0) -absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { +absl::StatusOr CudaRuntime::GetFuncBySymbol(void* symbol) { VLOG(2) << "Get CUDA function from a symbol: " << symbol; cudaFunction_t func; RETURN_IF_CUDA_RES_ERROR(cudaGetFuncBySymbol(&func, symbol), @@ -50,7 +50,7 @@ absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { return reinterpret_cast(func); } -absl::StatusOr GpuRuntime::GetRuntimeVersion() { +absl::StatusOr CudaRuntime::GetRuntimeVersion() { VLOG(2) << "Get CUDA runtime version"; int32_t version; RETURN_IF_CUDA_RES_ERROR(cudaRuntimeGetVersion(&version), diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_runtime.h b/third_party/xla/xla/stream_executor/cuda/cuda_runtime.h similarity index 86% rename from third_party/xla/xla/stream_executor/gpu/gpu_runtime.h rename to third_party/xla/xla/stream_executor/cuda/cuda_runtime.h index 6f36c7ceab1ea1..32ebd5cf8611a5 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_runtime.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_runtime.h @@ -15,13 +15,13 @@ limitations under the License. // CUDA/ROCm runtime library wrapper functionality. -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_RUNTIME_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_RUNTIME_H_ +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_RUNTIME_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_RUNTIME_H_ #include #include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/gpu_types.h" +#include "third_party/gpus/cuda/include/cuda.h" namespace stream_executor::gpu { @@ -39,10 +39,10 @@ namespace stream_executor::gpu { // //===----------------------------------------------------------------------===// -// Gpu runtime returns types defined in the stream_executor::gpu namespace, and +// Cuda runtime returns types defined in the stream_executor::gpu namespace, and // they usually correspond to the driver types, as driver API is the primary // integration API of Gpus into StreamExecutor. -class GpuRuntime { +class CudaRuntime { public: // Get pointer to device entry function that matches entry function `symbol`. // @@ -52,7 +52,7 @@ class GpuRuntime { // current device (and create it if it doesn't exist yet). // // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html#group__CUDART__DRIVER_1gaba6f8d01e745f0c8d8776ceb18be617 - static absl::StatusOr GetFuncBySymbol(void* symbol); + static absl::StatusOr GetFuncBySymbol(void* symbol); // Returns the Gpu Runtime version. // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION_1g0e3952c7802fd730432180f1f4a6cdc6 @@ -61,4 +61,4 @@ class GpuRuntime { } // namespace stream_executor::gpu -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_RUNTIME_H_ +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_RUNTIME_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index b123c2fc40ac38..d98dcfe248d4d9 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -195,19 +195,6 @@ gpu_only_cc_library( ), ) -gpu_only_cc_library( - name = "gpu_runtime_header", - hdrs = ["gpu_runtime.h"], - visibility = internal_visibility([ - "//xla/service/gpu:__subpackages__", - "//xla/stream_executor:__subpackages__", - ]), - deps = [ - ":gpu_types_header", - "@com_google_absl//absl/status:statusor", - ], -) - gpu_only_cc_library( name = "gpu_command_buffer", srcs = ["gpu_command_buffer.cc"], diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index bd1610200b16f5..e6fe7a4ef11566 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -83,17 +83,13 @@ cc_library( cc_library( name = "rocm_runtime", srcs = if_rocm_is_configured(["rocm_runtime.cc"]), - hdrs = if_rocm_is_configured([ - "rocm_driver_wrapper.h", - "rocm_driver.h", - ]), + hdrs = if_rocm_is_configured(["rocm_runtime.h"]), deps = if_rocm_is_configured([ # keep sorted + ":rocm_driver", "//xla/stream_executor", "//xla/stream_executor/gpu:context", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_runtime_header", - "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", "@com_google_absl//absl/base:core_headers", @@ -175,7 +171,6 @@ cc_library( "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", - "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index e46408be1974f5..4b3687fe92a5f4 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -54,7 +54,6 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -74,6 +73,7 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/rocm/rocm_runtime.h" #include "xla/stream_executor/rocm/rocm_version_parser.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" @@ -300,7 +300,7 @@ absl::StatusOr> RocmExecutor::LoadKernel( #if TF_ROCM_VERSION >= 60200 TF_ASSIGN_OR_RETURN( GpuFunctionHandle function, - GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); + RocmRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); rocm_kernel->set_gpu_function(function); #else rocm_kernel->set_gpu_function( @@ -704,7 +704,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { desc.set_compile_time_toolkit_version( SemanticVersion{HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH}); desc.set_runtime_version( - ParseRocmVersion(GpuRuntime::GetRuntimeVersion().value_or(0)) + ParseRocmVersion(RocmRuntime::GetRuntimeVersion().value_or(0)) .value_or(SemanticVersion{0, 0, 0})); desc.set_driver_version( ParseRocmVersion(GpuDriver::GetDriverVersion().value_or(0)) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_runtime.cc b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.cc index fe8dd31c47a7ee..56bdfd18726188 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_runtime.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/rocm/rocm_runtime.h" + #include #include "absl/base/optimization.h" @@ -20,8 +22,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" -#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/rocm/rocm_driver.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" @@ -34,7 +34,7 @@ limitations under the License. namespace stream_executor { namespace gpu { -absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { +absl::StatusOr RocmRuntime::GetFuncBySymbol(void* symbol) { VLOG(2) << "Get ROCM function from a symbol: " << symbol; #if TF_ROCM_VERSION >= 60200 hipFunction_t func; @@ -46,7 +46,7 @@ absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { #endif // TF_ROCM_VERSION >= 60200 } -absl::StatusOr GpuRuntime::GetRuntimeVersion() { +absl::StatusOr RocmRuntime::GetRuntimeVersion() { VLOG(2) << "Get ROCM runtime version"; int32_t version; RETURN_IF_ROCM_ERROR(wrap::hipRuntimeGetVersion(&version), diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h new file mode 100644 index 00000000000000..83148914b93977 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// CUDA/ROCm runtime library wrapper functionality. + +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_RUNTIME_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_RUNTIME_H_ + +#include + +#include "absl/status/statusor.h" +#include "rocm/include/hip/hip_runtime.h" + +namespace stream_executor::gpu { + +// Rocm runtime returns types defined in the stream_executor::gpu namespace, and +// they usually correspond to the driver types, as driver API is the primary +// integration API of Gpus into StreamExecutor. +class RocmRuntime { + public: + // Get pointer to device entry function that matches entry function `symbol`. + // + // WARNING: This will load all fatbins statically registered with the + // underlying runtime into runtime modules for the current context. If no + // context is current, the runtime will use the primary context for the + // current device (and create it if it doesn't exist yet). + static absl::StatusOr GetFuncBySymbol(void* symbol); + + // Returns the Gpu Runtime version. + static absl::StatusOr GetRuntimeVersion(); +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_RUNTIME_H_ From f40f796f335414bdcf80be8ebeec9e6f6565e0f9 Mon Sep 17 00:00:00 2001 From: Harsha H S Date: Mon, 23 Sep 2024 04:49:22 -0700 Subject: [PATCH 136/483] PR #17203: [ROCm] Fix build break on gcc with constexpr introduced in d4218841f7 Imported from GitHub PR https://github.com/openxla/xla/pull/17203 Copybara import of the project: -- 4c0f32bbe294fdea75d20f84bf90c9424521947a by Harsha HS : [ROCm] Fix build break on gcc with constexprt introduced in d4218841f7 Merging this change closes #17203 PiperOrigin-RevId: 677739691 --- .../gpu/transforms/cudnn_fused_conv_rewriter_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc index 7e9e3710569406..93608eb864c9c7 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc @@ -67,9 +67,9 @@ namespace m = match; using ::testing::HasSubstr; using ::testing::Not; -constexpr std::initializer_list kf16f32f64{"f16", "f32", - "f64"}; -constexpr std::initializer_list kf16f32{"f16", "f32"}; +static const std::initializer_list kf16f32f64{"f16", "f32", + "f64"}; +static const std::initializer_list kf16f32{"f16", "f32"}; class CudnnFusedConvRewriterHloTest : public HloTestBase { public: From 4cac97a15112020c63867eb831eaed1fa6e08bc3 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Mon, 23 Sep 2024 05:32:19 -0700 Subject: [PATCH 137/483] PR #17437: Check all F8 dtype combinations in //xla/tests:convert_test Imported from GitHub PR https://github.com/openxla/xla/pull/17437 This PR parametrizes float8 conversion tests to make sure all combinations are covered. All 5 supported float8 types' tests now follow the same pattern in the test file. Previously, the following tests were missing: - Conversion from F8e4m3b11fn to PRED; - Denormal and negative zero tests from F16 to F8e5m2; - `ConvertF32F8e5m2Roundtrip` was missing (test conversion from F32 to F8e5m2); - The complete set of exhaustive tests (vs other types) was only present for FNUZ types; - ...probably more Also merged the "Exhaustive4" and "Exhaustive5" suffixed tests that followed a similar pattern. Removed unnecessary saving/restoring of the debug options. Copybara import of the project: -- 91131ac613c5c86d3ab578ef26a1356944c5dff3 by Sergey Kozub : Check all F8 dtype combinations in //xla/tests:convert_test Merging this change closes #17437 PiperOrigin-RevId: 677751828 --- .../xla/xla/service/elemental_ir_emitter.cc | 4 +- third_party/xla/xla/tests/convert_test.cc | 735 +++++++++--------- 2 files changed, 367 insertions(+), 372 deletions(-) diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 9b76745911fd5d..15961cff6328bc 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -471,8 +471,8 @@ llvm::Value* EmitF16ToF8e4m3b11fnuz(llvm::Value* f16_value, auto type = f16_value->getType(); auto f16_abs_value = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {f16_value}, {type}, b); - auto f16_zero = llvm::ConstantFP::getZero(type); - auto is_zero = b->CreateFCmpOEQ(f16_abs_value, f16_zero); + auto f16_zero_or_underflow = llvm::ConstantFP::get(type, 0x1.004p-14); + auto is_zero = b->CreateFCmpOLT(f16_abs_value, f16_zero_or_underflow); auto f8_overflow_threshold = llvm::ConstantFP::get(type, 0x1.fp+4); auto no_overflow = b->CreateFCmpOLT(f16_abs_value, f8_overflow_threshold); diff --git a/third_party/xla/xla/tests/convert_test.cc b/third_party/xla/xla/tests/convert_test.cc index ef252a594e930b..1663ffe8a88619 100644 --- a/third_party/xla/xla/tests/convert_test.cc +++ b/third_party/xla/xla/tests/convert_test.cc @@ -54,11 +54,20 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); +template +class ConvertTestF16 : public ConvertTest { + public: + using ConvertTest::ConvertTest; +}; +using F16TypeList = ::testing::Types; +TYPED_TEST_SUITE(ConvertTestF16, F16TypeList); + TEST_F(ConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 64}); @@ -729,8 +738,21 @@ XLA_TEST_F(ConvertTest, ConvertF32BF16) { } } +XLA_TYPED_TEST(ConvertTestT, ConvertFPToPred) { + XlaBuilder builder(this->TestName()); + using FP = TypeParam; + + auto a = ConstantR1(&builder, {FP{0.0}, FP{0.25}, FP{2.0}, FP{-0.0}}); + ConvertElementType(a, PRED); + + std::array expected = {false, true, true, false}; + this->template ComputeAndCompareR1(&builder, expected, {}); +} + +// ----- F8E5M2 + XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -741,6 +763,7 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, {nan, nan}, @@ -752,8 +775,18 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { {0x1.DFCp15, 0x1.Cp15}, // Largest number that doesn't overflow {0x1.Ep15, inf}, // Smallest number that overflows {0x1p16, inf}, // Overflow - {0x1p-14, 0x1p-14}, // Smallest normal - {0x1.8p-15, 0x1.8p-15}, // Denormal + {0x1p-14, 0x1p-14}, // Smallest F8 normal + {0x1.Cp-15, 0x1p-14}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-15, 0x1.0p-15}, // Denormal without rounding + {0x1.4p-15, 0x1.0p-15}, // Round-to-even down + {0x1.Cp-15, 0x1.0p-14}, // Round-to-even up + {0x1.3p-15, 0x1.0p-15}, // Round-to-nearest down + {0x1.5p-15, 0x1.8p-15}, // Round-to-nearest up + {0x1p-17, 0}, // Largest number that underflows + {0x1.04p-17, 0x1p-16}, // Smallest number that doesn't underflow + {0x1.BFp-15, 0x1.8p-15}, // Largest number that rounds to denormal }; std::vector inputs; @@ -762,126 +795,129 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { inputs.push_back(Eigen::half{test_case.input}); expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); } + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E5M2); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); - // Pass in ErrorSpec, as this causes all NaNs to be treated as equal. ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2F16RoundtripExhaustive) { - // Convert from FP8 to FP16, then back to FP8 +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e5m2Roundtrip)) { + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - } - - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16); - ConvertElementType(all_f8_as_f16, F8E5M2); - - // Pass in ErrorSpec, as this causes all NaNs to be treated as equal. - // Round-tripping a NaN will turn it into a quiet NaN and doesn't necessarily - // preserve the payload. - ComputeAndCompareR1(&builder, all_f8, {}, ErrorSpec(0.)); -} + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {inf, inf}, + // clang-format on + {0x1.2p0, 0x1p0}, // Round-to-even down + {0x1.6p0, 0x1.8p0}, // Round-to-even up + {0x1.Cp15, 0x1.Cp15}, // Max value + {0x1.DFFFFEp15, 0x1.Cp15}, // Largest number that doesn't overflow + {0x1.Ep15, inf}, // Smallest number that overflows + {0x1p16, inf}, // Overflow + {0x1p-14, 0x1p-14}, // Smallest F8 normal + {0x1.Cp-15, 0x1p-14}, // Smallest number rounding up to normal -XLA_TEST_F(ConvertTest, ConvertF8e5m2F32Exhaustive) { - // Convert from f8e5m2 to f32. - XlaBuilder builder(TestName()); + // Denormal tests + {0x1.0p-15, 0x1.0p-15}, // Denormal without rounding + {0x1.4p-15, 0x1.0p-15}, // Round-to-even down + {0x1.Cp-15, 0x1.0p-14}, // Round-to-even up + {0x1.3p-15, 0x1.0p-15}, // Round-to-nearest down + {0x1.5p-15, 0x1.8p-15}, // Round-to-nearest up + {0x1p-17, 0}, // Largest number that underflows + {0x1.000002p-17, 0x1p-16}, // Smallest number that doesn't underflow + {0x1.BFFFFEp-15, 0x1.8p-15}, // Largest number that rounds to denormal + }; - std::vector all_f8; - std::vector all_f32; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - all_f32.push_back(tsl::float8_e5m2::ConvertTo(all_f8.back())); + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); - - ComputeAndCompareR1(&builder, all_f32, {}, ErrorSpec(0.)); + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E5M2); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzF32Exhaustive) { - // Convert from f8e5m2fnuz to f32. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; - std::vector all_f32; + using From = tsl::float8_e5m2; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - all_f32.push_back(tsl::float8_e5m2fnuz::ConvertTo(all_f8.back())); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); - - ComputeAndCompareR1(&builder, all_f32, {}, ErrorSpec(0.)); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3F32Exhaustive) { - // Convert from f8e4m3 to f32. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; - std::vector all_f32; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - all_f32.push_back(tsl::float8_e4m3fn::ConvertTo(all_f8.back())); + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); - - ComputeAndCompareR1(&builder, all_f32, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2F16RoundtripExhaustive2) { - // Convert from F16 to FP8. +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - std::vector inputs; - for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + using From = tsl::float8_e5m2; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E5M2); + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2BF16RoundtripExhaustive3) { - // Convert from BF16 to FP8. +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2F16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); - std::vector inputs; + std::vector inputs; for (int i = 0; i < 65536; i++) { inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_bf16_to_f8, F8E5M2); + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E5M2); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E4M3FN + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -910,8 +946,8 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { {0x1.0p-8, 0x1.0p-8}, // Denormal without rounding {0x1.4p-8, 0x1.0p-8}, // Round-to-even down {0x1.Cp-8, 0x1.0p-7}, // Round-to-even up - {0x1.5p-7, 0x1.4p-7}, // Round-to-nearest down - {0x1.3p-7, 0x1.4p-7}, // Round-to-nearest up + {0x1.3p-8, 0x1.0p-8}, // Round-to-nearest down + {0x1.5p-8, 0x1.8p-8}, // Round-to-nearest up {0x1p-10, 0}, // Largest number that underflows {0x1.004p-10, 0x1p-9}, // Smallest number that doesn't underflow {0x1.DFCp-7, 0x1.Cp-7}, // Largest number that rounds to denormal @@ -927,95 +963,124 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FN); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); - // Pass in ErrorSpec, as this causes all NaNs to be treated as equal. ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive) { - // Convert from FP8 to FP16, then back to FP8 +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3fnRoundtrip)) { + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, nan}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Cp8, 0x1.Cp8}, // Max value + {0x1.Dp8, 0x1.Cp8}, // Largest number that doesn't overflow + {0x1.D00002p8, nan}, // Smallest number that overflows + {0x1p9, nan}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest F8 normal + {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-8, 0x1.0p-8}, // Denormal without rounding + {0x1.4p-8, 0x1.0p-8}, // Round-to-even down + {0x1.Cp-8, 0x1.0p-7}, // Round-to-even up + {0x1.3p-8, 0x1.0p-8}, // Round-to-nearest down + {0x1.5p-8, 0x1.8p-8}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.000002p-10, 0x1p-9}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-7, 0x1.Cp-7}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16); - ConvertElementType(all_f8_as_f16, F8E4M3FN); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FN); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive2) { - // Convert from FP32 to FP8. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e4m3fn; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(static_cast( - Eigen::numext::bit_cast(static_cast(i)))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f32, F8E4M3FN); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive3) { - // Convert from FP8 to FP32. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnRoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive4) { - // Convert from F16 to FP8. +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - std::vector inputs; - for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + using From = tsl::float8_e4m3fn; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E4M3FN); + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnBF16RoundtripExhaustive5) { - // Convert from BF16 to FP8. +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3fnF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); - std::vector inputs; + std::vector inputs; for (int i = 0; i < 65536; i++) { inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_bf16_to_f8, F8E4M3FN); + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E4M3FN); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E4M3B11FNUZ + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3b11fnuzRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1044,8 +1109,8 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3b11fnuzRoundtrip) { {0x1.0p-12, 0x1.0p-12}, // Denormal without rounding {0x1.4p-12, 0x1.0p-12}, // Round-to-even down {0x1.Cp-12, 0x1.0p-11}, // Round-to-even up - {0x1.5p-11, 0x1.4p-11}, // Round-to-nearest down - {0x1.3p-11, 0x1.4p-11}, // Round-to-nearest up + {0x1.3p-12, 0x1.0p-12}, // Round-to-nearest down + {0x1.5p-12, 0x1.8p-12}, // Round-to-nearest up {0x1p-14, 0}, // Largest number that underflows {0x1.004p-14, 0x1p-13}, // Smallest number that doesn't underflow {0x1.DFCp-11, 0x1.Cp-11}, // Largest number that rounds to denormal @@ -1061,67 +1126,125 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3b11fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3B11FNUZ); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive) { - // Convert from FP8 to FP16, then back to FP8 +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3b11fnuzRoundtrip)) { + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, 0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, nan}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep4, 0x1.Ep4}, // Max value + {0x1.EFFFFEp4, 0x1.Ep4}, // Largest number that doesn't overflow + {0x1.Fp4, nan}, // Smallest number that overflows + {0x1p5, nan}, // Overflow + {0x1p-10, 0x1p-10}, // Smallest F8 normal + {0x1.Ep-11, 0x1p-10}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-12, 0x1.0p-12}, // Denormal without rounding + {0x1.4p-12, 0x1.0p-12}, // Round-to-even down + {0x1.Cp-12, 0x1.0p-11}, // Round-to-even up + {0x1.3p-12, 0x1.0p-12}, // Round-to-nearest down + {0x1.5p-12, 0x1.8p-12}, // Round-to-nearest up + {0x1p-14, 0}, // Largest number that underflows + {0x1.000002p-14, 0x1p-13}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-11, 0x1.Cp-11}, // Largest number that rounds to denormal + }; - std::vector all_f8; + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E4M3B11FNUZ); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3b11fnuzRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e4m3b11fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(Eigen::numext::bit_cast( - static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16); - ConvertElementType(all_f8_as_f16, F8E4M3B11FNUZ); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3B11FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive2) { - // Convert from FP32 to FP8. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3b11fnuzRoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + std::vector all_f8; for (int i = 0; i < 256; i++) { all_f8.push_back( - static_cast(Eigen::numext::bit_cast( + static_cast(Eigen::numext::bit_cast( static_cast(i)))); } - xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f32, F8E4M3B11FNUZ); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3B11FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive3) { - // Convert from FP8 to FP32. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3b11fnuzRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e4m3b11fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(Eigen::numext::bit_cast( - static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3b11fnuzF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f16), F8E4M3B11FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E5M2FNUZ + XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1132,11 +1255,11 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, - {-0.0, 0.0}, // No signed zero in F8E5M2FNUZ + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, {nan, nan}, - {inf, nan}, // No Inf in F8E4M3FNUZ + {inf, nan}, // clang-format on {0x1.2p0, 0x1p0}, // Round-to-even down {0x1.6p0, 0x1.8p0}, // Round-to-even up @@ -1168,18 +1291,12 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E5M2FNUZ); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { - // Convert from FP32 to FP8, then back to FP32 + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1190,11 +1307,11 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, - {-0.0, 0.0}, // No signed zero in F8E5M2FNUZ + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, {nan, nan}, - {inf, nan}, // No Inf in F8E4M3FNUZ + {inf, nan}, // clang-format on {0x1.2p0, 0x1p0}, // Round-to-even down {0x1.6p0, 0x1.8p0}, // Round-to-even up @@ -1209,12 +1326,11 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { {0x1.0p-16, 0x1.0p-16}, // Denormal without rounding {0x1.4p-16, 0x1.0p-16}, // Round-to-even down {0x1.Cp-16, 0x1.0p-15}, // Round-to-even up - {0x1.3FFFFEp-16, 0x1.0p-16}, // Round-to-nearest down - {0x1.5FFFFEp-16, 0x1.8p-16}, // Round-to-nearest up + {0x1.3p-16, 0x1.0p-16}, // Round-to-nearest down + {0x1.5p-16, 0x1.8p-16}, // Round-to-nearest up {0x1p-18, 0}, // Largest number that underflows {0x1.000002p-18, 0x1p-17}, // Smallest number that doesn't underflow {0x1.BFFFFEp-16, 0x1.8p-16}, // Largest number that rounds to denormal - {0x1.FFFFFEp-50, 0}, // A very small input that should underflow }; std::vector inputs; @@ -1226,41 +1342,24 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E5M2FNUZ); ConvertElementType(f8, F32); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzRoundtripExhaustive) { - // Convert from FP8 to each supported floating type, then back to FP8. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e5m2fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(Eigen::numext::bit_cast( - static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); - - for (auto type : {F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_type = ConvertElementType(all_f8_as_f8, type); - ConvertElementType(all_f8_as_type, F8E5M2FNUZ); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); - } - - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E5M2FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive2) { @@ -1274,62 +1373,43 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive2) { static_cast(i)))); } - xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f32, F8E5M2FNUZ); + ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2FNUZ); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzRoundtripExhaustive3) { - // Convert from FP8 to supported floating point types. - XlaBuilder builder(TestName()); - - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - } - - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, type); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); - } -} - -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzF16RoundtripExhaustive4) { - // Convert from F16 to FP8. +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - std::vector inputs; - for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + using From = tsl::float8_e5m2fnuz; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E5M2FNUZ); + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzBF16RoundtripExhaustive5) { - // Convert from BF16 to FP8. +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2fnuzF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); - std::vector inputs; + std::vector all_f16; for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_bf16_to_f8, F8E5M2FNUZ); + ConvertElementType(ConstantR1(&builder, all_f16), F8E5M2FNUZ); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E4M3FNUZ + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1340,10 +1420,10 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, - {-0.0, 0.0}, // No signed zero in F8E4M3FNUZ + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, - {inf, nan}, // No Inf in F8E4M3FNUZ + {inf, nan}, // clang-format on {0x1.1p0, 0x1p0}, // Round-to-even down {0x1.3p0, 0x1.4p0}, // Round-to-even up @@ -1358,8 +1438,8 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { {0x1.0p-9, 0x1.0p-9}, // Denormal without rounding {0x1.4p-9, 0x1.0p-9}, // Round-to-even down {0x1.Cp-9, 0x1.0p-8}, // Round-to-even up - {0x1.5p-8, 0x1.4p-8}, // Round-to-nearest down - {0x1.3p-8, 0x1.4p-8}, // Round-to-nearest up + {0x1.3p-9, 0x1.0p-9}, // Round-to-nearest down + {0x1.5p-9, 0x1.8p-9}, // Round-to-nearest up {0x1p-11, 0}, // Largest number that underflows {0x1.004p-11, 0x1p-10}, // Smallest number that doesn't underflow {0x1.DFCp-8, 0x1.Cp-8}, // Largest number that rounds to denormal @@ -1375,18 +1455,12 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FNUZ); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1397,10 +1471,10 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, - {-0.0, 0.0}, // No signed zero in F8E4M3FNUZ + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, - {inf, nan}, // No Inf in F8E4M3FNUZ + {inf, nan}, // clang-format on {0x1.1p0, 0x1p0}, // Round-to-even down {0x1.3p0, 0x1.4p0}, // Round-to-even up @@ -1415,12 +1489,11 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { {0x1.0p-9, 0x1.0p-9}, // Denormal without rounding {0x1.4p-9, 0x1.0p-9}, // Round-to-even down {0x1.Cp-9, 0x1.0p-8}, // Round-to-even up - {0x1.5p-8, 0x1.4p-8}, // Round-to-nearest down - {0x1.3p-8, 0x1.4p-8}, // Round-to-nearest up + {0x1.3p-9, 0x1.0p-9}, // Round-to-nearest down + {0x1.5p-9, 0x1.8p-9}, // Round-to-nearest up {0x1p-11, 0}, // Largest number that underflows {0x1.000002p-11, 0x1p-10}, // Smallest number that doesn't underflow {0x1.DFFFFEp-8, 0x1.Cp-8}, // Largest number that rounds to denormal - {0x1.FFFFFEp-50, 0}, // A very small input that should underflow }; std::vector inputs; @@ -1432,45 +1505,28 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FNUZ); ConvertElementType(f8, F32); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzRoundtripExhaustive) { - // Convert from FP8 to each supported floating type, then back to FP8. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e4m3fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - } - - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); - - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_type = ConvertElementType(all_f8_as_f8, type); - ConvertElementType(all_f8_as_type, F8E4M3FNUZ); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive2) { - // Convert from support floating types to FP8. + // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName()); std::vector all_f8; @@ -1480,99 +1536,38 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive2) { static_cast(i)))); } - xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f32, F8E4M3FNUZ); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3FNUZ); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzRoundtripExhaustive3) { - // Convert from FP8 to supported floating point types. - XlaBuilder builder(TestName()); - - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - } - - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, type); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); - } -} - -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzF16RoundtripExhaustive4) { - // Convert from F16 to FP8. +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - std::vector inputs; - for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + using From = tsl::float8_e4m3fnuz; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E4M3FNUZ); + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzBF16RoundtripExhaustive5) { - // Convert from BF16 to FP8. +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3fnuzF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); - std::vector inputs; + std::vector all_f16; for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_bf16_to_f8, F8E4M3FNUZ); + ConvertElementType(ConstantR1(&builder, all_f16), F8E4M3FNUZ); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2ToPred) { - XlaBuilder builder(TestName()); - using F8 = tsl::float8_e5m2; - auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); - ConvertElementType(a, PRED); - - std::array expected = {false, true, true}; - ComputeAndCompareR1(&builder, expected, {}); -} - -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnToPred) { - XlaBuilder builder(TestName()); - using F8 = tsl::float8_e4m3fn; - auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); - ConvertElementType(a, PRED); - - std::array expected = {false, true, true}; - ComputeAndCompareR1(&builder, expected, {}); -} - -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzToPred) { - XlaBuilder builder(TestName()); - using F8 = tsl::float8_e5m2fnuz; - auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); - ConvertElementType(a, PRED); - - std::array expected = {false, true, true}; - ComputeAndCompareR1(&builder, expected, {}); -} - -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzToPred) { - XlaBuilder builder(TestName()); - using F8 = tsl::float8_e4m3fnuz; - auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); - ConvertElementType(a, PRED); - - std::array expected = {false, true, true}; - ComputeAndCompareR1(&builder, expected, {}); -} - } // namespace } // namespace xla From 05e1e449ce0324ec1e7e6797a63bc91048458c99 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Mon, 23 Sep 2024 06:16:24 -0700 Subject: [PATCH 138/483] [Triton] Modify back some tests that were breaking when block_k was set to 16. These no longer fail due to modifications in RemoveLayoutConversions pass that canonicalize arith.sitofp which used to have dot_operand layout when lowering to LLVM to now be in Blocked layout. This essentially means these tests no longer trigger the failures we had due to small tilings that Triton couldn't handle. PiperOrigin-RevId: 677764125 --- .../service/gpu/autotuning/gemm_fusion_autotuner_test.cc | 9 +++------ .../triton/triton_fusion_emitter_parametrized_test.cc | 4 +--- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index f47003ecea4256..d9bec3a09906a8 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -451,7 +451,6 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -// Modify block_k back to 16 once b/337839570 is fixed. // TODO(b/344770374): Make this test not fragile. TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) { const std::string kHloText = R"( @@ -470,7 +469,7 @@ ENTRY %e { %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) %convert = s8[4,12288]{1,0} parameter(1) ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); @@ -494,7 +493,6 @@ ENTRY %e { ::testing::HasSubstr("Insufficient registers")))); } -// Modify block_k back to 16 once b/337839570 is fixed. // TODO(b/344770374): Make this test not fragile. TEST_F(GemmFusionAutotunerTest, DoNotFilterOutAutotuningKernelSpillingRegisters) { @@ -517,7 +515,7 @@ ENTRY %e { %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) %convert = s8[4,12288]{1,0} parameter(1) ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); @@ -540,7 +538,6 @@ ENTRY %e { EXPECT_NE(executable, nullptr); } -// Modify block_k back to 16 once b/337839570 is fixed. TEST_F(GemmFusionAutotunerTest, RunAutotuningKernelNotSpillingRegisters) { const std::string kHloText = R"( HloModule m @@ -556,7 +553,7 @@ ENTRY %e { %p0 = s8[12288,1536]{1,0} parameter(0) %p1 = f16[4,12288]{1,0} parameter(1) ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}} })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index 11b9ae776a3157..96a3e1dace9bcf 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -140,9 +140,7 @@ INSTANTIATE_TEST_SUITE_P(RewriteTestSuite, MixedTypeTest, // TritonRewriteTest2Params{F32, F16}, // TritonRewriteTest2Params{F32, BF16}, MixTypeParams{S8, BF16, 24, 40, 8}, - // Modify the case below to use k = 32 instead of - // 16 once b/337839570 is fixed. - MixTypeParams{S8, F16, 80, 32, 32, 1e-3, 1e-6}, + MixTypeParams{S8, F16, 80, 16, 32, 1e-3, 1e-6}, MixTypeParams{F16, F32, 127, 3, 300, 1e-2, 1e-2}, MixTypeParams{F16, BF16, 544, 96, 16, 1e-3, 1e-3}, MixTypeParams{BF16, F32, 77, 500, 333, 3e-3, 3e-3}, From 23e3fe1874d883e35a82cfb7be79b1cacab63c81 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Mon, 23 Sep 2024 06:53:37 -0700 Subject: [PATCH 139/483] PR #17359: [ffi] Support prepare stage in custom call thunk Imported from GitHub PR https://github.com/openxla/xla/pull/17359 Augments custom calls with support for `ffi::ExecutionStage::kPrepare` handlers of XLA's FFI. cc @ezhulenev Copybara import of the project: -- 9ada84b3067b36de135c0cb4918a4640f61d9d4a by Georg Stefan Schmid : [ffi] Support prepare stage in custom call thunk Merging this change closes #17359 PiperOrigin-RevId: 677774591 --- .../xla/xla/service/gpu/custom_call_test.cc | 18 +++++-- .../service/gpu/runtime/custom_call_thunk.cc | 47 ++++++++++++------- .../service/gpu/runtime/custom_call_thunk.h | 3 +- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc index 2de44347f8ce35..f377c0e0ae3dd8 100644 --- a/third_party/xla/xla/service/gpu/custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/custom_call_test.cc @@ -715,6 +715,7 @@ TEST_F(CustomCallTest, WithCalledComputation) { struct SomeExtraContext { explicit SomeExtraContext(int32_t value) : value(value) {} int32_t value; + bool prepared = false; bool initialized = false; bool executed = false; }; @@ -723,15 +724,25 @@ template static absl::Status ExecutionContext(ffi::Result, SomeExtraContext* ctx) { if (ctx->value != 42) return absl::InternalError("Unexpected value"); - if constexpr (stage == ffi::ExecutionStage::kInitialize) { + if constexpr (stage == ffi::ExecutionStage::kPrepare) { + ctx->prepared = true; + } else if constexpr (stage == ffi::ExecutionStage::kInitialize) { ctx->initialized = true; - } else { + } else if constexpr (stage == ffi::ExecutionStage::kExecute) { ctx->executed = true; + } else { + return absl::InternalError("Unexpected stage"); } return absl::OkStatus(); } +XLA_FFI_DEFINE_HANDLER(kExecutionContextPrepare, + ExecutionContext, + ffi::Ffi::Bind() + .Ret() + .Ctx>()); + XLA_FFI_DEFINE_HANDLER(kExecutionContextInitialize, ExecutionContext, ffi::Ffi::Bind() @@ -748,7 +759,7 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla.gpu.ffi_execution_context", PLATFORM, { /*instantiate=*/nullptr, - /*prepare=*/nullptr, + /*prepare=*/kExecutionContextPrepare, /*initialize=*/kExecutionContextInitialize, /*execute=*/kExecutionContextExecute, }); @@ -774,6 +785,7 @@ TEST_F(CustomCallTest, FfiExecutionContext) { // Check that FFI handler was called during initialization and execution. TF_ASSERT_OK_AND_ASSIGN(auto* user_context, execution_context.Lookup()); + EXPECT_TRUE(user_context->prepared); EXPECT_TRUE(user_context->initialized); EXPECT_TRUE(user_context->executed); } diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc index 7cf44c109cb267..436ba7416515d3 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc @@ -156,19 +156,27 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) { } absl::Status CustomCallThunk::ExecuteFfiHandler( - XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage, - int32_t device_ordinal, se::Stream* stream, - se::DeviceMemoryAllocator* allocator, + XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage, se::Stream* stream, const ffi::ExecutionContext* execution_context, const BufferAllocations* buffer_allocations) { if (handler == nullptr) { return absl::InternalError("FFI execute handler is not set"); } + if (stage != XLA_FFI_ExecutionStage_PREPARE && + !(buffer_allocations && stream)) { + return absl::InternalError("buffer allocations and stream are required"); + } // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing // a lot of extra allocation on every call. We have to keep attributes // separate from arguments, as they do not change after thunk is constructed. CallFrameBuilder builder(operands_.size(), results_.size()); + auto device_address = + [buffer_allocations]( + BufferAllocation::Slice slice) -> se::DeviceMemoryBase { + return buffer_allocations ? buffer_allocations->GetDeviceAddress(slice) + : se::DeviceMemoryBase{}; + }; for (auto& operand : operands_) { if (!operand.has_value()) { @@ -179,7 +187,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( if (!operand->slice.allocation()) return Internal("custom call argument missing buffer allocation"); - builder.AddBufferArg(buffer_allocations->GetDeviceAddress(operand->slice), + builder.AddBufferArg(device_address(operand->slice), operand->shape.element_type(), operand->shape.dimensions()); } @@ -193,7 +201,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( if (!result->slice.allocation()) return Internal("custom call result missing buffer allocation"); - builder.AddBufferRet(buffer_allocations->GetDeviceAddress(result->slice), + builder.AddBufferRet(device_address(result->slice), result->shape.element_type(), result->shape.dimensions()); } @@ -204,6 +212,13 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( builder.AddAttributes(attrs.Build()); CallFrame call_frame = builder.Build(); + int32_t device_ordinal = -1; + se::DeviceMemoryAllocator* allocator = nullptr; + if (stage != XLA_FFI_ExecutionStage_PREPARE) { + device_ordinal = buffer_allocations->device_ordinal(); + allocator = buffer_allocations->memory_allocator(); + } + CallOptions options = { device_ordinal, CallOptions::GpuOptions{stream, allocator}, called_computation_, execution_context, execution_state_.get()}; @@ -212,10 +227,14 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( absl::Status CustomCallThunk::Prepare(const PrepareParams& params, ResourceRequests& resource_requests) { - if (bundle_ && bundle_->prepare) { - return absl::InternalError("FFI prepare stage is not yet supported"); + if (!bundle_ || !bundle_->prepare) { + return absl::OkStatus(); } - return absl::OkStatus(); + + return ExecuteFfiHandler(bundle_->prepare, XLA_FFI_ExecutionStage_PREPARE, + /*stream=*/nullptr, + /*execution_context=*/nullptr, + /*buffer_allocations=*/nullptr); } absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { @@ -224,19 +243,15 @@ absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { } return ExecuteFfiHandler( - bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE, - params.buffer_allocations->device_ordinal(), params.stream, - params.buffer_allocations->memory_allocator(), + bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE, params.stream, params.ffi_execution_context, params.buffer_allocations); } absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { if (bundle_.has_value()) { - return ExecuteFfiHandler( - bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, - params.buffer_allocations->device_ordinal(), params.stream, - params.buffer_allocations->memory_allocator(), - params.ffi_execution_context, params.buffer_allocations); + return ExecuteFfiHandler(bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, + params.stream, params.ffi_execution_context, + params.buffer_allocations); } return ExecuteCustomCall(params); } diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h index c65676381f9c8a..e67b9e89d3a867 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h @@ -120,8 +120,7 @@ class CustomCallThunk : public Thunk { absl::Status ExecuteFfiHandler(XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage, - int32_t device_ordinal, se::Stream* stream, - se::DeviceMemoryAllocator* allocator, + se::Stream* stream, const ffi::ExecutionContext* execution_context, const BufferAllocations* buffer_allocations); From 94dd5cd0eb3f08d2ce379d275b9c41064b407271 Mon Sep 17 00:00:00 2001 From: Paul Chang Date: Mon, 23 Sep 2024 07:57:15 -0700 Subject: [PATCH 140/483] Minor clarifications for Gather, fix cut-and-paste error in Scatter docs PiperOrigin-RevId: 677794138 --- third_party/xla/docs/operation_semantics.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/third_party/xla/docs/operation_semantics.md b/third_party/xla/docs/operation_semantics.md index 55849974726628..92943521cd0343 100644 --- a/third_party/xla/docs/operation_semantics.md +++ b/third_party/xla/docs/operation_semantics.md @@ -1500,8 +1500,9 @@ if, e.g., `offset_dims.size` is `4`, `operand.rank` is `6` and `1`→`3`, `2`→`4`, `3`→`5`}. If `indices_are_sorted` is set to true then XLA can assume that `start_indices` -are sorted (in ascending `start_index_map` order) by the user. If they are not -then the semantics is implementation defined. +are sorted (in ascending order, _after_ scattering its values according to +`start_index_map`) by the user. If they are not then the semantics are +implementation defined. ### Informal Description and Examples @@ -2493,9 +2494,10 @@ always be the current value from the `output` array and the second parameter will always be the value from the `updates` array. This is important specifically for cases when the `update_computation` is _not commutative_. -If `indices_are_sorted` is set to true then XLA can assume that `start_indices` -are sorted (in ascending `start_index_map` order) by the user. If they are not -then the semantics is implementation defined. +If `indices_are_sorted` is set to true then XLA can assume that `scatter_indices` +are sorted (in ascending order, _after_ scattering its values according to +`scatter_dims_to_operand_dims`) by the user. If they are not then the semantics +are implementation defined. If `unique_indices` is set to true then XLA can assume that all elements scattered to are unique. So XLA could use non-atomic operations. If From 9c7053b1b54c50f7ed0a6485d833dafebdd645de Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 23 Sep 2024 08:31:01 -0700 Subject: [PATCH 141/483] #sdy change `ManualComputationOp` SDY round tripping to use a `CallOp` with `CustomCalls` to change the shapes local<->global WRT to the mesh. PiperOrigin-RevId: 677805310 --- .../xla/xla/service/spmd/shardy/constants.h | 6 - .../shardy/sdy_round_trip/shard_map_export.cc | 58 ++++++--- .../shardy/sdy_round_trip/shard_map_export.h | 6 +- .../shardy/sdy_round_trip/shard_map_import.cc | 90 ++++++++++---- .../shardy/sdy_round_trip/shard_map_import.h | 7 +- .../test/sdy_round_trip_shard_map_export.mlir | 115 +++++++++++++----- .../test/sdy_round_trip_shard_map_import.mlir | 95 ++++++++++++--- ...y_round_trip_shard_map_import_failure.mlir | 16 ++- 8 files changed, 280 insertions(+), 113 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/constants.h b/third_party/xla/xla/service/spmd/shardy/constants.h index 6d92bbc1660ed2..020c6a6c893fb1 100644 --- a/third_party/xla/xla/service/spmd/shardy/constants.h +++ b/third_party/xla/xla/service/spmd/shardy/constants.h @@ -81,12 +81,6 @@ inline constexpr llvm::StringRef kOutShardings = "xla.sdy.out_shardings"; // Attribute name for the manual axes of a `ManualComputationOp`. inline constexpr llvm::StringRef kManualAxes = "xla.sdy.manual_axes"; -// The target name of the custom call that will store the various attrs of a -// `ManualComputationOp` and a reference to a `FuncOp` that is the body of the -// original `ManualComputationOp`. -inline constexpr llvm::StringRef kManualComputationCustomCallTargetName = - "xla.sdy.ManualComputation"; - // The function name of the of the body of a `ManualComputationOp` during Shardy // round tripping. Used inline constexpr llvm::StringRef kManualComputationBodyFuncName = diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc index 08363df42d500a..35c88549e867ea 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc @@ -31,7 +31,9 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" @@ -52,6 +54,7 @@ namespace { using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::StringRef; +using ::mlir::func::CallOp; using ::mlir::func::FuncOp; namespace stablehlo = ::mlir::stablehlo; @@ -72,29 +75,46 @@ class SdyRoundTripShardMapExportPass auto rewriter = mlir::IRRewriter(context); moduleOp->walk([&](sdy::ManualComputationOp manualComputation) { rewriter.setInsertionPointToEnd(&moduleOp.getRegion().front()); + mlir::Location loc = manualComputation.getLoc(); + mlir::Region& manualCompBody = manualComputation.getBody(); + mlir::TypeRange manualCompBodyArgTypes = + manualCompBody.getArgumentTypes(); + mlir::TypeRange localResultTypes = + sdy::getBodyTerminatorOpOperandTypes(manualComputation); auto funcOp = rewriter.create( - manualComputation.getLoc(), kManualComputationBodyFuncName, - rewriter.getFunctionType( - manualComputation.getBody().getArgumentTypes(), - sdy::getBodyTerminatorOpOperandTypes(manualComputation))); - sdy::inlineRegionAndConvertTerminatorOp( - manualComputation.getBody(), funcOp.getBody()); + loc, kManualComputationBodyFuncName, + rewriter.getFunctionType(manualCompBodyArgTypes, localResultTypes)); mlir::StringAttr funcName = symbolTable.insert(funcOp); rewriter.setInsertionPoint(manualComputation); - auto customCallOp = rewriter.create( - manualComputation.getLoc(), manualComputation.getResultTypes(), - manualComputation->getOperands()); - customCallOp.setCallTargetName(kManualComputationCustomCallTargetName); - customCallOp.setCalledComputationsAttr( - rewriter.getArrayAttr(mlir::FlatSymbolRefAttr::get(funcName))); - addFrontendAttribute(customCallOp, kInShardings, + stablehlo::CustomCallOp fullToShard; + mlir::ValueRange operands = manualComputation->getOperands(); + if (!operands.empty()) { + fullToShard = rewriter.create( + loc, manualCompBodyArgTypes, operands); + fullToShard.setCallTargetName(kSPMDFullToShardShapeCallTargetName); + operands = fullToShard->getResults(); + } + + auto callOp = + rewriter.create(loc, localResultTypes, funcName, operands); + addFrontendAttribute(callOp, kInShardings, manualComputation.getInShardings()); - addFrontendAttribute(customCallOp, kOutShardings, + addFrontendAttribute(callOp, kOutShardings, manualComputation.getOutShardings()); - addFrontendAttribute(customCallOp, kManualAxes, + addFrontendAttribute(callOp, kManualAxes, manualComputation.getManualAxesAttr()); - rewriter.replaceOp(manualComputation, customCallOp->getResults()); + + mlir::ResultRange results = manualComputation->getResults(); + if (!results.empty()) { + auto shardToFull = rewriter.create( + loc, manualComputation.getResultTypes(), callOp->getResults()); + shardToFull.setCallTargetName(kSPMDShardToFullShapeCallTargetName); + results = shardToFull->getResults(); + } + sdy::inlineRegionAndConvertTerminatorOp( + manualCompBody, funcOp.getBody()); + rewriter.replaceOp(manualComputation, results); }); } @@ -104,9 +124,9 @@ class SdyRoundTripShardMapExportPass StringRef getDescription() const override { return "Converts the body of a ManualComputationOp to a separate function " - "with a CustomCallOp of the same name referring to it. The " - "CustomCallOp saves the in/out shardings and manual axes as " - "frontend attrs for HLO round tripping."; + "with a CallOp and a pair of CustomCallOps that change the shape of " + "the arguments/results. The CallOp saves the in/out shardings and " + "manual axes as frontend attrs."; } void getDependentDialects(mlir::DialectRegistry& registry) const final { registry.insert(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h index 34eef5eeb13cf6..c3a7ed9b7ea3a7 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h @@ -23,8 +23,10 @@ limitations under the License. namespace xla { namespace sdy { -// Creates the pass that converts `ManualComputationOps` to a separate function -// and `CustomCallOp` for round tripping between HLO. +// Creates the pass that converts `ManualComputationOp`s to a separate function +// with a CallOp and a pair of `CustomCallOp`s that change the shape of the +// arguments/results. The CallOp saves the in/out shardings and manual axes as +// frontend attrs. std::unique_ptr createSdyRoundTripShardMapExportPass(); // Registers the xla-sdy-round-trip-shard-map-export pass. diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc index 8975fa2142691c..dec5bd2cbc84e1 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc @@ -15,14 +15,15 @@ limitations under the License. #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" +#include #include #include #include "absl/log/check.h" +#include "absl/strings/match.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" @@ -31,7 +32,9 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -55,47 +58,72 @@ using ::mlir::ModuleOp; using ::mlir::OpConversionPattern; using ::mlir::StringRef; using ::mlir::SymbolTable; +using ::mlir::func::CallOp; using ::mlir::func::FuncOp; -using ::mlir::stablehlo::CustomCallOp; +namespace stablehlo = ::mlir::stablehlo; namespace sdy = ::mlir::sdy; -// Converts `CustomCallOp`s called `@local_xla.sdy.ManualComputation` with in/out -// shardings and manual axes as frontend attrs to `ManualComputationOp`s. -class ManualComputationPattern : public OpConversionPattern { +// Converts a CallOp calling a @local_xla.sdy.manual_computation_body func with in/out +// shardings and manual axes as frontend attrs, wrapped with custom calls that +// change the shape of the arguments/results to a `ManualComputationOp`. See +// `SdyRoundTripShardMapExportPass` for its counterpart. +class ManualComputationPattern : public OpConversionPattern { public: explicit ManualComputationPattern(MLIRContext* context, const SymbolTable& symbolTable) - : OpConversionPattern(context), symbolTable(symbolTable) {} + : OpConversionPattern(context), symbolTable(symbolTable) {} mlir::LogicalResult matchAndRewrite( - CustomCallOp customCallOp, OpAdaptor adaptor, + CallOp callOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override { - if (customCallOp.getCallTargetName() != - kManualComputationCustomCallTargetName) { + if (!absl::StartsWith(callOp.getCallee(), kManualComputationBodyFuncName)) { return mlir::failure(); } - CHECK_EQ(customCallOp.getCalledComputations().size(), 1); - auto shmapBodyFunc = - symbolTable.lookup((*customCallOp.getCalledComputations() - .getAsRange() - .begin()) - .getValue()); + // NOTE: if the original `ManualComputationOp` had no operands (results), + // then a @FullToShard (@ShardToFull) custom call won't be present. So + // we have to take the operands/results of the newly created + // `ManualComputationOp` differently depending on whether the original had + // operands/results. + stablehlo::CustomCallOp fullToShard; + mlir::ValueRange operands = callOp->getOperands(); + if (!operands.empty()) { + fullToShard = + callOp->getOperand(0).getDefiningOp(); + operands = fullToShard->getOperands(); + CHECK(fullToShard); + CHECK(fullToShard.getCallTargetName() == + kSPMDFullToShardShapeCallTargetName); + } + mlir::TypeRange resultTypes = callOp->getResultTypes(); + stablehlo::CustomCallOp shardToFull; + if (!resultTypes.empty()) { + CHECK(callOp->getResult(0).hasOneUse()) + << "all CallOp results should be used by a single ShardToFull"; + shardToFull = mlir::cast( + *callOp->getResult(0).getUsers().begin()); + CHECK(shardToFull.getCallTargetName() == + kSPMDShardToFullShapeCallTargetName); + resultTypes = shardToFull->getResultTypes(); + } + + auto shmapBodyFunc = symbolTable.lookup(callOp.getCallee()); if (shmapBodyFunc.empty()) { - return customCallOp->emitOpError( + return callOp->emitOpError( "expected a unique FuncOp per " - "@local_xla.sdy.ManualComputation custom call. Were " + "@local_xla.sdy.manual_computation_body call. Were " "functions maybe somehow shared/de-duped between " "two ManualComputations?"); } - mlir::DictionaryAttr frontendAttrs = getFrontendAttrs(customCallOp); - CHECK(frontendAttrs); + mlir::DictionaryAttr frontendAttrs = getFrontendAttrs(callOp); + CHECK(frontendAttrs) + << "Expected in/out shardings and manual axes as frontend attrs on the " + "CallOp during round tripping."; auto manualComputationOp = rewriter.replaceOpWithNewOp( - customCallOp, customCallOp->getResultTypes(), - customCallOp->getOperands(), + callOp, resultTypes, operands, parseStringAttr(frontendAttrs, kInShardings), parseStringAttr(frontendAttrs, @@ -104,6 +132,12 @@ class ManualComputationPattern : public OpConversionPattern { sdy::inlineRegionAndConvertTerminatorOp( shmapBodyFunc.getBody(), manualComputationOp.getRegion(), rewriter); rewriter.eraseOp(shmapBodyFunc); + if (fullToShard) { + rewriter.eraseOp(fullToShard); + } + if (shardToFull) { + rewriter.replaceOp(shardToFull, manualComputationOp->getResults()); + } return mlir::success(); } @@ -124,10 +158,11 @@ class SdyRoundTripShardMapImportPass SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(module); MLIRContext& context = getContext(); mlir::ConversionTarget target(context); - target.addDynamicallyLegalOp([](CustomCallOp op) { - return op.getCallTargetName() != kManualComputationCustomCallTargetName; + target.addDynamicallyLegalOp([](CallOp op) { + return !absl::StartsWith(op.getCallee(), kManualComputationBodyFuncName); }); - target.addLegalOp(); + target.addLegalOp(); mlir::RewritePatternSet patterns(&context); patterns.add(&context, symbolTable); if (mlir::failed(mlir::applyPartialConversion(module, target, @@ -141,9 +176,10 @@ class SdyRoundTripShardMapImportPass } StringRef getDescription() const override { - return "converts CustomCalls called @local_xla.sdy.manual_computation_body " - "with in/out shardings and manual axes as frontend attrs to a " - "`ManualComputationOp`"; + return "converts a CallOp calling a @local_xla.sdy.manual_computation_body func " + "with in/out shardings and manual axes as frontend attrs, wrapped " + "with a pair of `CustomCallOps` that change the shape of the " + "arguments/results, to a ManualComputationOp"; } void getDependentDialects(mlir::DialectRegistry& registry) const final { registry.insert(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h index 1520c8baa663f7..e84304a177dce9 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h @@ -23,9 +23,10 @@ limitations under the License. namespace xla { namespace sdy { -// Creates the pass that converts a `CustomCallOp` called -// `kManualComputationBodyFuncName` with in/out shardings and manual -// axes as frontend attrs to a `ManualComputationOp`. +// Creates the pass that converts a `CallOp` calling +// `@local_xla.sdy.manual_computation_body` with in/out shardings and manual +// axes as frontend attrs, wrapped with a pair of `CustomCallOp`s that change +// the shape of the arguments/results, to a `ManualComputationOp`. std::unique_ptr createSdyRoundTripShardMapImportPass(); // Registers the xla-sdy-round-trip-shard-map-import pass. diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir index 2376056f8f0735..aefca62fc88b6d 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir @@ -5,14 +5,15 @@ sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]> // CHECK-LABEL: func @single_manual_comp func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) { - // CHECK: %[[SHMAP:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0, %arg1) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]]:2 = stablehlo.custom_call @SPMDFullToShardShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body(%[[FULL_TO_SHARD]]#0, %[[FULL_TO_SHARD]]#1) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} - // CHECK-SAME: : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> - // CHECK-NEXT: return %[[SHMAP]] : tensor<8x32xf32> + // CHECK-SAME: : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_0, [{"a"}, {"b"}]>, <@mesh_0, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_0, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { %1 = stablehlo.add %arg2, %arg2 : tensor<2x8xf32> %2 = stablehlo.dot %1, %arg3 : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> @@ -31,20 +32,23 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.shard // CHECK-LABEL: func @manual_comp_using_another func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"b"}]>}) { - // CHECK: %[[SHMAP_0:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_0], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHMAP_0:.*]] = call @local_xla.sdy.manual_computation_body_0(%[[FULL_TO_SHARD_0]]) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} - // CHECK-SAME: : (tensor<8x8xf32>) -> tensor<8x8xf32> - // CHECK-NEXT: %[[SHMAP_1:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%[[SHMAP_0]]) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_1], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_0:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP_0]]) : (tensor<2x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[SHARD_TO_FULL_0]]) : (tensor<8x8xf32>) -> tensor<8x4xf32> + // CHECK-NEXT: %[[SHMAP_1:.*]] = call @local_xla.sdy.manual_computation_body_1(%[[FULL_TO_SHARD_1]]) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} - // CHECK-SAME: : (tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK-SAME: : (tensor<8x4xf32>) -> tensor<8x4xf32 + // CHECK-NEXT: %[[SHARD_TO_FULL_1:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP_1]]) : (tensor<8x4xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL_1]] : tensor<8x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { sdy.return %arg1 : tensor<2x8xf32> } : (tensor<8x8xf32>) -> tensor<8x8xf32> @@ -57,14 +61,15 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy // CHECK-LABEL: func @nested_shmaps func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { - // CHECK: %[[SHMAP:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_3], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_3(%[[FULL_TO_SHARD]]) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} - // CHECK-SAME: : (tensor<4x8xf32>) -> tensor<4x8xf32 - // CHECK-NEXT: return %[[SHMAP]] : tensor<4x8xf32> + // CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> @@ -77,13 +82,15 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@m // CHECK-LABEL: func @nested_shmaps_extra_op func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { - // CHECK: %[[SHMAP:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_5], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_5(%[[FULL_TO_SHARD]]) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", - // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<4x8xf32>) -> tensor<4x8xf32> - // CHECK-NEXT: return %[[SHMAP]] : tensor<4x8xf32> + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} + // CHECK-SAME: (tensor<2x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> @@ -95,6 +102,40 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sh return %0 : tensor<4x8xf32> } +// CHECK-LABEL: func @manual_computation_no_inputs +func.func @manual_computation_no_inputs() -> tensor<4xi64> { + // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_6() + // CHECK-SAME: {mhlo.frontend_attributes = { + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} + // CHECK-SAME: () -> tensor<2xi64> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2xi64>) -> tensor<4xi64> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4xi64> + %0 = sdy.manual_computation() in_shardings=[] out_shardings=[<@mesh_0, [{"b"}]>] manual_axes={"b"} () { + %1 = stablehlo.constant dense<[2, 3]> : tensor<2xi64> + sdy.return %1 : tensor<2xi64> + } : () -> tensor<4xi64> + func.return %0 : tensor<4xi64> +} + +// CHECK-LABEL: func @manual_computation_no_outputs +func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + // CHECK-NEXT: call @local_xla.sdy.manual_computation_body_7(%[[FULL_TO_SHARD]]) + // CHECK-SAME: {mhlo.frontend_attributes = { + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} + // CHECK-SAME: : (tensor<2xi64>) -> () + // CHECK-NEXT: return + sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{"b"}]>] out_shardings=[] manual_axes={"b"} (%arg1: tensor<2xi64>) { + stablehlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () + sdy.return + } : (tensor<4xi64>) -> () + func.return +} + // CHECK-LABEL: func @local_xla.sdy.manual_computation_body(%arg0: tensor<2x8xf32>, %arg1: tensor<8x32xf32>) -> tensor<2x32xf32> // CHECK-NEXT: stablehlo.add // CHECK-NEXT: stablehlo.dot @@ -110,24 +151,34 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sh // CHECK-NEXT: stablehlo.multiply %arg0, %arg0 : tensor<2x4xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32 -// CHECK-NEXT: stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) -// CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_2], -// CHECK-SAME: mhlo.frontend_attributes = { +// CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_2(%[[FULL_TO_SHARD]]) +// CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} -// CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> +// CHECK-SAME: : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x4xf32>) -> tensor<2x8xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK-NEXT: stablehlo.multiply %arg0, %arg0 : tensor<2x4xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> -// CHECK-NEXT: %[[SHMAP:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) -// CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_4], -// CHECK-SAME: mhlo.frontend_attributes = { +// CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32 +// CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_4(%[[FULL_TO_SHARD]]) +// CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} -// CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> -// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHMAP]], %[[SHMAP]] : tensor<2x8xf32> +// CHECK-SAME: : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x4xf32>) -> tensor<2x8xf32> +// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_TO_FULL]], %[[SHARD_TO_FULL]] : tensor<2x8xf32> // CHECK-NEXT: return %[[ADD]] : tensor<2x8xf32> + +// CHECK-LABEL: func @local_xla.sdy.manual_computation_body_6() -> tensor<2xi64> { +// CHECK-NEXT: %[[C:.*]] = stablehlo.constant dense<[2, 3]> : tensor<2xi64> +// CHECK-NEXT: return %[[C]] : tensor<2xi64> + +// CHECK-LABEL: func @local_xla.sdy.manual_computation_body_7(%arg0: tensor<2xi64>) { +// CHECK-NEXT: stablehlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () +// CHECK-NEXT: return diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir index 1f49e3858cd899..4211756a8dcf02 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir @@ -4,7 +4,7 @@ sdy.mesh @mesh_0 = <["a"=4, "b"=2]> sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]> func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) -> (tensor<8x32xf32>) { - // CHECK-NOT: xla.sdy.ManualComputation + // CHECK-NOT: call @local_xla.sdy.manual_computation_body // CHECK: %[[MAN_COMP:.*]] = sdy.manual_computation(%arg0, %arg1) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {"b"}]>, <@mesh_0, [{"b"}, {}], replicated={"a"}>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{"a"}, {}], replicated={"b"}>] @@ -20,12 +20,14 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) // CHECK-NEXT: sdy.return %[[REDUCE]] : tensor<2x32xf32> // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x32xf32> - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0, %arg1) {called_computations = [@local_xla.sdy.manual_computation_body], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> - return %0 : tensor<8x32xf32> + %0:2 = stablehlo.custom_call @SPMDFullToShardShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> + return %2 : tensor<8x32xf32> } func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK-NOT: xla.sdy.ManualComputation + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_0 // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{"a"}, {}]>] @@ -33,6 +35,7 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-SAME: (%arg1: tensor<2x8xf32>) { // CHECK-NEXT: sdy.return %arg1 : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_1 // CHECK-NEXT: %[[MAN_COMP_1:.*]] = sdy.manual_computation(%[[MAN_COMP_0]]) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{}, {"b"}]>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{}, {"b"}]>] @@ -41,15 +44,21 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: sdy.return %arg1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[MAN_COMP_1]] : tensor<8x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> - %1 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%0) {called_computations = [@local_xla.sdy.manual_computation_body_1], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> - return %1 : tensor<8x8xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + %1 = call @local_xla.sdy.manual_computation_body_0(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> + %4 = call @local_xla.sdy.manual_computation_body_1(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x4xf32>) -> tensor<8x4xf32> + %5 = stablehlo.custom_call @SPMDShardToFullShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> + return %5 : tensor<8x8xf32> } // CHECK-NOT: func @local_xla.sdy.manual_computation_body_3( func.func @local_xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_2], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - return %0 : tensor<2x8xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %1 = call @local_xla.sdy.manual_computation_body_2(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + return %2 : tensor<2x8xf32> } // CHECK-NOT: func @local_xla.sdy.manual_computation_body_2( @@ -59,7 +68,7 @@ func.func @local_xla.sdy.manual_computation_body_2(%arg0: tensor<2x4xf32>) -> te } func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NOT: xla.sdy.ManualComputation + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_3 // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{"a"}, {}]>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_1, [{"a"}, {}]>] @@ -76,12 +85,14 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[MAN_COMP_1]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_3], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<4x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %1 = call @local_xla.sdy.manual_computation_body_3(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> } func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NOT: xla.sdy.ManualComputation + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_5 // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{"a"}, {}]>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_1, [{"a"}, {}]>] @@ -99,8 +110,42 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[ADD]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_5], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<4x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %1 = call @local_xla.sdy.manual_computation_body_5(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} + +func.func @manual_computation_no_inputs() -> tensor<4xi64> { + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_6 + // CHECK: %[[SHMAP:.*]] = sdy.manual_computation() + // CHECK-SAME{LITERAL}: in_shardings=[] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{"b"}]>] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME{LITERAL}: () { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant dense<[2, 3]> : tensor<2xi64> + // CHECK-NEXT: sdy.return %[[C]] : tensor<2xi64> + // CHECK-NEXT: } : () -> tensor<4xi64> + // CHECK-NEXT: return %[[SHMAP]] : tensor<4xi64> + %0 = call @local_xla.sdy.manual_computation_body_6() {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} : () -> tensor<2xi64> + %1 = stablehlo.custom_call @SPMDShardToFullShape(%0) : (tensor<2xi64>) -> tensor<4xi64> + return %1 : tensor<4xi64> +} + +func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_7 + // CHECK: sdy.manual_computation(%arg0) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"b"}]>] + // CHECK-SAME{LITERAL}: out_shardings=[] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME{LITERAL}: (%arg1: tensor<2xi64>) { + // CHECK-NEXT: stablehlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () + // CHECK-NEXT: sdy.return + // CHECK-NEXT: } : (tensor<4xi64>) -> () + // CHECK-NEXT: return + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + call @local_xla.sdy.manual_computation_body_7(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} : (tensor<2xi64>) -> () + return } // CHECK-NOT: func @local_xla.sdy.manual_computation_body( @@ -133,7 +178,21 @@ func.func @local_xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> te // CHECK-NOT: func @local_xla.sdy.manual_computation_body_5( func.func @local_xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_4], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %1 = stablehlo.add %0, %0 : tensor<2x8xf32> - return %1 : tensor<2x8xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %1 = call @local_xla.sdy.manual_computation_body_4(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %3 = stablehlo.add %2, %2 : tensor<2x8xf32> + return %3 : tensor<2x8xf32> +} + +// CHECK-NOT: func @local_xla.sdy.manual_computation_body_6( +func.func @local_xla.sdy.manual_computation_body_6() -> tensor<2xi64> { + %c = stablehlo.constant dense<[2, 3]> : tensor<2xi64> + return %c : tensor<2xi64> +} + +// CHECK-NOT: func @local_xla.sdy.manual_computation_body_7( +func.func @local_xla.sdy.manual_computation_body_7(%arg0: tensor<2xi64>) { + stablehlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () + return } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir index d1516899e450d6..7effeae63ae5be 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir @@ -3,13 +3,17 @@ sdy.mesh @mesh = <["a"=2]> func.func @using_same_body_func(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> - // expected-error @+2 {{'stablehlo.custom_call' op expected a unique FuncOp per @local_xla.sdy.ManualComputation custom call}} - // expected-error @+1 {{failed to legalize operation 'stablehlo.custom_call'}} - %1 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%0) {called_computations = [@local_xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> - return %1 : tensor<8x8xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %1 = call @local_xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + // expected-error @+2 {{'func.call' op expected a unique FuncOp per @local_xla.sdy.manual_computation_body call}} + // expected-error @+1 {{failed to legalize operation 'func.call'}} + %4 = call @local_xla.sdy.manual_computation_body(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) + %5 = stablehlo.custom_call @SPMDShardToFullShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + return %5 : tensor<8x8xf32> } -func.func @local_xla.sdy.manual_computation_body_0(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { +func.func @local_xla.sdy.manual_computation_body(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { return %arg0 : tensor<2x8xf32> } From 3aa946e395785980a755e91c5a0113b8d1f1a890 Mon Sep 17 00:00:00 2001 From: Yifan Jiang Date: Mon, 23 Sep 2024 10:13:59 -0700 Subject: [PATCH 142/483] Wait for events in a different thread if they are not defined yet. PiperOrigin-RevId: 677843165 --- third_party/xla/xla/pjrt/BUILD | 1 + .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 46 +++++++ .../xla/xla/pjrt/local_device_state.cc | 13 ++ third_party/xla/xla/pjrt/local_device_state.h | 14 +++ .../xla/pjrt/pjrt_stream_executor_client.cc | 114 +++++++++++------- .../xla/xla/pjrt/tracked_device_buffer.cc | 11 +- 6 files changed, 152 insertions(+), 47 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 978ce97ce210ba..7eda0b9e119b8d 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -151,6 +151,7 @@ cc_library( "//xla:util", "//xla/client:local_client", "//xla/stream_executor", + "//xla/tsl/util:env_var", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index b306194ed4eefd..cebfa6879fffcd 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -674,6 +674,52 @@ TEST(StreamExecutorGpuClientTest, FromHostAsyncPinnedHostChunked) { EXPECT_THAT(lit->data(), ElementsAreArray(data)); } +TEST(StreamExecutorGpuClientTest, DeleteBufferThenFulfillBufferNoDeadLock) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetStreamExecutorGpuClient(GpuClientOptions())); + ASSERT_THAT(client->addressable_devices(), SizeIs(Gt(0))); + TF_ASSERT_OK_AND_ASSIGN( + PjRtMemorySpace * memspace, + client->addressable_devices()[0]->memory_space_by_kind( + PinnedHostMemorySpace::kKind)); + std::vector data{1, 3, 5, 7, 11, 13, 17, 19}; + Shape shape = ShapeUtil::MakeShape(F32, {static_cast(data.size())}); + std::vector> + txms; + for (int i = 0; i < 10000; ++i) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr txm, + client->CreateBuffersForAsyncHostToDevice({shape}, memspace)); + std::unique_ptr buf = txm->RetrieveBuffer(0); + ASSERT_THAT(buf->GetReadyFuture().IsReady(), Eq(false)); + txms.push_back(std::move(txm)); + // Delete the buffer + } + + // At this point, we have 10000 buffers pending deallocation. + + absl::string_view raw_view(reinterpret_cast(data.data()), + data.size() * sizeof(data[0])); + for (auto& txm : txms) { + int offset = 0; + while (true) { + int end = offset + 3; // unaligned chunk size + if (end > raw_view.size()) { + end = raw_view.size(); + } + int sz = end - offset; + bool reaches_end = end == raw_view.size(); + TF_ASSERT_OK(txm->TransferRawDataToSubBuffer( + /*buffer_index=*/0, raw_view.data() + offset, offset, sz, reaches_end, + /*on_done=*/[]() {})); + if (reaches_end) { + break; + } + offset = end; + } + } +} + TEST(StreamExecutorGpuClientTest, CopyRawToHostFullBuffer) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index 62eba2b6238098..dd037deade6a0c 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -51,6 +52,16 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, prng_seed_generator_(prng_seed_device_()), prng_seed_distribution_(std::numeric_limits::min(), std::numeric_limits::max()) { + // Setting XLA_PJRT_GPU_ALLOW_DELETE_BEFORE_FULFILL to false will: + // 1. disallow the host to schedule `create buffer -> use -> delete -> + // fulfill`, which is a use case unit tested in + // StreamExecutorGpuClientTest.DeleteBufferThenFulfillBufferNoDeadLock. + // 2. potentially reduce spikes in HBM usage because the host will wait for + // buffer fulfillment to be scheduled before destructing it. + absl::Status status = + tsl::ReadBoolFromEnvVar("XLA_PJRT_GPU_ALLOW_DELETE_BEFORE_FULFILL", true, + &allow_delete_before_fulfill_); + local_hardware_id_ = executor_->device_ordinal(); local_device_id_ = device_ordinal != -1 ? device_ordinal : executor_->device_ordinal(); @@ -103,6 +114,8 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, std::make_unique(tsl::Env::Default(), "py_xla_execute"); callback_thread_ = std::make_unique(tsl::Env::Default(), "py_xla_callback"); + cleanup_thread_ = + std::make_unique(tsl::Env::Default(), "py_xla_cleanup"); } LocalDeviceState::~LocalDeviceState() { diff --git a/third_party/xla/xla/pjrt/local_device_state.h b/third_party/xla/xla/pjrt/local_device_state.h index 1ce1f1ea7d5401..a7ed7addd84499 100644 --- a/third_party/xla/xla/pjrt/local_device_state.h +++ b/third_party/xla/xla/pjrt/local_device_state.h @@ -170,6 +170,8 @@ class LocalDeviceState { WorkerThread* execute_thread() const { return execute_thread_.get(); } + WorkerThread* cleanup_thread() const { return cleanup_thread_.get(); } + // Enqueues a host callback on 'stream'. `stream` may, but need not, wait for // `callback` to complete. It is safe to call runtime methods from the // callback. @@ -199,6 +201,12 @@ class LocalDeviceState { // Returns a fresh, PRNG-generated random seed for an XLA computation. int GetNewPrngSeed(); + // Whether to allow deleting a buffer before the operation fulfilling the + // buffer is scheduled by the host. + bool allow_delete_before_fulfill() const { + return allow_delete_before_fulfill_; + } + private: absl::Status SynchronizeAllActivity(); @@ -255,6 +263,12 @@ class LocalDeviceState { // semaphore during calls to Execute but release it from a callback and if // they are the same thread we might deadlock. std::unique_ptr callback_thread_; + + // One thread dedicated to cleaning up buffers. Scheduled work on this thread + // may wait for other threads to schedule writes to buffers. + std::unique_ptr cleanup_thread_; + + bool allow_delete_before_fulfill_ = true; }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index c9e1c61cd56a3b..cb92e535e10f60 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -425,6 +425,27 @@ absl::Status AddDestinationBufferSynchronization( return absl::OkStatus(); } +// We wait for events that the compute stream didn't already wait for. Based on +// our heuristics, for usage events, this rare case should only occur when a +// buffer was copied to a device and then never used there. In that case we get +// a new stream and use it to hold onto a reference to the buffer until the +// events are complete. +void MaybeWaitForEventOnStream(BufferSequencingEvent* event, + LocalDeviceState* local_device_state, + se::Stream*& stream) { + if (!event->IsPredeterminedErrorOrDefinedOn( + local_device_state->compute_stream()) && + !event->IsComplete()) { + if (stream == nullptr) { + stream = local_device_state->GetFixedSizePoolUsageStream(); + } + VLOG(2) << "Waiting for event: " << event + << "; is_predetermined_error: " << event->IsPredeterminedError() + << "; on stream: " << stream; + event->WaitForEventOnStream(stream); + } +} + } // namespace absl::StatusOr> @@ -1492,6 +1513,24 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { if (local_device_state->allocation_model() == LocalDeviceState::kComputeSynchronized) { se::Stream* block_stream = nullptr; + // If an event is not defined yet, we wait for it to be defined in a new + // thread in the thread pool. + // This allows the host to schedule: + // create buffer -> use -> delete -> fulfill + absl::InlinedVector, 5> + events_to_wait_for_in_a_different_thread; + auto maybe_wait_for_event_on_block_stream_or_add_to_events_to_wait = + [&events_to_wait_for_in_a_different_thread, local_device_state, + &block_stream](const std::shared_ptr& event) { + if (local_device_state->allow_delete_before_fulfill() && + !event->IsDefined()) { + // Wait for the event to be defined in a different thread. + events_to_wait_for_in_a_different_thread.push_back(event); + } else { + MaybeWaitForEventOnStream(event.get(), local_device_state, + block_stream); + } + }; for (const auto& stream_and_event : events) { VLOG(2) << "Checking whether need to wait for stream_and_event: stream: " @@ -1501,25 +1540,11 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { << "; is_predetermined_error: " << stream_and_event.event->IsPredeterminedError(); // We only need to do something for events that didn't already acquire a - // reference to the buffer, and also which the compute stream didn't - // already wait for. Based on our heuristics this rare case should only - // occur when a buffer was copied to a device and then never used there. - // In that case we get a new stream and use it to hold onto a reference - // to the buffer until the events are complete. - if (!stream_and_event.reference_held && - !stream_and_event.event->IsPredeterminedErrorOrDefinedOn( - local_device_state->compute_stream()) && - !stream_and_event.event->IsComplete()) { - if (block_stream == nullptr) { - block_stream = local_device_state->GetFixedSizePoolUsageStream(); - } - VLOG(2) << "Waiting for stream_and_event: stream: " - << stream_and_event.stream - << "; event: " << stream_and_event.event.get() - << "; reference_held: " << stream_and_event.reference_held - << "; is_predetermined_error: " - << stream_and_event.event->IsPredeterminedError(); - stream_and_event.event->WaitForEventOnStream(block_stream); + // reference to the buffer and for other situations described in the + // comment of MaybeWaitForEventOnStream() + if (!stream_and_event.reference_held) { + maybe_wait_for_event_on_block_stream_or_add_to_events_to_wait( + stream_and_event.event); } } for (const auto& definition_event : device_buffer->definition_events()) { @@ -1527,31 +1552,34 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { << definition_event.get() << "; is_predetermined_error: " << definition_event->IsPredeterminedError(); // Here we wait for the definition events to complete on block_stream as - // well, if they are not on the compute stream and not also recorded as - // usage events. - // - // Since it's possible that definition_event.SetSequencingEvent() - // is called on a different host thread than this host thread, when in - // future more conditions are added to this check, we should be careful - // about whether we put them before the IsPredeterminedErrorOrDefinedOn - // check or after it. For example, we shouldn't add an IsDefined() check - // before the IsPredeterminedErrorOrDefinedOn() check here because that - // could potentially cause a shortcut where we don't wait for - // definition_event.SetSequencingEvent() on the other thread and - // eventually cause memory corruption. - if (!definition_event->IsPredeterminedErrorOrDefinedOn( - local_device_state->compute_stream()) && - !definition_event->IsComplete()) { - if (block_stream == nullptr) { - block_stream = local_device_state->GetFixedSizePoolUsageStream(); - } - VLOG(2) << "Waiting for definition_event: " << definition_event.get() - << "; is_predetermined_error: " - << definition_event->IsPredeterminedError(); - definition_event->WaitForEventOnStream(block_stream); - } + // well, in case they are not also usage events. + maybe_wait_for_event_on_block_stream_or_add_to_events_to_wait( + definition_event); } - if (block_stream != nullptr) { + if (!events_to_wait_for_in_a_different_thread.empty()) { + VLOG(1) << "Going to wait for " + << events_to_wait_for_in_a_different_thread.size() + << " events in a different thread."; + // We always use the cleanup_thread instead of using the + // client->thread_pool() here to avoid exhausting the client thread + // pool. + local_device_state->cleanup_thread()->Schedule( + [events_to_wait_for_in_a_different_thread = + std::move(events_to_wait_for_in_a_different_thread), + local_device_state, device_buffer, block_stream]() mutable { + for (const auto& event : + events_to_wait_for_in_a_different_thread) { + MaybeWaitForEventOnStream(event.get(), local_device_state, + block_stream); + } + if (block_stream != nullptr) { + TF_CHECK_OK(local_device_state->ThenExecuteCallback( + block_stream, [device_buffer]() { + // Drops device_buffer shared pointer. + })); + } + }); + } else if (block_stream != nullptr) { TF_RETURN_IF_ERROR(local_device_state->ThenExecuteCallback( block_stream, [device_buffer]() { // Drops device_buffer shared pointer. diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index 2fe88dcacdca59..16ab8a669da699 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -159,10 +159,13 @@ void BufferSequencingEvent::ExecuteOrAddToFutureTasks( void BufferSequencingEvent::ExecuteFutureTasks() { absl::MutexLock lock(&mu_); - for (auto& [task_name, task_callback] : on_ready_tasks_callback_) { - thread_pool_->Schedule(std::move(task_callback)); - } - on_ready_tasks_callback_.clear(); + auto call_all_task_callbacks = [on_ready_tasks_callback = + std::move(on_ready_tasks_callback_)]() { + for (auto& [task_name, task_callback] : on_ready_tasks_callback) { + task_callback(); + } + }; + thread_pool_->Schedule(std::move(call_all_task_callbacks)); } /* static */ std::shared_ptr From 4bb72be9e5ba7a29fd37de85ebaf7e2d03156c37 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Mon, 23 Sep 2024 10:17:51 -0700 Subject: [PATCH 143/483] [XLA:GPU] Verify async instruction pairs for send/recv Send/recv must feed into the corresponding *-done operation or through control flow. PiperOrigin-RevId: 677844838 --- third_party/xla/xla/service/hlo_verifier.cc | 20 +++++++++++ .../xla/xla/service/hlo_verifier_test.cc | 34 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 8c71ad693bfb33..0b7e4bd9f916b1 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -2368,6 +2368,26 @@ absl::Status VerifyAsynchronousInstructionPairs(const HloModule& module) { instruction, {HloOpcode::kCollectivePermuteStart})); break; } + case HloOpcode::kSend: { + TF_RETURN_IF_ERROR(VerifySingleUser( + instruction, {HloOpcode::kSendDone, HloOpcode::kTuple})); + break; + } + case HloOpcode::kSendDone: { + TF_RETURN_IF_ERROR(VerifySingleOperand( + instruction, {HloOpcode::kSend, HloOpcode::kGetTupleElement})); + break; + } + case HloOpcode::kRecv: { + TF_RETURN_IF_ERROR(VerifySingleUser( + instruction, {HloOpcode::kRecvDone, HloOpcode::kTuple})); + break; + } + case HloOpcode::kRecvDone: { + TF_RETURN_IF_ERROR(VerifySingleOperand( + instruction, {HloOpcode::kRecv, HloOpcode::kGetTupleElement})); + break; + } default: break; } diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index 6c649a8c0ff004..877c445c6f5aa8 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -3428,5 +3428,39 @@ TEST_F(HloVerifierTest, NoErrorOnDuplicateChannelId) { ASSERT_IS_OK(verifier.Run(module.get()).status()); } +TEST_F(HloVerifierTest, ChannelVerifierAsyncSend) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + data = f32[16] parameter(0) + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(after_all, data), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT send_done = (f32[16], token[]) send-done(send), channel_id=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierAsyncRecv) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + recv_done = (f32[16], token[]) recv-done(recv), channel_id=1 + ROOT result = f32[16] get-tuple-element(recv_done), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + } // namespace } // namespace xla From ab30e0fc889d62c9fab1a078d772c7999218dd16 Mon Sep 17 00:00:00 2001 From: Eric Salo Date: Mon, 23 Sep 2024 10:38:03 -0700 Subject: [PATCH 144/483] cleanup: remove api_version from BUILD files PiperOrigin-RevId: 677852856 --- tensorflow/compiler/mlir/quantization/stablehlo/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index b95ec31e959b47..47510e04c27abe 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -726,7 +726,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "quantization_options_py_pb2", -# api_version = 2, # visibility = [":internal_visibility_allowlist_package"], # deps = [":quantization_options_proto"], # ) @@ -746,7 +745,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "quantization_config_py_pb2", -# api_version = 2, # visibility = [ # ":internal_visibility_allowlist_package", # ], From db1384fe4adc60ecd8bc17f783c5d48fd29f4f3e Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Mon, 23 Sep 2024 11:27:36 -0700 Subject: [PATCH 145/483] Migrate dense_to_sparse pass to new TFL::Pass mechanism and. remove the .td definition. PiperOrigin-RevId: 677873271 --- tensorflow/compiler/mlir/lite/BUILD | 12 +-- tensorflow/compiler/mlir/lite/sparsity/BUILD | 1 + .../mlir/lite/sparsity/sparsify_model.cc | 5 +- ...e_to_sparse.cc => dense_to_sparse_pass.cc} | 15 +--- .../lite/transforms/dense_to_sparse_pass.h | 80 +++++++++++++++++++ .../compiler/mlir/lite/transforms/passes.h | 5 -- .../compiler/mlir/lite/transforms/passes.td | 29 ------- 7 files changed, 92 insertions(+), 55 deletions(-) rename tensorflow/compiler/mlir/lite/transforms/{dense_to_sparse.cc => dense_to_sparse_pass.cc} (97%) create mode 100644 tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 5cbf2c1db15911..0f2f7f700450a2 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1127,19 +1127,19 @@ cc_library( cc_library( name = "tensorflow_lite_d2s", srcs = [ - "transforms/dense_to_sparse.cc", + "transforms/dense_to_sparse_pass.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/dense_to_sparse_pass.h", ], deps = [ - ":tensorflow_lite", - ":tensorflow_lite_passes_inc_gen", + ":pass", + ":pass_options", + ":tensorflow_lite_ops", "//tensorflow/compiler/mlir/lite/kernels/internal/utils:sparsity_format_converter", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD index 79a355dcb73e3d..3b9825f72bde45 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/BUILD +++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD @@ -29,6 +29,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", + "//tensorflow/compiler/mlir/lite:pass_registry_utils", "//tensorflow/compiler/mlir/lite:tensorflow_lite_d2s", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index e180d3d46a9d8b..76a64df1f6e7bd 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -36,7 +36,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" @@ -67,7 +68,7 @@ absl::Status SparsifyModel(const tflite::ModelT& input_model, } PassManager pm((*module)->getName(), OpPassManager::Nesting::Implicit); - pm.addPass(TFL::CreateDenseToSparsePass()); + pm.addPass(TFL::Create()); if (failed(pm.run(module.get()))) { LOG(ERROR) << "Failed to sparsify: " diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.cc similarity index 97% rename from tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc rename to tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.cc index 73d102a0502f1f..7668a8af959a60 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.cc @@ -14,18 +14,17 @@ limitations under the License. ==============================================================================*/ // This transformation pass convert dense tensor to sparse format. +#include "tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h" #include "absl/memory/memory.h" #include "Eigen/Core" // from @eigen_archive #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" //===----------------------------------------------------------------------===// // The DenseToSparse Pass. @@ -35,9 +34,6 @@ namespace TFL { namespace { -#define GEN_PASS_DEF_DENSETOSPARSEPASS -#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" - // If sparsity level is below this threshold, keep the tensor in dense format. constexpr float kMinSparsityLevel = 0.3; // Heuristic to check if a block configuration is correct for float constants. @@ -277,13 +273,7 @@ std::vector BuildSparsityParameterAttribute( return compressed_data; } - -struct DenseToSparsePass - : public impl::DenseToSparsePassBase { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DenseToSparsePass) - - void runOnOperation() override; -}; +} // namespace void DenseToSparsePass::runOnOperation() { func::FuncOp func = getOperation(); @@ -418,7 +408,6 @@ void DenseToSparsePass::runOnOperation() { }); } -} // namespace // Creates an instance of the TensorFlow Lite dialect DenseToSparse pass. std::unique_ptr> CreateDenseToSparsePass() { diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h new file mode 100644 index 00000000000000..fa39e09c8d0aad --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h @@ -0,0 +1,80 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass convert dense tensor to sparse format. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" + +namespace mlir { +namespace TFL { + +// This pass encodes sparse weights in the model in the proper format, and adds +// Densify() op if necessary. The general algorithm is: +// 1. Get list of operands (weights) of an op that can be sparse. +// 2. Get list of supported block configurations of the op. +// 3. Calculate random sparsity of the weight. +// 3.1. If sparsity level is below the encoding threshold, keep in dense. +// 3.2. If sparsity level is above the encoding threshold, go to 4. +// 4. Try to encode the weight with supported block configurations. If the +// weight was pruned with the same block config, the blocked sparsity level +// should match the random sparsity. +// 4.1. Return the matching block config if found. +// 4.2. If no matching block config is found, encode the weight with random +// sparsity, and add Densify() op to fall back to dense execution. + +class DenseToSparsePass + : public Pass { + public: + DenseToSparsePass() = default; + DenseToSparsePass(const DenseToSparsePass &other) {} + + void runOnOperation() final; + + /// Returns the command-line argument attached to this pass. + static llvm::StringRef GetArgument() { return "tfl-dense-to-sparse"; } + + static llvm::StringRef GetDescription() { + return "Convert dense tensor to sparse format."; + } + + /// Returns the derived pass name. + static llvm::StringRef GetName() { return "DenseToSparsePass"; } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + /// Explicitly declare the TypeID for this class. We declare an explicit + /// private instantiation because Pass classes should only be visible by the + /// current library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DenseToSparsePass) +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 5dc38cc8317003..152fe75f2ff309 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -181,10 +181,6 @@ std::unique_ptr> CreateDefaultQuantParamsPass(); // Creates an instance of the IdentifyDilatedConvPass. std::unique_ptr> CreateIdentifyDilatedConvPass(); -// Creates an instance of the TensorFlow Lite dialect pass to convert dense -// tensor to sparse format. -std::unique_ptr> CreateDenseToSparsePass(); - // Creates function pass to legalize TF While to TFL While. std::unique_ptr> CreateLegalizeTFWhilePass(); @@ -267,7 +263,6 @@ std::unique_ptr> CreatePartitionedTopologicalSortPass(); #define GEN_PASS_DECL_DEFAULTQUANTPARAMSPASS -#define GEN_PASS_DECL_DENSETOSPARSEPASS #define GEN_PASS_DECL_LEGALIZETFPASS #define GEN_PASS_DECL_LOWERSTATICTENSORLISTPASS #define GEN_PASS_DECL_MODIFYIONODESPASS diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 09ba813055a333..692f4e4d1ded1e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -52,35 +52,6 @@ def DecomposeHybridQuantizationPass : Pass<"tfl-decompose-hybrid-quantization", let dependentDialects = ["TFL::TensorFlowLiteDialect"]; } -def DenseToSparsePass : Pass<"tfl-dense-to-sparse", "mlir::func::FuncOp"> { - let summary = "Convert dense tensor to sparse format."; - let description = [{ - This pass encodes sparse weights in the model in the proper format, and adds - Densify() op if necessary. The general algorithm is: - 1. Get list of operands (weights) of an op that can be sparse. - 2. Get list of supported block configurations of the op. - 3. Calculate random sparsity of the weight. - 3.1. If sparsity level is below the encoding threshold, keep in dense. - 3.2. If sparsity level is above the encoding threshold, go to 4. - 4. Try to encode the weight with supported block configurations. If the - weight was pruned with the same block config, the blocked sparsity level - should match the random sparsity. - 4.1. Return the matching block config if found. - 4.2. If no matching block config is found, encode the weight with random - sparsity, and add Densify() op to fall back to dense execution. - }]; - let constructor = "CreateDenseToSparsePass()"; - let dependentDialects = ["TFL::TensorFlowLiteDialect"]; - let options = [ - Option<"default_min_", "default-min", "double", "-1.0", - "Default minimum value for TFLite quantization">, - Option<"default_max_", "default-max", "double", "1.0", - "Default maximum value for TFLite quantization">, - Option<"is_signed_", "is-signed", "bool", "false", - "Is the corresponding integer signed">, - ]; -} - def IdentifyDilatedConvPass : Pass<"tfl-identify-dilated-conv", "mlir::func::FuncOp"> { let summary = "Convert dense tensor to sparse format."; let constructor = "CreateIdentifyDilatedConvPass()"; From b0b0fb84fc873330fc34703fae0a79cccb265a22 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Mon, 23 Sep 2024 11:56:39 -0700 Subject: [PATCH 146/483] Fixes layout for int4 while loading weights on XLA PiperOrigin-RevId: 677884790 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 10 ++++++++++ .../xla/xla/service/generic_transfer_manager.cc | 12 ++++++++++++ .../xla/xla/service/generic_transfer_manager.h | 3 +++ .../xla/xla/service/generic_transfer_manager_test.cc | 8 ++++++++ 4 files changed, 33 insertions(+) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index cebfa6879fffcd..359b621349a9e6 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -1589,5 +1589,15 @@ TEST(StreamExecutorGpuClientTest, nullptr); } +TEST(StreamExecutorGpuClientTest, GetDefaultLayout) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + auto shape = ShapeUtil::MakeShape(S4, {2, 2}); + TF_ASSERT_OK_AND_ASSIGN( + auto layout, + client->GetDefaultLayout(shape.element_type(), shape.dimensions())); + EXPECT_EQ(layout.element_size_in_bits(), 4); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/generic_transfer_manager.cc b/third_party/xla/xla/service/generic_transfer_manager.cc index 436bf144e6a7cf..85168eb9a969c9 100644 --- a/third_party/xla/xla/service/generic_transfer_manager.cc +++ b/third_party/xla/xla/service/generic_transfer_manager.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/service/shaped_buffer.h" @@ -300,4 +301,15 @@ Shape GenericTransferManager::HostShapeToDeviceShape( return device_shape; } +absl::StatusOr GenericTransferManager::ChooseCompactLayoutForShape( + const Shape& host_shape) const { + Shape compact_shape = LayoutUtil::GetWithDefaultLayout(host_shape); + if (PackSubbyteTypes() && + primitive_util::IsSubByteNonPredType(compact_shape.element_type())) { + compact_shape.mutable_layout()->set_element_size_in_bits( + primitive_util::BitWidth(compact_shape.element_type())); + } + return compact_shape; +} + } // namespace xla diff --git a/third_party/xla/xla/service/generic_transfer_manager.h b/third_party/xla/xla/service/generic_transfer_manager.h index 3503cff66b7dc0..22a2178792b16e 100644 --- a/third_party/xla/xla/service/generic_transfer_manager.h +++ b/third_party/xla/xla/service/generic_transfer_manager.h @@ -83,6 +83,9 @@ class GenericTransferManager : public TransferManager { Shape HostShapeToDeviceShape(const Shape& host_shape) const override; + absl::StatusOr ChooseCompactLayoutForShape( + const Shape& host_shape) const override; + private: // Transfer a memory block of the given size from the device source into the // 'destination' buffer. diff --git a/third_party/xla/xla/service/generic_transfer_manager_test.cc b/third_party/xla/xla/service/generic_transfer_manager_test.cc index eb8cb7afa85004..8153090dbcbf55 100644 --- a/third_party/xla/xla/service/generic_transfer_manager_test.cc +++ b/third_party/xla/xla/service/generic_transfer_manager_test.cc @@ -181,5 +181,13 @@ TEST_F(GenericTransferManagerTest, TransferLiteralFromDeviceInt4) { } } +TEST_F(GenericTransferManagerTest, ChooseCompactLayoutForShape) { + auto shape = ShapeUtil::MakeShape(S4, {2, 2}); + TF_ASSERT_OK_AND_ASSIGN(auto compact_shape, + transfer_manager_.ChooseCompactLayoutForShape(shape)); + EXPECT_TRUE(Shape::Equal().IgnoreLayout()(compact_shape, shape)); + EXPECT_EQ(compact_shape.layout().element_size_in_bits(), 4); +} + } // namespace } // namespace xla From fa1bee9310d69136e05f2c0d517ecb6c63d0bb2b Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 23 Sep 2024 12:04:27 -0700 Subject: [PATCH 147/483] Remove GpuCollectives backend-agnostic API header - Rename `GpuCollectives` to `CudaCollectvies` and move header into `stream_executor/cuda` - Fix up the only user of the `GpuCollectives` APIs (`CudaExecutor`) - Remove ROCm implementation since it's an unused stub - Remove unused dependencies from SYCL implementation - Add basic test for `CudaCollectives` PiperOrigin-RevId: 677888252 --- .../xla/xla/service/gpu/runtime/nccl_api.cc | 2 + .../xla/xla/service/gpu/runtime/nccl_api.h | 4 ++ .../xla/service/gpu/runtime/nccl_api_stub.cc | 2 + .../xla/xla/stream_executor/cuda/BUILD | 62 +++++++++++++++++-- .../stream_executor/cuda/cuda_collectives.cc | 20 ++---- .../cuda_collectives.h} | 8 +-- .../cuda_collectives_stub.cc} | 19 +++--- .../cuda/cuda_collectives_test.cc | 61 ++++++++++++++++++ .../xla/stream_executor/cuda/cuda_executor.cc | 7 ++- .../xla/stream_executor/cuda/cuda_executor.h | 7 ++- third_party/xla/xla/stream_executor/gpu/BUILD | 10 --- .../xla/xla/stream_executor/rocm/BUILD | 16 ----- .../xla/stream_executor/rocm/rocm_executor.cc | 1 - .../xla/xla/stream_executor/sycl/BUILD | 1 - 14 files changed, 151 insertions(+), 69 deletions(-) rename third_party/xla/xla/stream_executor/{gpu/gpu_collectives.h => cuda/cuda_collectives.h} (89%) rename third_party/xla/xla/stream_executor/{rocm/rocm_collectives.cc => cuda/cuda_collectives_stub.cc} (56%) create mode 100644 third_party/xla/xla/stream_executor/cuda/cuda_collectives_test.cc diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc index 15949ac9cae999..8b6bd101752e66 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc @@ -349,6 +349,8 @@ NcclApi* NcclApi::Default() { return nccl_api; } +bool NcclApi::HasNcclSupport() { return true; } + static_assert(NCCL_UNIQUE_ID_BYTES == NcclCliqueId::kSize, "size of nccl unique id must match the clique id size"); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.h b/third_party/xla/xla/service/gpu/runtime/nccl_api.h index 76747b64f703c3..813a940052a36b 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.h @@ -60,6 +60,10 @@ class NcclApi { // NCCL or a stub if XLA compiled without NCCL or CUDA support. static NcclApi* Default(); + // Returns true if XLA is compiled with NCCL support, otherwise returns false. + // If false, Default() will return a stub implementation. + static bool HasNcclSupport(); + // Forward declarations of opaque structs corresponding to underlying platform // types (also defined as opaque structs). struct NcclComm; diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc index b0cfad8fc23dfe..c3934e02814d76 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc @@ -170,4 +170,6 @@ NcclApi* NcclApi::Default() { return nccl_api; } +bool NcclApi::HasNcclSupport() { return false; } + } // namespace xla::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 651eda5a62e39b..9b900305327f4f 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -227,21 +227,72 @@ cuda_only_cc_library( cuda_only_cc_library( name = "cuda_collectives", - srcs = ["cuda_collectives.cc"], - defines = if_nccl(["STREAM_EXECUTOR_GPU_ENABLE_XCCL"]), + hdrs = ["cuda_collectives.h"], + deps = if_nccl( + [":cuda_collectives_impl"], + [":cuda_collectives_stub"], + ) + [ + "//xla/stream_executor/gpu:context", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "cuda_collectives_impl", + srcs = [ + "cuda_collectives.cc", + "cuda_collectives.h", + ], + tags = [ + "gpu", + "manual", + ], deps = [ ":cuda_driver", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_collectives_header", "//xla/stream_executor/gpu:scoped_activate_context", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@local_config_nccl//:nccl", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", - ] + if_nccl(["@local_config_nccl//:nccl"]), + ], +) + +cc_library( + name = "cuda_collectives_stub", + srcs = [ + "cuda_collectives.h", + "cuda_collectives_stub.cc", + ], + deps = [ + "//xla/stream_executor/gpu:context", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +xla_test( + name = "cuda_collectives_test", + srcs = ["cuda_collectives_test.cc"], + backends = ["gpu_any"], + deps = [ + ":cuda_collectives", + "//xla/service/gpu/runtime:nccl_api", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_executor_header", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], ) xla_test( @@ -823,7 +874,7 @@ cuda_only_cc_library( "cuda_executor.h", ], deps = [ - ":cuda_collectives", # buildcleaner: keep + ":cuda_collectives", ":cuda_diagnostics", ":cuda_driver", ":cuda_event", # buildcleaner: keep @@ -847,7 +898,6 @@ cuda_only_cc_library( "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_collectives_header", "//xla/stream_executor/gpu:gpu_command_buffer", "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc index 281707de8a11ec..829f1a1eff474f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc @@ -13,30 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/cuda/cuda_collectives.h" + #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "third_party/nccl/nccl.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" -#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL -#include "third_party/nccl/nccl.h" -#endif // STREAM_EXECUTOR_GPU_ENABLE_XCCL - namespace stream_executor::gpu { -/* static */ absl::StatusOr GpuCollectives::CollectiveMemoryAllocate( +/* static */ absl::StatusOr CudaCollectives::CollectiveMemoryAllocate( Context* context, uint64_t bytes) { if (bytes == 0) return nullptr; ScopedActivateContext activated(context); -#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL void* ptr = nullptr; ncclResult_t res = ncclMemAlloc(&ptr, bytes); if (res != ncclSuccess) { @@ -49,16 +46,12 @@ namespace stream_executor::gpu { VLOG(2) << "Allocated collective memory " << ptr << " for context " << context << " of " << bytes << " bytes"; return ptr; -#else - return absl::FailedPreconditionError("XLA was compiled without NCCL support"); -#endif } -/* static */ absl::Status GpuCollectives::CollectiveMemoryDeallocate( +/* static */ absl::Status CudaCollectives::CollectiveMemoryDeallocate( Context* context, void* location) { ScopedActivateContext activation(context); -#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL ncclResult_t res = ncclMemFree(location); if (res != ncclSuccess) { return absl::InternalError(absl::StrFormat( @@ -70,9 +63,6 @@ namespace stream_executor::gpu { VLOG(2) << "Deallocated collective memory " << location << " for context " << context; return absl::OkStatus(); -#else - return absl::FailedPreconditionError("XLA was compiled without NCCL support"); -#endif } } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_collectives.h b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.h similarity index 89% rename from third_party/xla/xla/stream_executor/gpu/gpu_collectives.h rename to third_party/xla/xla/stream_executor/cuda/cuda_collectives.h index 188931312fbe38..4574a59871aa37 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_collectives.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_COLLECTIVES_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_COLLECTIVES_H_ #include @@ -24,7 +24,7 @@ limitations under the License. namespace stream_executor::gpu { -struct GpuCollectives { +struct CudaCollectives { // Allocates a collective device memory space of size bytes associated with // the given context. // @@ -42,4 +42,4 @@ struct GpuCollectives { } // namespace stream_executor::gpu -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_COLLECTIVES_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_collectives.cc b/third_party/xla/xla/stream_executor/cuda/cuda_collectives_stub.cc similarity index 56% rename from third_party/xla/xla/stream_executor/rocm/rocm_collectives.cc rename to third_party/xla/xla/stream_executor/cuda/cuda_collectives_stub.cc index 2b993bb25295e0..f4a403ca1d1894 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_collectives.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives_stub.cc @@ -17,21 +17,20 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" -#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/cuda/cuda_collectives.h" +#include "xla/stream_executor/gpu/context.h" namespace stream_executor::gpu { -absl::StatusOr GpuCollectives::CollectiveMemoryAllocate(Context* context, - uint64_t bytes) { - return absl::UnimplementedError( - "Feature not supported on ROCm platform (CollectiveMemoryAllocate)"); +/* static */ absl::StatusOr CudaCollectives::CollectiveMemoryAllocate( + Context* context, uint64_t bytes) { + if (bytes == 0) return nullptr; + return absl::FailedPreconditionError("XLA was compiled without NCCL support"); } -absl::Status GpuCollectives::CollectiveMemoryDeallocate(Context* context, - void* location) { - return absl::UnimplementedError( - "Feature not supported on ROCm platform (CollectiveMemoryDeallocate)"); +/* static */ absl::Status CudaCollectives::CollectiveMemoryDeallocate( + Context* context, void* location) { + return absl::FailedPreconditionError("XLA was compiled without NCCL support"); } } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_collectives_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_collectives_test.cc new file mode 100644 index 00000000000000..e467f37aadebb8 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_collectives.h" + +#include + +#include +#include +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; + +TEST(CudaCollectivesTest, CollectiveMemoryAllocation) { + if (!xla::gpu::NcclApi::HasNcclSupport()) { + GTEST_SKIP() << "Compiled without NCCL support"; + } + + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + PlatformManager::PlatformWithName("CUDA")); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + GpuExecutor* gpu_executor = ExtractGpuExecutor(executor); + + constexpr size_t kAllocateSize = 1024; + TF_ASSERT_OK_AND_ASSIGN(void* memory, + CudaCollectives::CollectiveMemoryAllocate( + gpu_executor->gpu_context(), kAllocateSize)); + + EXPECT_THAT(gpu_executor->GetPointerMemorySpace(memory), + IsOkAndHolds(MemoryType::kDevice)); + + EXPECT_THAT(CudaCollectives::CollectiveMemoryDeallocate( + gpu_executor->gpu_context(), memory), + IsOk()); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 27486707b5707c..475899fa6b1bb2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -39,6 +39,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/cuda/cuda_collectives.h" #include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/cuda/cuda_runtime.h" @@ -52,7 +53,6 @@ limitations under the License. #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" @@ -441,11 +441,12 @@ absl::Status CudaExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, DeviceMemoryBase CudaExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == 1) { - auto result = GpuCollectives::CollectiveMemoryAllocate(gpu_context(), size); + auto result = + CudaCollectives::CollectiveMemoryAllocate(gpu_context(), size); if (!result.ok()) { LOG(ERROR) << result.status(); } - return DeviceMemoryBase(*result, size); + return DeviceMemoryBase(nullptr, 0); } else if (memory_space == static_cast(stream_executor::MemoryType::kHost)) { return DeviceMemoryBase(GpuDriver::HostAllocate(gpu_context(), size), size); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h index 60ecd23d05d11d..e467c2a3d432be 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/thread_annotations.h" @@ -34,13 +35,13 @@ limitations under the License. #include "absl/types/span.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/cuda/cuda_collectives.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" @@ -66,11 +67,11 @@ class CudaExecutor : public GpuExecutor { bool SynchronizeAllActivity() override; absl::StatusOr CollectiveMemoryAllocate(uint64_t size) override { - return GpuCollectives::CollectiveMemoryAllocate(gpu_context(), size); + return CudaCollectives::CollectiveMemoryAllocate(gpu_context(), size); } absl::Status CollectiveMemoryDeallocate(void* location) override { - return GpuCollectives::CollectiveMemoryDeallocate(gpu_context(), location); + return CudaCollectives::CollectiveMemoryDeallocate(gpu_context(), location); } absl::StatusOr> CreateEventBasedTimer( diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index d98dcfe248d4d9..711104c6ffc673 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -160,16 +160,6 @@ gpu_only_cc_library( deps = ["@com_google_absl//absl/status:statusor"], ) -gpu_only_cc_library( - name = "gpu_collectives_header", - hdrs = ["gpu_collectives.h"], - deps = [ - ":context", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - gpu_only_cc_library( name = "gpu_driver_header", hdrs = ["gpu_driver.h"], diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index e6fe7a4ef11566..491e558b1a4e4e 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -107,19 +107,6 @@ cc_library( ]), ) -cc_library( - name = "rocm_collectives", - srcs = if_rocm_is_configured(["rocm_collectives.cc"]), - deps = if_rocm_is_configured([ - # keep sorted - "//xla/stream_executor/gpu:gpu_collectives_header", - "//xla/stream_executor/gpu:gpu_driver_header", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ]), -) - cc_library( name = "rocm_event", srcs = if_rocm_is_configured(["rocm_event.cc"]), @@ -141,7 +128,6 @@ cc_library( hdrs = if_rocm_is_configured(["rocm_executor.h"]), deps = if_rocm_is_configured([ # keep sorted - ":rocm_collectives", ":rocm_diagnostics", ":rocm_driver", ":rocm_event", @@ -164,7 +150,6 @@ cc_library( "//xla/stream_executor:plugin_registry", "//xla/stream_executor:semantic_version", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_collectives_header", "//xla/stream_executor/gpu:gpu_command_buffer", "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", @@ -227,7 +212,6 @@ cc_library( visibility = ["//visibility:public"], deps = if_rocm_is_configured([ # keep sorted - ":rocm_collectives", ":rocm_driver", ":rocm_executor", ":rocm_platform_id", diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index 4b3687fe92a5f4..c7bca9c3851c3c 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -47,7 +47,6 @@ limitations under the License. #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" diff --git a/third_party/xla/xla/stream_executor/sycl/BUILD b/third_party/xla/xla/stream_executor/sycl/BUILD index 86c00a09d67028..b185209d8bab26 100644 --- a/third_party/xla/xla/stream_executor/sycl/BUILD +++ b/third_party/xla/xla/stream_executor/sycl/BUILD @@ -47,7 +47,6 @@ cc_library( "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:gpu_collectives_header", "@local_tsl//tsl/platform:errors", ]), alwayslink = True, # Registers itself with the PlatformManager. From 9e1e4354fe2283c39e5d79a3391d0f1d19958f6b Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Mon, 23 Sep 2024 12:13:51 -0700 Subject: [PATCH 148/483] Automated g4 rollback Reverts 19faa81daa36027232a7858e46cfd45d36da99be PiperOrigin-RevId: 677891613 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 18 +- third_party/xla/xla/hlo/ir/hlo_instruction.h | 5 +- third_party/xla/xla/service/BUILD | 35 - .../xla/service/infeed_token_propagation.cc | 459 ------------- .../xla/service/infeed_token_propagation.h | 45 -- .../service/infeed_token_propagation_test.cc | 601 ------------------ .../while_loop_invariant_code_motion.cc | 1 - 7 files changed, 5 insertions(+), 1159 deletions(-) delete mode 100644 third_party/xla/xla/service/infeed_token_propagation.cc delete mode 100644 third_party/xla/xla/service/infeed_token_propagation.h delete mode 100644 third_party/xla/xla/service/infeed_token_propagation_test.cc diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 8822767a05704b..ad23f22909db5e 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -3399,30 +3399,18 @@ const PtrVec& HloInstruction::branch_computations() const { return called_computations(); } -int32_t HloInstruction::branch_count() const { +int HloInstruction::branch_count() const { CHECK(HloOpcode::kConditional == opcode_); return called_computations().size(); } -HloComputation* HloInstruction::branch_computation(int32_t b) const { - CHECK_EQ(HloOpcode::kConditional, opcode_); +HloComputation* HloInstruction::branch_computation(int b) const { + CHECK(HloOpcode::kConditional == opcode_); CHECK_GE(b, 0); CHECK_LT(b, called_computations().size()); return called_computations()[b]; } -int32_t HloInstruction::branch_index(HloComputation* computation) const { - CHECK_EQ(HloOpcode::kConditional, opcode_); - CHECK_NE(computation, nullptr); - for (int32_t idx = 0; idx < branch_count(); idx++) { - if (branch_computation(idx) == computation) { - return idx; - } - } - LOG(FATAL) << absl::StrFormat("Conditional %s does not contain branch %s", - name(), computation->name()); -} - void HloInstruction::set_branch_computation(int b, HloComputation* computation) { CHECK_EQ(HloOpcode::kConditional, opcode_); diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index b6ba1bd8d37571..42729daec64df3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1808,9 +1808,8 @@ class HloInstruction { // // Precondition: The instruction is a Conditional instruction. const PtrVec& branch_computations() const; - int32_t branch_count() const; - HloComputation* branch_computation(int32_t b) const; - int32_t branch_index(HloComputation* computation) const; + int branch_count() const; + HloComputation* branch_computation(int b) const; // Sets a branch HloComputation for Conditional. // The setter should only be called by HloModule or HloComputation methods. // diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 53d6b29fc5bcd7..bda32e96ee0fac 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -8543,39 +8543,4 @@ xla_cc_test( ], ) -cc_library( - name = "infeed_token_propagation", - srcs = ["infeed_token_propagation.cc"], - hdrs = ["infeed_token_propagation.h"], - deps = [ - ":hlo_dce", - ":tuple_simplifier", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "infeed_token_propagation_test", - srcs = ["infeed_token_propagation_test.cc"], - deps = [ - ":infeed_token_propagation", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/third_party/xla/xla/service/infeed_token_propagation.cc b/third_party/xla/xla/service/infeed_token_propagation.cc deleted file mode 100644 index 11c4d6bb7d0c5c..00000000000000 --- a/third_party/xla/xla/service/infeed_token_propagation.cc +++ /dev/null @@ -1,459 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/infeed_token_propagation.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/tuple_simplifier.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace { -bool IsDanglingInfeed(HloInstruction* infeed) { - CHECK(infeed->opcode() == HloOpcode::kInfeed); - if (infeed->has_sharding()) { - // TODO: b/368327832 - Skip handling sharding until it is removed. - return false; - } - - // Check for dangling input token. - if (const HloInstruction* after_all = infeed->operand(0); - after_all->opcode() != HloOpcode::kAfterAll || - after_all->operand_count() != 0) { - return false; - } - - // Check for dangling output token. - for (const HloInstruction* user : infeed->users()) { - if (user->opcode() == HloOpcode::kGetTupleElement && - user->tuple_index() == 1) { - return false; - } - } - - return true; -} - -bool IsDanglingOutfeed(HloInstruction* outfeed) { - CHECK(outfeed->opcode() == HloOpcode::kOutfeed); - if (outfeed->has_sharding()) { - // TODO: b/368327832 - Skip handling sharding until it is removed. - return false; - } - - // Check for dangling input token. - if (const HloInstruction* after_all = outfeed->operand(1); - after_all->opcode() != HloOpcode::kAfterAll || - after_all->operand_count() != 0) { - return false; - } - - // Check for dangling output token. - if (outfeed->user_count() != 0) { - return false; - } - - return true; -} - -HloInstruction* ReconstructTuple(HloInstruction* tuple) { - CHECK(tuple->shape().IsTuple()); - HloComputation* computation = tuple->parent(); - - std::vector gtes; - gtes.resize(tuple->shape().tuple_shapes_size()); - for (int64_t idx = 0; idx < gtes.size(); ++idx) { - gtes[idx] = computation->AddInstruction( - HloInstruction::CreateGetTupleElement(tuple, idx)); - } - - return computation->AddInstruction(HloInstruction::CreateTuple(gtes)); -} - -absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, - bool add_token_operand) { - CHECK(tuple->shape().IsTuple()); - HloComputation* computation = tuple->parent(); - - // Recreate the original tuple, we'll need to pass this to all the users. - std::vector original_users = tuple->users(); - HloInstruction* original_tuple = ReconstructTuple(tuple); - for (HloInstruction* original_user : original_users) { - int64_t idx = original_user->operand_index(tuple); - TF_RETURN_IF_ERROR(original_user->ReplaceOperandWith(idx, original_tuple)); - } - - // Append the token to the parameter tuple. - *tuple->mutable_shape()->add_tuple_shapes() = ShapeUtil::MakeTokenShape(); - if (add_token_operand) { - tuple->AppendOperand( - computation->AddInstruction(HloInstruction::CreateToken())); - } - - HloInstruction* input_token_gte = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - tuple, tuple->shape().tuple_shapes_size() - 1)); - return input_token_gte; -} - -absl::Status CanonicalizeConditionalBranch(HloComputation* branch) { - CHECK(branch->IsConditionalBranchComputation()); - CHECK_EQ(branch->num_parameters(), 1); - - // Tuplify the branch parameter if needed. - HloInstruction* parameter = branch->parameter_instruction(0); - if (!parameter->shape().IsTuple()) { - *parameter->mutable_shape() = - ShapeUtil::MakeTupleShape({parameter->shape()}); - HloInstruction* original = branch->AddInstruction( - HloInstruction::CreateGetTupleElement(parameter, 0)); - TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); - } - - // Tuplify the branch tuple if needed. - HloInstruction* conditional = branch->ConditionalCallInstruction(); - int64_t branch_operand_idx = conditional->branch_index(branch) + 1; - HloInstruction* branch_tuple = - conditional->mutable_operand(branch_operand_idx); - if (!branch_tuple->shape().IsTuple()) { - branch_tuple = conditional->parent()->AddInstruction( - HloInstruction::CreateTuple({branch_tuple})); - TF_RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape( - branch_operand_idx, branch_tuple)); - } - - // Explicitly disjoin computation parameters from branch inputs, so we can - // insert tokens into the input tuple. - if (branch_tuple->opcode() == HloOpcode::kParameter) { - branch_tuple = ReconstructTuple(branch_tuple); - TF_RETURN_IF_ERROR( - conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); - } - - // If the computation root is a also a computation parameter, explicitly split - // them, as the input and output tokens cannot be part of the same - // instruction. - HloInstruction* root = branch->root_instruction(); - if (root->opcode() == HloOpcode::kParameter) { - root = ReconstructTuple(root); - branch->set_root_instruction(root); - } - - // ConditionalCanonicalizer should have already turned the conditional output - // to be a tuple. - CHECK(conditional->shape().IsTuple()); - return absl::OkStatus(); -} - -absl::Status CanonicalizeWhileBody(HloComputation* body) { - CHECK(body->IsWhileBodyComputation()); - CHECK_EQ(body->num_parameters(), 1); - - // Tuplify the body parameter if needed. - HloInstruction* parameter = body->parameter_instruction(0); - if (!parameter->shape().IsTuple()) { - *parameter->mutable_shape() = - ShapeUtil::MakeTupleShape({parameter->shape()}); - HloInstruction* original = body->AddInstruction( - HloInstruction::CreateGetTupleElement(parameter, 0)); - TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); - } - - // Tuplify the body root if needed. - HloInstruction* root = body->root_instruction(); - if (!root->shape().IsTuple()) { - root = body->AddInstruction(HloInstruction::CreateTuple({root})); - body->set_root_instruction(root, /*accept_different_shape=*/true); - } - - // Tuplify the condition parameter if needed. - HloInstruction* loop = body->WhileCallInstruction(); - HloComputation* cond = loop->while_condition(); - HloInstruction* cond_parameter = cond->parameter_instruction(0); - if (!cond_parameter->shape().IsTuple()) { - *cond_parameter->mutable_shape() = - ShapeUtil::MakeTupleShape({cond_parameter->shape()}); - HloInstruction* original = cond->AddInstruction( - HloInstruction::CreateGetTupleElement(cond_parameter, 0)); - TF_RETURN_IF_ERROR( - cond_parameter->ReplaceAllUsesWithDifferentShape(original)); - } - - // Tuplify the while instruction if needed. - if (!loop->shape().IsTuple()) { - *loop->mutable_shape() = ShapeUtil::MakeTupleShape({loop->shape()}); - HloInstruction* original = loop->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement(loop, 0)); - TF_RETURN_IF_ERROR(loop->ReplaceAllUsesWithDifferentShape(original)); - } - - // Tuplify the while tuple if needed. - HloInstruction* loop_tuple = loop->mutable_operand(0); - if (!loop_tuple->shape().IsTuple()) { - loop_tuple = loop->parent()->AddInstruction( - HloInstruction::CreateTuple({loop_tuple})); - TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(0, loop_tuple)); - } - - // Explicitly disjoin computation parameters from loop inputs, so we can - // insert tokens into the input tuple. - if (loop_tuple->opcode() == HloOpcode::kParameter) { - loop_tuple = ReconstructTuple(loop_tuple); - TF_RETURN_IF_ERROR(loop->ReplaceOperandWith(0, loop_tuple)); - } - - // If the computation root is a also a computation parameter, explicitly - // split them, as the input and output tokens cannot be part of the same - // instruction. - if (root->opcode() == HloOpcode::kParameter) { - root = ReconstructTuple(root); - body->set_root_instruction(root); - } - - return absl::OkStatus(); -} - -absl::StatusOr> -PropagateTokenThroughConditionalBranch(HloInstruction* instruction, - HloInstruction* input_token, - HloInstruction* output_token) { - // Conditional branches can diverge in inputs, but must converge on outputs. - - // Fixup the branch. - HloComputation* comp = instruction->parent(); - TF_RETURN_IF_ERROR(CanonicalizeConditionalBranch(comp)); - HloInstruction* next_instruction = comp->ConditionalCallInstruction(); - - // Insert the output token into each branch. - for (HloComputation* branch : next_instruction->branch_computations()) { - HloInstruction* root = branch->root_instruction(); - if (branch == comp) { - TF_RETURN_IF_ERROR( - InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); - root->AppendOperand(output_token); - } else { - TF_RETURN_IF_ERROR( - InsertTokenIntoTuple(root, /*add_token_operand=*/true).status()); - } - } - - // Insert the input token into the branch parameter. - HloInstruction* parameter = comp->parameter_instruction(0); - TF_ASSIGN_OR_RETURN( - HloInstruction * input_token_gte, - InsertTokenIntoTuple(parameter, /*add_token_operand=*/false)); - TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); - - // Insert the input token into the branch tuple. - int64_t branch_operand_idx = next_instruction->branch_index(comp) + 1; - HloInstruction* branch_tuple = - next_instruction->mutable_operand(branch_operand_idx); - TF_ASSIGN_OR_RETURN( - HloInstruction * next_input_token_gte, - InsertTokenIntoTuple(branch_tuple, /*add_token_operand=*/true)); - TF_RETURN_IF_ERROR(next_instruction->ReplaceOperandWithDifferentShape( - branch_operand_idx, branch_tuple)); - HloInstruction* next_input_token = - branch_tuple->mutable_operand(next_input_token_gte->tuple_index()); - - // Insert the output token into conditional instruction. - TF_ASSIGN_OR_RETURN( - HloInstruction * next_output_token, - InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); - - return std::make_tuple(next_instruction, next_input_token, next_output_token); -} - -absl::StatusOr> -PropagateTokenThroughWhileBody(HloInstruction* instruction, - HloInstruction* input_token, - HloInstruction* output_token) { - // While loops need to converge on input and output. - - // Fixup the while body. - HloComputation* comp = instruction->parent(); - TF_RETURN_IF_ERROR(CanonicalizeWhileBody(comp)); - HloInstruction* next_instruction = comp->WhileCallInstruction(); - - // Insert the output token into the body root. - HloInstruction* root = comp->root_instruction(); - TF_RETURN_IF_ERROR( - InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); - root->AppendOperand(output_token); - - // Insert the input token into the body parameter. - HloInstruction* body_parameter = comp->parameter_instruction(0); - TF_ASSIGN_OR_RETURN( - HloInstruction * input_token_gte, - InsertTokenIntoTuple(body_parameter, /*add_token_operand=*/false)); - TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); - - // Insert the input token into the condition parameter. - HloComputation* cond = next_instruction->while_condition(); - HloInstruction* cond_parameter = cond->parameter_instruction(0); - TF_RETURN_IF_ERROR( - InsertTokenIntoTuple(cond_parameter, /*add_token_operand=*/false) - .status()); - - // Insert the input token into the while tuple. - HloInstruction* while_tuple = next_instruction->mutable_operand(0); - TF_ASSIGN_OR_RETURN( - HloInstruction * next_input_token, - InsertTokenIntoTuple(while_tuple, /*add_token_operand=*/true)); - TF_RETURN_IF_ERROR( - next_instruction->ReplaceOperandWithDifferentShape(0, while_tuple)); - - // Insert the input token into the while instruction. - TF_ASSIGN_OR_RETURN( - HloInstruction * next_output_token, - InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); - - return std::make_tuple(next_instruction, next_input_token, next_output_token); -} - -absl::Status PropagateToken(HloInstruction* instruction, - HloInstruction* input_token, - HloInstruction* output_token) { - HloComputation* comp = instruction->parent(); - if (comp->IsEntryComputation()) { - // If we propagate through the root instruction, reconstruct the original - // tuple and set that to be root. - if (instruction->IsRoot() && - (instruction->opcode() == HloOpcode::kWhile || - instruction->opcode() == HloOpcode::kConditional)) { - std::vector gtes; - int64_t output_token_idx = output_token->tuple_index(); - for (int64_t idx = 0; idx < instruction->shape().tuple_shapes_size(); - idx++) { - if (idx != output_token_idx) { - gtes.push_back(comp->AddInstruction( - HloInstruction::CreateGetTupleElement(instruction, idx))); - } - } - HloInstruction* original_tuple = - comp->AddInstruction(HloInstruction::CreateTuple(gtes)); - comp->set_root_instruction(original_tuple, - /*accept_different_shape=*/true); - } - return absl::OkStatus(); - } - - HloInstruction* next_instruction = nullptr; - HloInstruction* next_input_token = nullptr; - HloInstruction* next_output_token = nullptr; - if (comp->IsConditionalBranchComputation()) { - // TODO: b/368327832 - Skip handling sharding until it is removed. - if (comp->ConditionalCallInstruction()->has_sharding()) { - return absl::OkStatus(); - } - TF_ASSIGN_OR_RETURN( - std::tie(next_instruction, next_input_token, next_output_token), - PropagateTokenThroughConditionalBranch(instruction, input_token, - output_token)); - } else if (comp->IsWhileBodyComputation()) { - // TODO: b/368327832 - Skip handling sharding until it is removed. - if (comp->WhileCallInstruction()->has_sharding()) { - return absl::OkStatus(); - } - TF_ASSIGN_OR_RETURN( - std::tie(next_instruction, next_input_token, next_output_token), - PropagateTokenThroughWhileBody(instruction, input_token, output_token)); - } else { - // We only expect to encounter computations behind while and conditional - // instructions. In the case of it being behind a while condition, there is - // no way to propagate the output token, as the root only returns a - // predicate. All other computations that could possibly contain infeed - // or outfeed ops should have already been inlined. - VLOG(2) << "Unhandled computation: " << comp->name(); - return absl::OkStatus(); - } - CHECK_NE(next_instruction, nullptr); - CHECK_NE(next_input_token, nullptr); - CHECK_NE(next_output_token, nullptr); - - return PropagateToken(next_instruction, next_input_token, next_output_token); -} -} // namespace - -absl::StatusOr InfeedTokenPropagation::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - VLOG(5) << "Before InfeedTokenPropagation:"; - XLA_VLOG_LINES(5, module->ToString()); - - std::vector dangling_infeeds; - std::vector dangling_outfeeds; - for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { - if (!computation->IsEntryComputation()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kInfeed && - IsDanglingInfeed(instruction)) { - VLOG(1) << "Found dangling infeed: " << instruction->ToString(); - dangling_infeeds.push_back(instruction); - } else if (instruction->opcode() == HloOpcode::kOutfeed && - IsDanglingOutfeed(instruction)) { - VLOG(1) << "Found dangling outfeed: " << instruction->ToString(); - dangling_outfeeds.push_back(instruction); - } - } - } - } - - for (HloInstruction* dangling_infeed : dangling_infeeds) { - HloInstruction* input_token = dangling_infeed->mutable_operand(0); - HloInstruction* output_token = dangling_infeed->AddInstruction( - HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); - TF_RETURN_IF_ERROR( - PropagateToken(dangling_infeed, input_token, output_token)); - } - for (HloInstruction* dangling_outfeed : dangling_outfeeds) { - HloInstruction* input_token = dangling_outfeed->mutable_operand(1); - HloInstruction* output_token = dangling_outfeed; - TF_RETURN_IF_ERROR( - PropagateToken(dangling_outfeed, input_token, output_token)); - } - - bool changed = !dangling_infeeds.empty() || !dangling_outfeeds.empty(); - if (changed) { - TF_RETURN_IF_ERROR( - TupleSimplifier().Run(module, execution_threads).status()); - TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); - } - - VLOG(5) << "After InfeedTokenPropagation:"; - XLA_VLOG_LINES(5, module->ToString()); - return changed; -} -} // namespace xla diff --git a/third_party/xla/xla/service/infeed_token_propagation.h b/third_party/xla/xla/service/infeed_token_propagation.h deleted file mode 100644 index cc6994a62a98a9..00000000000000 --- a/third_party/xla/xla/service/infeed_token_propagation.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ -#define XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ - -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { -// Finds dangling infeed/outfeed tokens inside nested computations and bubbles -// them up through callers until they reach the entry computation. This is -// needed to prepare these computations to be inlined, otherwise the previous -// computation boundaries won't be there to stop infeeds/outfeeds from being -// reordered during scheduling. -// -// This pass assumes the HLO graph is flattened. -class InfeedTokenPropagation : public HloModulePass { - public: - std::string_view name() const override { return "infeed-token-propagation"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; -} // namespace xla - -#endif // XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/infeed_token_propagation_test.cc b/third_party/xla/xla/service/infeed_token_propagation_test.cc deleted file mode 100644 index 8c1024253868d6..00000000000000 --- a/third_party/xla/xla/service/infeed_token_propagation_test.cc +++ /dev/null @@ -1,601 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/infeed_token_propagation.h" - -#include -#include - -#include -#include -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -class InfeedTokenPropagationTest : public HloTestBase { - protected: - InfeedTokenPropagationTest() = default; -}; - -TEST_F(InfeedTokenPropagationTest, EntryComputationInfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -ENTRY main { - token.0 = after-all() - infeed.0 = (s32[], token[]) infeed(token.0) - ROOT gte.0 = get-tuple-element(infeed.0), index=0 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_FALSE(changed); -} - -TEST_F(InfeedTokenPropagationTest, EntryComputationOutfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -ENTRY main { - arg.0 = s32[] parameter(0) - tuple.0 = tuple(arg.0) - token.0 = after-all() - outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) - ROOT tuple.1 = tuple() -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_FALSE(changed); -} - -TEST_F(InfeedTokenPropagationTest, ConditionalInfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -true_comp { - arg.0 = () parameter(0) - token.0 = after-all() - infeed.0 = (s32[], token[]) infeed(token.0) - ROOT tuple.0 = tuple() -} - -false_comp { - arg.0 = () parameter(0) - ROOT tuple.0 = tuple() -} - -ENTRY main { - pred.0 = pred[] constant(true) - true_tuple.0 = tuple() - false_tuple.0 = tuple() - ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The infeed output token should have propagated through the conditional. - HloInstruction* cond = FindInstruction(module.get(), "cond.0"); - EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); - - // The infeed input token should have propagated through the true tuple. - HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); - EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); - - // The infeed input token should not have propagated through the false tuple. - HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); - EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); - - // The infeed output token should have propagated through the true - // computation's root. - HloComputation* true_comp = FindComputation(module.get(), "true_comp"); - EXPECT_THAT(true_comp->root_instruction(), - op::Tuple(op::GetTupleElement(op::Infeed(), 1))); - - // The infeed output token should have propagated to the false computation's - // root. - HloComputation* false_comp = FindComputation(module.get(), "false_comp"); - EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); -} - -TEST_F(InfeedTokenPropagationTest, ConditionalOutfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -true_comp { - arg.0 = (s32[]) parameter(0) - token.0 = after-all() - outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) - ROOT tuple.0 = tuple() -} - -false_comp { - arg.0 = () parameter(0) - ROOT tuple.0 = tuple() -} - -ENTRY main { - arg.0 = s32[] parameter(0) - pred.0 = pred[] constant(true) - true_tuple.0 = tuple(arg.0) - false_tuple.0 = tuple() - ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The outfeed output token should have propagated through the conditional. - HloInstruction* cond = FindInstruction(module.get(), "cond.0"); - EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); - - // The outfeed input token should have propagated through the true tuple. - HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); - EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); - - // The outfeed input token should not have propagated through the false tuple. - HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); - EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); - - // The outfeed output token should have propagated through the true - // computation's root. - HloComputation* true_comp = FindComputation(module.get(), "true_comp"); - EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); - - // The outfeed output token should have propagated to the false computation's - // root. - HloComputation* false_comp = FindComputation(module.get(), "false_comp"); - EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); -} - -TEST_F(InfeedTokenPropagationTest, NonTupleConditional) { - constexpr std::string_view hlo = R"( -HloModule main - -true_comp { - arg.0 = s32[] parameter(0) - outfeed_tuple.0 = tuple(arg.0) - token.0 = after-all() - outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) - ROOT tuple.0 = tuple() -} - -false_comp { - arg.0 = () parameter(0) - ROOT tuple.0 = tuple() -} - -ENTRY main { - arg.0 = s32[] parameter(0) - pred.0 = pred[] constant(true) - false_tuple.0 = tuple() - ROOT cond.0 = () conditional(pred.0, arg.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The outfeed output token should have propagated through the conditional. - HloInstruction* cond = FindInstruction(module.get(), "cond.0"); - EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); - - // The outfeed input token should have propagated through the true tuple. - HloInstruction* true_tuple = cond->mutable_operand(1); - EXPECT_TRUE(true_tuple->shape().IsTuple()); - EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); - - // The outfeed input token should not have propagated through the false tuple. - HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); - EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); - - // The outfeed output token should have propagated through the true - // computation's root. - HloComputation* true_comp = FindComputation(module.get(), "true_comp"); - EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); - - // The outfeed output token should have propagated to the false computation's - // root. - HloComputation* false_comp = FindComputation(module.get(), "false_comp"); - EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); -} - -TEST_F(InfeedTokenPropagationTest, DisjointConditionalOutfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -true_comp { - ROOT arg.0 = () parameter(0) - one.0 = s32[] constant(1) - outfeed_tuple.0 = tuple(one.0) - token.0 = after-all() - outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) -} - -false_comp { - arg.0 = () parameter(0) - ROOT tuple.0 = tuple() -} - -ENTRY main { - pred.0 = pred[] constant(true) - true_tuple.0 = tuple() - false_tuple.0 = tuple() - ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The outfeed output token should have propagated through the conditional. - HloInstruction* cond = FindInstruction(module.get(), "cond.0"); - EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); - - // The outfeed input token should have propagated through the true tuple. - HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); - EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); - - // The outfeed input token should not have propagated through the false tuple. - HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); - EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); - - // The outfeed output token should have propagated through the true - // computation's root. - HloComputation* true_comp = FindComputation(module.get(), "true_comp"); - EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); - - // The outfeed output token should have propagated to the false computation's - // root. - HloComputation* false_comp = FindComputation(module.get(), "false_comp"); - EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); -} - -TEST_F(InfeedTokenPropagationTest, WhileInfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -comp { - arg.0 = () parameter(0) - token.0 = after-all() - infeed.0 = (s32[], token[]) infeed(token.0) - ROOT tuple.0 = tuple() -} - -cond { - arg.0 = () parameter(0) - ROOT true.0 = pred[] constant(true) -} - -ENTRY main { - while_tuple.0 = tuple() - ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The infeed output token should have propagated through the loop. - HloInstruction* loop = FindInstruction(module.get(), "while.0"); - EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); - - // The infeed input token should have propagated through the loop tuple. - HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); - EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); - - // The infeed output token should have propagated through the while body root. - HloComputation* body_comp = FindComputation(module.get(), "comp"); - EXPECT_THAT(body_comp->root_instruction(), - op::Tuple(op::GetTupleElement(op::Infeed(), 1))); - - // The infeed input token should have propagated through the body parameter. - HloInstruction* body_param = body_comp->parameter_instruction(0); - EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); - - // The infeed input token should have propagated through the condition - // parameter. - HloComputation* cond_comp = FindComputation(module.get(), "cond"); - HloInstruction* cond_param = cond_comp->parameter_instruction(0); - EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); -} - -TEST_F(InfeedTokenPropagationTest, WhileOutfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -comp { - arg.0 = (s32[]) parameter(0) - token.0 = after-all() - outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) - gte.0 = get-tuple-element(arg.0), index=0 - ROOT tuple.0 = tuple(gte.0) -} - -cond { - arg.0 = (s32[]) parameter(0) - ROOT true.0 = pred[] constant(true) -} - -ENTRY main { - arg.0 = s32[] parameter(0) - while_tuple.0 = tuple(arg.0) - ROOT while.0 = (s32[]) while(while_tuple.0), condition=cond, body=comp -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The outfeed output token should have propagated through the loop. - HloInstruction* loop = FindInstruction(module.get(), "while.0"); - EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); - - // The outfeed input token should have propagated through the loop tuple. - HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); - EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); - - // The outfeed output token should have propagated through the while body - // root. - HloComputation* body_comp = FindComputation(module.get(), "comp"); - EXPECT_THAT(body_comp->root_instruction(), - op::Tuple(op::GetTupleElement(), op::Outfeed())); - - // The outfeed output token should have propagated through the body parameter. - HloInstruction* body_param = body_comp->parameter_instruction(0); - EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); - - // The outfeed output token should have propagated through the condition - // parameter. - HloComputation* cond_comp = FindComputation(module.get(), "cond"); - HloInstruction* cond_param = cond_comp->parameter_instruction(0); - EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); -} - -TEST_F(InfeedTokenPropagationTest, DisjointWhileOutfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -comp { - ROOT arg.0 = () parameter(0) - one.0 = s32[] constant(1) - outfeed_tuple.0 = tuple(one.0) - token.0 = after-all() - outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) -} - -cond { - arg.0 = () parameter(0) - ROOT true.0 = pred[] constant(true) -} - -ENTRY main { - while_tuple.0 = tuple() - ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The outfeed output token should have propagated through the loop. - HloInstruction* loop = FindInstruction(module.get(), "while.0"); - EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); - - // The outfeed input token should have propagated through the loop tuple. - HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); - EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); - - // The outfeed output token should have propagated through the while body - // root. - HloComputation* body_comp = FindComputation(module.get(), "comp"); - EXPECT_THAT(body_comp->root_instruction(), op::Tuple(op::Outfeed())); - - // The outfeed output token should have propagated through the body parameter. - HloInstruction* body_param = body_comp->parameter_instruction(0); - EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); - - // The outfeed output token should have propagated through the condition - // parameter. - HloComputation* cond_comp = FindComputation(module.get(), "cond"); - HloInstruction* cond_param = cond_comp->parameter_instruction(0); - EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); -} - -TEST_F(InfeedTokenPropagationTest, NonTupleWhile) { - constexpr std::string_view hlo = R"( -HloModule main - -comp { - ROOT arg.0 = s32[] parameter(0) - tuple.0 = tuple(arg.0) - token.0 = after-all() - outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) -} - -cond { - arg.0 = s32[] parameter(0) - ROOT true.0 = pred[] constant(true) -} - -ENTRY main { - arg.0 = s32[] parameter(0) - ROOT while.0 = s32[] while(arg.0), condition=cond, body=comp -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The outfeed output token should have propagated through the loop. - HloInstruction* loop = FindInstruction(module.get(), "while.0"); - EXPECT_TRUE(loop->shape().IsTuple()); - EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); - - // The outfeed input token should have propagated through the loop tuple. - EXPECT_THAT(loop->operand(0), op::Tuple(op::Parameter(), op::AfterAll())); - - // The outfeed output token should have propagated through the while body - // root. - HloComputation* body_comp = FindComputation(module.get(), "comp"); - EXPECT_THAT(body_comp->root_instruction(), - op::Tuple(op::GetTupleElement(), op::Outfeed())); - - // The outfeed output token should have propagated through the body parameter. - HloInstruction* body_param = body_comp->parameter_instruction(0); - EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); - - // The outfeed output token should have propagated through the condition - // parameter. - HloComputation* cond_comp = FindComputation(module.get(), "cond"); - HloInstruction* cond_param = cond_comp->parameter_instruction(0); - EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); -} - -TEST_F(InfeedTokenPropagationTest, NestedInfeedOutfeed) { - constexpr std::string_view hlo = R"( -HloModule main - -true_comp { - arg.0 = (s32[]) parameter(0) - token.0 = after-all() - outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) - ROOT tuple.0 = tuple() -} - -false_comp { - arg.0 = () parameter(0) - ROOT tuple.0 = tuple() -} - -comp { - arg.0 = () parameter(0) - token.0 = after-all() - infeed.0 = (s32[], token[]) infeed(token.0) - gte.0 = get-tuple-element(infeed.0), index=0 - pred.0 = pred[] constant(true) - true_tuple.0 = tuple(gte.0) - false_tuple.0 = tuple() - cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp - ROOT tuple.0 = tuple() -} - -cond { - arg.0 = () parameter(0) - ROOT true.0 = pred[] constant(true) -} - -ENTRY main { - while_tuple.0 = tuple() - ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - InfeedTokenPropagation itp; - TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); - EXPECT_TRUE(changed); - - // The infeed and outfeed output tokens should have propagated through the - // loop. - HloInstruction* loop = FindInstruction(module.get(), "while.0"); - EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); - EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); - - // The infeed and outfeed input tokens should have propagated through the loop - // tuple. - HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); - EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); - EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); - - // The infeed and outfeed output tokens should have propagated through the - // while body root. - HloComputation* body_comp = FindComputation(module.get(), "comp"); - EXPECT_THAT(body_comp->root_instruction(), - op::Tuple(op::GetTupleElement(op::Infeed(), 1), - op::GetTupleElement(op::Conditional(), 0))); - - // The outfeed output token should have propagated through the conditional. - HloInstruction* cond = FindInstruction(module.get(), "cond.0"); - EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); - EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); - - // The outfeed input token should have propagated through the true tuple. - HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); - EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); - EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); - - // The outfeed input token should not have propagated through the false tuple. - HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); - EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); - - // The outfeed output token should have propagated through the true - // computation's root. - HloComputation* true_comp = FindComputation(module.get(), "true_comp"); - EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); - - // The outfeed output token should have propagated to the false computation's - // root. - HloComputation* false_comp = FindComputation(module.get(), "false_comp"); - EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); -} -} // namespace -} // namespace xla diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc index b1aae51df132e9..ed44547af3fca4 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc @@ -232,7 +232,6 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( } if (instruction->HasSideEffect() || - instruction->opcode() == HloOpcode::kAfterAll || instruction->opcode() == HloOpcode::kParameter || !instruction->control_predecessors().empty() || !instruction->control_successors().empty()) { From be4d1cb3579e1e685a1acb5484b3280c0bfe0d5e Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 23 Sep 2024 13:00:56 -0700 Subject: [PATCH 149/483] PR #17457: Parameterize FloatConversion tests Imported from GitHub PR https://github.com/openxla/xla/pull/17457 gpu float_conversions_test.cc contains multiple similar tests for Float types Changes: - Parameterize FloatConversion tests Copybara import of the project: -- da47ec8ae24398ef0fa5083e0550b1e640c1b77d by Alexander Pivovarov : Parameterize FloatConversion tests Merging this change closes #17457 PiperOrigin-RevId: 677907718 --- .../gpu/tests/float_conversions_test.cc | 141 ++++-------------- 1 file changed, 31 insertions(+), 110 deletions(-) diff --git a/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc b/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc index 34b5c703798c23..b5d571e4c7be3f 100644 --- a/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc +++ b/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc @@ -23,131 +23,52 @@ namespace gpu { class FloatConversionTest : public GpuCodegenTest {}; -TEST_F(FloatConversionTest, F8E5M2ToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e5m2[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F8E4M3FNToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e4m3fn[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F8E4M3B11FNUZToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e4m3b11fnuz[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F8E5M2FNUZToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e5m2fnuz[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F8E4M3FNUZToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e4m3fnuz[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, BF16ToF32) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = bf16[] parameter(0) - ROOT %c = f32[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F16ToF32) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f32[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F64ToF32) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f64[] parameter(0) - ROOT %c = f32[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} +class FloatConversionParamTest + : public GpuCodegenTest, + public ::testing::WithParamInterface {}; -TEST_F(FloatConversionTest, F16ToF8E5M2) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e5m2[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} +INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, + ::testing::Values("f64", "f32", "f16", "bf16", + "f8e5m2", "f8e5m2fnuz", "f8e4m3fn", + "f8e4m3fnuz", "f8e4m3b11fnuz")); -TEST_F(FloatConversionTest, F16ToF8E4M3FN) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e4m3fn[] convert(%p) +TEST_P(FloatConversionParamTest, FloatToF16) { + auto type_name = GetParam(); + EXPECT_TRUE(RunAndCompare(absl::StrFormat(R"(ENTRY m { + p0 = %s[] parameter(0) + ROOT c1 = f16[] convert(p0) })", + type_name), ErrorSpec{1e-5, 1e-5})); } -TEST_F(FloatConversionTest, F16ToF8E4M3B11FNUZ) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e4m3b11fnuz[] convert(%p) +TEST_P(FloatConversionParamTest, F16ToFloat) { + auto type_name = GetParam(); + EXPECT_TRUE(RunAndCompare(absl::StrFormat(R"(ENTRY m { + p0 = f16[] parameter(0) + ROOT c1 = %s[] convert(p0) })", + type_name), ErrorSpec{1e-5, 1e-5})); } -TEST_F(FloatConversionTest, F16ToF8E5M2FNUZ) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e5m2fnuz[] convert(%p) +TEST_P(FloatConversionParamTest, FloatToF32) { + auto type_name = GetParam(); + EXPECT_TRUE(RunAndCompare(absl::StrFormat(R"(ENTRY m { + p0 = %s[] parameter(0) + ROOT c1 = f32[] convert(p0) })", + type_name), ErrorSpec{1e-5, 1e-5})); } -TEST_F(FloatConversionTest, F16ToF8E4M3FNUZ) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e4m3fnuz[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F32ToBF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f32[] parameter(0) - ROOT %c = bf16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F32ToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f32[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F32ToF64) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f32[] parameter(0) - ROOT %c = f64[] convert(%p) +TEST_P(FloatConversionParamTest, F32ToFloat) { + auto type_name = GetParam(); + EXPECT_TRUE(RunAndCompare(absl::StrFormat(R"(ENTRY m { + p0 = f32[] parameter(0) + ROOT c1 = %s[] convert(p0) })", + type_name), ErrorSpec{1e-5, 1e-5})); } From fc0356f9dd3d23dc02a3d88d80661ec22163a5b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 13:04:30 -0700 Subject: [PATCH 150/483] [easy] [XLA] Factor our removal of unused outputs of multi-fusions PiperOrigin-RevId: 677909092 --- third_party/xla/xla/service/hlo_dce.cc | 189 +++++++++++++------------ 1 file changed, 99 insertions(+), 90 deletions(-) diff --git a/third_party/xla/xla/service/hlo_dce.cc b/third_party/xla/xla/service/hlo_dce.cc index 5617190d36b854..36054a950d6b8d 100644 --- a/third_party/xla/xla/service/hlo_dce.cc +++ b/third_party/xla/xla/service/hlo_dce.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -67,104 +68,112 @@ bool IsRemovableWhile(HloInstruction* instruction, return true; } -} // namespace +// Returns true if it found and removed unused outputs. +absl::StatusOr RemoveMultiOutputFusionsUnusedOutputs( + HloComputation* computation) { + HloInstruction* fusion_instruction = computation->FusionInstruction(); + if (!fusion_instruction) { + return false; + } -/*static*/ absl::StatusOr HloDCE::RunOnComputation( - HloComputation* computation, bool remove_cross_partition_collective_ops) { - bool changed = false; - // Cleanup unused tuple elements in multi-output fusion roots. We do this - // first, because it may create dead roots which we can clean up next. - if (auto* fusion_instruction = computation->FusionInstruction(); - fusion_instruction != nullptr && - computation->root_instruction()->opcode() == HloOpcode::kTuple && - !computation->root_instruction()->has_sharding() && - fusion_instruction->output_operand_aliasing().empty() && - !fusion_instruction->HasControlDependencies() && - !fusion_instruction->IsCustomFusion()) { - // The order of the used outputs is relevant for the algorithm below. - std::set used_tuple_elements; - // We only support this cleanup if all users of the fusion instruction are - // GetTupleElement ops, and there is at least one user of - // 'fusion_instruction'. - bool supported = fusion_instruction->user_count() > 0; - for (HloInstruction* gte : fusion_instruction->users()) { - if (gte->opcode() != HloOpcode::kGetTupleElement) { - supported = false; - break; - } - used_tuple_elements.insert(gte->tuple_index()); - } + if (computation->root_instruction()->opcode() != HloOpcode::kTuple || + computation->root_instruction()->has_sharding() || + !fusion_instruction->output_operand_aliasing().empty() || + fusion_instruction->HasControlDependencies() || + fusion_instruction->IsCustomFusion()) { + return false; + } + + // The order of the used outputs is relevant for the algorithm below. + std::set used_tuple_elements; + + // We only support this cleanup if all users of the fusion instruction are + // GetTupleElement ops, and there is at least one user of + // 'fusion_instruction'. + if (fusion_instruction->users().empty()) { + return false; + } - // If all outputs are used, nothing to clean up. - if (used_tuple_elements.size() == - computation->root_instruction()->operand_count()) { - supported = false; + for (HloInstruction* gte : fusion_instruction->users()) { + if (gte->opcode() != HloOpcode::kGetTupleElement) { + return false; } + used_tuple_elements.insert(gte->tuple_index()); + } - if (supported) { - std::vector tuple_shapes; - tuple_shapes.reserve(used_tuple_elements.size()); - for (int64_t tuple_index : used_tuple_elements) { - tuple_shapes.push_back( - fusion_instruction->shape().tuple_shapes(tuple_index)); - } - Shape new_shape = tuple_shapes.size() == 1 - ? tuple_shapes[0] - : ShapeUtil::MakeTupleShape(tuple_shapes); - *fusion_instruction->mutable_shape() = std::move(new_shape); - - // Update the users of the old fusion instruction. - if (tuple_shapes.size() > 1) { - for (HloInstruction* gte : fusion_instruction->users()) { - auto it = - std::lower_bound(used_tuple_elements.begin(), - used_tuple_elements.end(), gte->tuple_index()); - int64_t new_tuple_index = - std::distance(used_tuple_elements.begin(), it); - gte->set_tuple_index(new_tuple_index); - } - } else { - // Since we iterate over users while removing them .. make a local copy - // first. - std::vector users(fusion_instruction->users()); - for (HloInstruction* gte : users) { - // Replace and change control successors to be dependent on the fusion - // instruction itself. - TF_ASSIGN_OR_RETURN(bool replaced, - gte->parent()->ReplaceInstruction( - gte, fusion_instruction, - /*preserve_sharding=*/true, - /*relay_control_dependency=*/true)); - if (replaced) { - changed |= replaced; - } - } - } + // If all outputs are used, nothing to clean up. + if (used_tuple_elements.size() == + computation->root_instruction()->operand_count()) { + return false; + } - // Update the root of the fusion computation. - if (tuple_shapes.size() > 1) { - std::vector new_operands; - new_operands.reserve(used_tuple_elements.size()); - for (int64_t tuple_index : used_tuple_elements) { - new_operands.push_back( - computation->root_instruction()->mutable_operand(tuple_index)); - } - auto new_tuple = computation->AddInstruction( - HloInstruction::CreateTuple(new_operands)); - TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( - computation->root_instruction(), new_tuple)); - } else { - TF_RETURN_IF_ERROR( - computation->root_instruction()->ReplaceAllUsesWithDifferentShape( - computation->root_instruction()->mutable_operand( - *used_tuple_elements.begin()))); - } + std::vector tuple_shapes; + tuple_shapes.reserve(used_tuple_elements.size()); + for (int64_t tuple_index : used_tuple_elements) { + tuple_shapes.push_back( + fusion_instruction->shape().tuple_shapes(tuple_index)); + } + Shape new_shape = tuple_shapes.size() == 1 + ? tuple_shapes[0] + : ShapeUtil::MakeTupleShape(tuple_shapes); + *fusion_instruction->mutable_shape() = std::move(new_shape); + + // Update the users of the old fusion instruction. + if (tuple_shapes.size() > 1) { + for (HloInstruction* gte : fusion_instruction->users()) { + auto it = used_tuple_elements.lower_bound(gte->tuple_index()); + int64_t new_tuple_index = std::distance(used_tuple_elements.begin(), it); + gte->set_tuple_index(new_tuple_index); + } + } else { + // Since we iterate over users while removing them .. make a local copy + // first. + std::vector users(fusion_instruction->users()); + for (HloInstruction* gte : users) { + // Replace and change control successors to be dependent on the fusion + // instruction itself. + TF_ASSIGN_OR_RETURN(std::ignore, gte->parent()->ReplaceInstruction( + gte, fusion_instruction, + /*preserve_sharding=*/true, + /*relay_control_dependency=*/true)); + } + } + + // Update the root of the fusion computation. + if (tuple_shapes.size() > 1) { + std::vector new_operands; + new_operands.reserve(used_tuple_elements.size()); + for (int64_t tuple_index : used_tuple_elements) { + new_operands.push_back( + computation->root_instruction()->mutable_operand(tuple_index)); } + auto new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( + computation->root_instruction(), new_tuple)); + } else { + TF_RETURN_IF_ERROR( + computation->root_instruction()->ReplaceAllUsesWithDifferentShape( + computation->root_instruction()->mutable_operand( + *used_tuple_elements.begin()))); } - // Remove any dead roots and their dead transitive operands. Collect them - // into a separate list first to avoid problems with iterating through the - // computation's instruction while simultaneously removing instructions. + // We always updated the fusion if we got here. + return true; +} + +} // namespace + +/*static*/ absl::StatusOr HloDCE::RunOnComputation( + HloComputation* computation, bool remove_cross_partition_collective_ops) { + // We do this first, because it may create dead roots which we can clean up + // next. + TF_ASSIGN_OR_RETURN(bool changed, + RemoveMultiOutputFusionsUnusedOutputs(computation)); + + // Remove any dead roots and their dead transitive operands. Collect + // them into a separate list first to avoid problems with iterating through + // the computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { auto maybe_collective_op = DynCast(instruction); From eaeaa91ef425704a562fe6a52b4a1d76dfdbfe5b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 13:28:55 -0700 Subject: [PATCH 151/483] Add check for max size when encoding string list PiperOrigin-RevId: 677918083 --- tensorflow/core/platform/BUILD | 1 + tensorflow/core/platform/tensor_coding.cc | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 0d23a089d6dbf9..51333b01ab8f61 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -850,6 +850,7 @@ cc_library( hdrs = ["tensor_coding.h"], deps = [ ":coding", + ":logging", ":platform", ":protobuf", ":refcount", diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc index dd91086efaf4dc..38f1d26508722f 100644 --- a/tensorflow/core/platform/tensor_coding.cc +++ b/tensorflow/core/platform/tensor_coding.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/core/platform/tensor_coding.h" +#include +#include #include #include "tensorflow/core/platform/coding.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/stringpiece.h" @@ -34,6 +37,13 @@ void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) { } void EncodeStringList(const tstring* strings, int64_t n, string* out) { + int64_t tot = n * sizeof(size_t); + for (int i = 0; i < n; ++i) { + tot += strings[i].size(); + } + if (tot > INT_MAX) { + LOG(FATAL) << "EncodeStringList size too large: " << tot; // Crash OK + } out->clear(); for (int i = 0; i < n; ++i) { core::PutVarint32(out, strings[i].size()); From 29df6c3d845b6f5f02ccf0eee53faa5e76cd93a4 Mon Sep 17 00:00:00 2001 From: Augie Fackler Date: Mon, 23 Sep 2024 13:40:00 -0700 Subject: [PATCH 152/483] Integrate LLVM at llvm/llvm-project@8b4b7d28f7c3 Updates LLVM usage to match [8b4b7d28f7c3](https://github.com/llvm/llvm-project/commit/8b4b7d28f7c3) PiperOrigin-RevId: 677922152 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/temporary.patch | 10 +++++----- third_party/shardy/workspace.bzl | 4 ++-- third_party/xla/third_party/shardy/temporary.patch | 10 +++++----- third_party/xla/third_party/shardy/workspace.bzl | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 55290bf02ec4fd..726a367bee5547 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" - LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" + LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" + LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 613660a484cee2..f6938677141184 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +1,15 @@ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index c011aab..55290bf 100644 +index 55290bf..726a367 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" -- LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" -+ LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" -+ LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" +- LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" +- LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" ++ LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" ++ LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index b3863418126e1b..cb2dc13e58bb06 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "f1ed2d559c09f969d34bc870a03b882d5a4ac813" - SHARDY_SHA256 = "73858df3e06c3afad362308baca03e5bf319c31f3227c7b990046389320303c0" + SHARDY_COMMIT = "5013981d546b3bb99d9193841edcd5318cce3ce2" + SHARDY_SHA256 = "800366d7604691e63939e2cb3ecb4acfa349253f2f0ed8b1b84e6783fad55a01" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 613660a484cee2..f6938677141184 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,15 +1,15 @@ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index c011aab..55290bf 100644 +index 55290bf..726a367 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" -- LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" -+ LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" -+ LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" +- LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" +- LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" ++ LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" ++ LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index b3863418126e1b..cb2dc13e58bb06 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "f1ed2d559c09f969d34bc870a03b882d5a4ac813" - SHARDY_SHA256 = "73858df3e06c3afad362308baca03e5bf319c31f3227c7b990046389320303c0" + SHARDY_COMMIT = "5013981d546b3bb99d9193841edcd5318cce3ce2" + SHARDY_SHA256 = "800366d7604691e63939e2cb3ecb4acfa349253f2f0ed8b1b84e6783fad55a01" tf_http_archive( name = "shardy", From 0e638463cc81f6902049cec5dd471fdd17fdaa0f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 23 Sep 2024 13:46:22 -0700 Subject: [PATCH 153/483] [PJRT] Relax visibility of Bazel targets used by JAX. Change in preparation for enabling Bazel visibility in JAX builds. PiperOrigin-RevId: 677924487 --- third_party/xla/xla/pjrt/BUILD | 2 +- third_party/xla/xla/pjrt/c/BUILD | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 7eda0b9e119b8d..d7c70c22f9454f 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -986,7 +986,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":exceptions", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index 0556ee9f61887a..023de87343880c 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -17,9 +17,13 @@ load( "//xla/tsl:tsl.bzl", "if_google", "if_macos", + "internal_visibility", ) -# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla:internal"]), +) cc_library( name = "pjrt_c_api_hdrs", From c44b812c6325ce4a45ab2f42ac6e9f95e0d44235 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Mon, 23 Sep 2024 14:24:28 -0700 Subject: [PATCH 154/483] [IFRT] Spell out xla::ifrt::Layout alias as xla::PjRtLayout in preparation of introducing non-aliased xla::ifrt::Layout IFRT will introduce its `Layout` interface that follows IFRT conventions (LLVM RTTI and IFRT SerDes). PjRt-IFRT will have a wrapper around existing `xla::PjRtLayout`. To incrementally migrate the code, we first spell out `xla::ifrt::Layout` as `xla::PjRtLayout` before IFRT defines a separate `xla::ifrt::Layout` class. They will eventually use `xla::ifrt::Layout` again, but we will have a mix of `xla::PjRtLayout` and `xla::ifrt::Layout` in the IFRT API during the transition. PiperOrigin-RevId: 677939218 --- third_party/xla/xla/python/ifrt/client.h | 5 +++-- third_party/xla/xla/python/ifrt/executable.h | 9 +++++---- third_party/xla/xla/python/ifrt/mock.h | 8 ++++---- .../xla/python/ifrt_proxy/client/executable.cc | 8 ++++---- .../xla/python/ifrt_proxy/client/executable.h | 8 ++++---- .../ifrt_proxy/server/ifrt_backend_test.cc | 4 ++-- .../xla/xla/python/pjrt_ifrt/pjrt_executable.h | 16 ++++++++-------- 7 files changed, 30 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/client.h b/third_party/xla/xla/python/ifrt/client.h index 5b48bedd713a6a..bf03b857c8b254 100644 --- a/third_party/xla/xla/python/ifrt/client.h +++ b/third_party/xla/xla/python/ifrt/client.h @@ -225,8 +225,9 @@ class Client : public llvm::RTTIExtends { // single-shard dimensions `dims`. // TODO(hyeontaek): Change the API to take `Shape` and `Sharding` instead of // single-shard dimensions and device. - virtual absl::StatusOr> GetDefaultLayoutForDevice( - DType dtype, absl::Span dims, Device* device) const = 0; + virtual absl::StatusOr> + GetDefaultLayoutForDevice(DType dtype, absl::Span dims, + Device* device) const = 0; static char ID; // NOLINT }; diff --git a/third_party/xla/xla/python/ifrt/executable.h b/third_party/xla/xla/python/ifrt/executable.h index 08fa0de003ddae..6b642bd5d178d6 100644 --- a/third_party/xla/xla/python/ifrt/executable.h +++ b/third_party/xla/xla/python/ifrt/executable.h @@ -28,6 +28,7 @@ limitations under the License. #include "llvm/Support/ExtensibleRTTI.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" @@ -74,10 +75,10 @@ class Executable : public llvm::RTTIExtends { // Returns a list of output `OpSharding`. virtual std::optional> GetOutputShardings() const = 0; // Returns a list of parameter layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const = 0; // Returns a list of output/result layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const = 0; // Returns an `HloModule` (optimized) per partition. virtual absl::StatusOr>> @@ -153,10 +154,10 @@ class LoadedExecutable // Returns a list of output OpSharding. virtual std::optional> GetOutputShardings() const = 0; // Returns a list of parameter layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const = 0; // Returns a list of output/result layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const = 0; // Return an HloModule (optimized) per partition. virtual absl::StatusOr>> diff --git a/third_party/xla/xla/python/ifrt/mock.h b/third_party/xla/xla/python/ifrt/mock.h index 125c95e58a3fb5..c437268b2f7f43 100644 --- a/third_party/xla/xla/python/ifrt/mock.h +++ b/third_party/xla/xla/python/ifrt/mock.h @@ -244,9 +244,9 @@ class MockExecutable : public llvm::RTTIExtends { (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetParameterLayouts, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetHloModules, (), (const, final)); @@ -273,9 +273,9 @@ class MockLoadedExecutable (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetParameterLayouts, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetOutputMemoryKinds, (), (const, final)); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc index 37c7d4795e0509..d04c612f529e12 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc @@ -355,12 +355,12 @@ std::optional> LoadedExecutable::GetOutputShardings() return (*info)->output_shardings; } -absl::StatusOr>> +absl::StatusOr>> LoadedExecutable::GetParameterLayouts() const { TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); TF_RETURN_IF_ERROR(info->parameter_layouts.status()); - std::vector> result; + std::vector> result; result.reserve(info->parameter_layouts->size()); for (const xla::Layout& layout : *info->parameter_layouts) { result.push_back(std::make_unique(layout)); @@ -368,12 +368,12 @@ LoadedExecutable::GetParameterLayouts() const { return result; } -absl::StatusOr>> +absl::StatusOr>> LoadedExecutable::GetOutputLayouts() const { TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); TF_RETURN_IF_ERROR(info->output_layouts.status()); - std::vector> result; + std::vector> result; result.reserve(info->output_layouts->size()); for (const xla::Layout& layout : *info->output_layouts) { result.push_back(std::make_unique(layout)); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.h b/third_party/xla/xla/python/ifrt_proxy/client/executable.h index 2df7d17a8ffaae..9afc875cf9bf11 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.h @@ -75,10 +75,10 @@ class LoadedExecutable final std::optional> GetParameterShardings() const override; std::optional> GetOutputShardings() const override; - absl::StatusOr>> GetParameterLayouts() - const override; - absl::StatusOr>> GetOutputLayouts() - const override; + absl::StatusOr>> + GetParameterLayouts() const override; + absl::StatusOr>> + GetOutputLayouts() const override; absl::StatusOr>> GetOutputMemoryKinds() const override; absl::StatusOr>> GetHloModules() diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 987708935f3da2..17ae98c13cd76b 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -1063,7 +1063,7 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableMetadata) { EXPECT_CALL(*executable, GetOutputShardings()) .WillOnce(Return(std::vector{op_sharding1})); - std::vector> parameter_layouts; + std::vector> parameter_layouts; parameter_layouts.push_back(std::make_unique( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1))); parameter_layouts.push_back(std::make_unique( @@ -1071,7 +1071,7 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableMetadata) { EXPECT_CALL(*executable, GetParameterLayouts()) .WillOnce(Return(std::move(parameter_layouts))); - std::vector> output_layouts; + std::vector> output_layouts; output_layouts.push_back(std::make_unique( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2))); EXPECT_CALL(*executable, GetOutputLayouts()) diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h index 6ebd1f9e903481..ce83ee0da24de1 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h @@ -116,14 +116,14 @@ class PjRtExecutable final return pjrt_executable_->GetOutputShardings(); } - absl::StatusOr>> GetParameterLayouts() - const override { + absl::StatusOr>> + GetParameterLayouts() const override { DCHECK(this); return pjrt_executable_->GetParameterLayouts(); } - absl::StatusOr>> GetOutputLayouts() - const override { + absl::StatusOr>> + GetOutputLayouts() const override { DCHECK(this); return pjrt_executable_->GetOutputLayouts(); } @@ -242,14 +242,14 @@ class PjRtLoadedExecutable final return pjrt_loaded_executable_->GetOutputShardings(); } - absl::StatusOr>> GetParameterLayouts() - const override { + absl::StatusOr>> + GetParameterLayouts() const override { DCHECK(this); return pjrt_loaded_executable_->GetParameterLayouts(); } - absl::StatusOr>> GetOutputLayouts() - const override { + absl::StatusOr>> + GetOutputLayouts() const override { DCHECK(this); return pjrt_loaded_executable_->GetOutputLayouts(); } From 318907e681c9b95ebedbc839240eb99f06a9fc8c Mon Sep 17 00:00:00 2001 From: Luke Boyer Date: Mon, 23 Sep 2024 14:48:19 -0700 Subject: [PATCH 155/483] Reverts 5a74b87f9529d7522055aae15cd4f8a816d2bd6e PiperOrigin-RevId: 677947649 --- tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD | 2 ++ tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD | 1 + 2 files changed, 3 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD index 3601f00137de02..d584670b49730e 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD @@ -70,6 +70,7 @@ cc_library( cc_test( name = "model_test", srcs = ["model_test.cc"], + tags = ["no_oss"], deps = [ ":api_internal", ":graph_tools", @@ -97,6 +98,7 @@ cc_library( cc_test( name = "algo_test", srcs = ["algo_test.cc"], + tags = ["no_oss"], deps = [ ":algo", ":api_internal", diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD index 9be77a87971794..fbb21622ab2d30 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD @@ -37,6 +37,7 @@ cc_shared_library( cc_test( name = "mul_op_plugin_test", srcs = ["mul_op_plugin_test.cc"], + tags = ["no_oss"], deps = [ ":mul_op_plugin", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", From 275162ddbea32d55e17bc227be0acbcfc185d3a8 Mon Sep 17 00:00:00 2001 From: Kris Tonthat Date: Mon, 23 Sep 2024 15:20:00 -0700 Subject: [PATCH 156/483] Update lite.py Update documentation link for tf.lite.Optimize --- tensorflow/lite/python/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 87cb2bdf8969f7..6a77e4eeea7336 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -114,7 +114,7 @@ class Optimize(enum.Enum): The default optimization strategy that enables post-training quantization. The type of post-training quantization that will be used is dependent on the other converter options supplied. Refer to the - [documentation](/lite/performance/post_training_quantization) for further + [documentation](https://ai.google.dev/edge/litert/models/post_training_quantization) for further information on the types available and how to use them. OPTIMIZE_FOR_SIZE From 2653b8eb5c90668a3223897eea75339a7d6fcebd Mon Sep 17 00:00:00 2001 From: Kris Tonthat Date: Mon, 23 Sep 2024 16:22:56 -0700 Subject: [PATCH 157/483] Update lite.py --- tensorflow/lite/python/lite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 6a77e4eeea7336..315e0c64848c62 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -114,8 +114,8 @@ class Optimize(enum.Enum): The default optimization strategy that enables post-training quantization. The type of post-training quantization that will be used is dependent on the other converter options supplied. Refer to the - [documentation](https://ai.google.dev/edge/litert/models/post_training_quantization) for further - information on the types available and how to use them. + [documentation](https://ai.google.dev/edge/litert/models/post_training_quantization) + for further information on the types available and how to use them. OPTIMIZE_FOR_SIZE Deprecated. Does the same as DEFAULT. From 5c18da76c1097b2492b9d71e32fc1fb187f248e0 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Mon, 23 Sep 2024 16:17:24 -0700 Subject: [PATCH 158/483] [xla][tpu] Adds support for HLO value tracking in logging Add support for original_value in TPU logging. PiperOrigin-RevId: 677978849 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 15 +-------------- .../xla/xla/hlo/ir/hlo_original_value.cc | 18 ++++++++++++++++++ .../xla/xla/hlo/ir/hlo_original_value.h | 2 ++ 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index ad23f22909db5e..7ca026f7c94448 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -4094,20 +4094,7 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_statistics_viz() = statistics_viz(); if (original_value_) { - xla::OriginalValueProto* original_value_proto = - proto.mutable_original_value(); - for (const auto& leaf : original_value_->leaves()) { - OriginalArrayProto* original_array_proto = - original_value_proto->add_leaves(); - for (const auto& index : leaf.first) { - original_array_proto->add_leaf_shape_index(index); - } - *original_array_proto->mutable_instruction_name() = - leaf.second->instruction_name; - for (const auto& index : leaf.second->shape_index) { - original_array_proto->add_shape_index(index); - } - } + *proto.mutable_original_value() = OriginalValueToProto(*original_value_); } return proto; diff --git a/third_party/xla/xla/hlo/ir/hlo_original_value.cc b/third_party/xla/xla/hlo/ir/hlo_original_value.cc index 789978d74cbf39..111297ebc879ca 100644 --- a/third_party/xla/xla/hlo/ir/hlo_original_value.cc +++ b/third_party/xla/xla/hlo/ir/hlo_original_value.cc @@ -65,4 +65,22 @@ std::string OriginalValueToString(const OriginalValue& original_value) { return OriginalValueToStringHelper(original_value, original_value.shape(), shape_index); } + +OriginalValueProto OriginalValueToProto(const OriginalValue& original_value) { + OriginalValueProto original_value_proto; + for (const auto& leaf : original_value.leaves()) { + OriginalArrayProto* original_array_proto = + original_value_proto.add_leaves(); + for (const auto& index : leaf.first) { + original_array_proto->add_leaf_shape_index(index); + } + *original_array_proto->mutable_instruction_name() = + leaf.second->instruction_name; + for (const auto& index : leaf.second->shape_index) { + original_array_proto->add_shape_index(index); + } + } + return original_value_proto; +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_original_value.h b/third_party/xla/xla/hlo/ir/hlo_original_value.h index a77bc8a13460c7..70d46a8734ac52 100644 --- a/third_party/xla/xla/hlo/ir/hlo_original_value.h +++ b/third_party/xla/xla/hlo/ir/hlo_original_value.h @@ -32,6 +32,8 @@ struct OriginalArray { using OriginalValue = ShapeTree>; std::string OriginalValueToString(const OriginalValue& original_value); + +OriginalValueProto OriginalValueToProto(const OriginalValue& original_value); } // namespace xla #endif // XLA_HLO_IR_HLO_ORIGINAL_VALUE_H_ From 4ba6eb0d98bfe405e2254077955fefb793a2fc4b Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Mon, 23 Sep 2024 16:38:19 -0700 Subject: [PATCH 159/483] Rename compiler.h containing IfrtIrProgram to ifrt_ir_program.h PiperOrigin-RevId: 677986233 --- third_party/xla/xla/python/ifrt/ir/BUILD | 7 ++++--- .../ir/{compiler.cc => ifrt_ir_program.cc} | 9 +++++++-- .../ifrt/ir/{compiler.h => ifrt_ir_program.h} | 18 +++++++++--------- third_party/xla/xla/python/ifrt/ir/tests/BUILD | 2 +- .../ifrt/ir/tests/executable_impl_test_lib.cc | 2 +- 5 files changed, 22 insertions(+), 16 deletions(-) rename third_party/xla/xla/python/ifrt/ir/{compiler.cc => ifrt_ir_program.cc} (83%) rename third_party/xla/xla/python/ifrt/ir/{compiler.h => ifrt_ir_program.h} (84%) diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index 01233d18276583..22d1979f9e3ed0 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -170,14 +170,15 @@ cc_library( ) cc_library( - name = "compiler", - srcs = ["compiler.cc"], - hdrs = ["compiler.h"], + name = "ifrt_ir_program", + srcs = ["ifrt_ir_program.cc"], + hdrs = ["ifrt_ir_program.h"], compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla/python/ifrt", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/python/ifrt/ir/compiler.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc similarity index 83% rename from third_party/xla/xla/python/ifrt/ir/compiler.cc rename to third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc index 8922d23ec30f22..12f2b07fc9ac67 100644 --- a/third_party/xla/xla/python/ifrt/ir/compiler.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/ifrt/ir/compiler.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/Support/Casting.h" +#include "xla/python/ifrt/compiler.h" + namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt/ir/compiler.h b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h similarity index 84% rename from third_party/xla/xla/python/ifrt/ir/compiler.h rename to third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h index 15db4a974121ec..809afb7819322f 100644 --- a/third_party/xla/xla/python/ifrt/ir/compiler.h +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_IFRT_IR_COMPILER_H_ -#define XLA_PYTHON_IFRT_IR_COMPILER_H_ +#ifndef XLA_PYTHON_IFRT_IR_IFRT_IR_PROGRAM_H_ +#define XLA_PYTHON_IFRT_IR_IFRT_IR_PROGRAM_H_ #include #include @@ -58,13 +58,13 @@ struct IfrtIRCompileOptions loaded_exec_binding(std::move(loaded_exec_binding)), compile_options_overrides(std::move(compile_options_overrides)) {} - // Map from logical device ids in MLIR module to runtime device ids obtained - // from IFRT client. + // Mapping from logical device ids in IFRT IR MLIR module to runtime device + // ids obtained from IFRT client. std::vector device_assignments; - // Map from `getSymName()` of declared LoadedExecutableOp in the `mlir_module` - // to pre-compiled LoadedExecutable instance. The LoadedExecutables must - // outlive the LoadedExecutable to be compiled. + // Map from symbol names of LoadedExecutableOp in the IFRT IR MLIR module + // to pre-compiled `LoadedExecutable` instance. The `LoadedExecutable`s must + // outlive the `LoadedExecutable` of the IFRT IR program. absl::flat_hash_map> loaded_exec_binding; @@ -85,4 +85,4 @@ absl::StatusOr> GetIfrtIRCompileOptions( } // namespace ifrt } // namespace xla -#endif // XLA_PYTHON_IFRT_IR_COMPILER_H_ +#endif // XLA_PYTHON_IFRT_IR_IFRT_IR_PROGRAM_H_ diff --git a/third_party/xla/xla/python/ifrt/ir/tests/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/BUILD index 01ca1bff5c92e8..ab7e1250422b5a 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/BUILD @@ -93,7 +93,7 @@ cc_library( "//xla/python/ifrt", "//xla/python/ifrt:test_util", "//xla/python/ifrt/hlo:hlo_program", - "//xla/python/ifrt/ir:compiler", + "//xla/python/ifrt/ir:ifrt_ir_program", "//xla/python/ifrt/ir:sharding_param", "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc index cef4bdf654c52d..a9b5741a87d043 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/hlo/hlo_program.h" -#include "xla/python/ifrt/ir/compiler.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" #include "xla/python/ifrt/ir/sharding_param.h" #include "xla/python/ifrt/ir/tests/executable_impl_test_base.h" #include "xla/python/ifrt/shape.h" From e1b495dbc59289b7bdb9eeedb8fb8580118fc448 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Mon, 23 Sep 2024 16:47:41 -0700 Subject: [PATCH 160/483] [XLA:GatherScatter] Fix updated indices_are_sorted and handle case of batching dim size overflowing indices integer type. PiperOrigin-RevId: 677989389 --- .../batched_gather_scatter_normalizer.cc | 76 +++-- .../batched_gather_scatter_normalizer_test.cc | 272 ++++++++++++++++++ 2 files changed, 330 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc index 441c3b69f3da28..c3b6c6d96a250a 100644 --- a/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -46,42 +47,75 @@ bool IsBatchScatter(const HloScatterInstruction* scatter) { return !dims.input_batching_dims().empty(); } +// If `type` is an integral type in which `size` doesn't fit, promote it to S32 +// or S64 (depending on `size`). +PrimitiveType PromoteTypeForSize(PrimitiveType type, int64_t size) { + // Gather/Scatter should have an integral type, but we check just in case. + if (!primitive_util::IsIntegralType(type) || + primitive_util::FitsInIntegralType(size, type)) { + return type; + } + if (primitive_util::FitsInIntegralType(size, PrimitiveType::S32)) { + return PrimitiveType::S32; + } + return PrimitiveType::S64; +} + +// If `indices_batching_dims` and `updated_index_map` are both sorted, then the +// `indices_are_sorted` property is preserved. +// +// This is because each concatenated iota is monotonically increasing, sorted +// indices batching dims mean their order corresponds to the order of batching +// dims in the operand, and a sorted updated start index map means the order of +// the index vector dim corresponds to the order of operand dims. +bool GetUpdatedIndicesAreSorted(bool indices_are_sorted, + absl::Span indices_batching_dims, + absl::Span updated_index_map) { + return indices_are_sorted && absl::c_is_sorted(indices_batching_dims) && + absl::c_is_sorted(updated_index_map); +} + // Update gather/scater indices by adding fake batching iota dimensions. HloInstruction* CreateConcatIndices( HloInstruction* inst, HloInstruction* indices, int64_t index_vector_dim, absl::Span indices_batching_dims, BatchedGatherScatterNormalizer* normalizer) { - const bool index_vector_dim_on_last_dim = - index_vector_dim == indices->shape().rank(); + // The batching dim sizes might not fit in the existing element type, + // in which case we need to promote it. + PrimitiveType element_type = indices->shape().element_type(); + for (int64_t indices_batching_dim : indices_batching_dims) { + element_type = PromoteTypeForSize( + element_type, indices->shape().dimensions(indices_batching_dim)); + } + if (element_type != indices->shape().element_type()) { + Shape indices_shape = indices->shape(); + indices_shape.set_element_type(element_type); + indices = inst->parent()->AddInstruction( + HloInstruction::CreateConvert(indices_shape, indices)); + } Shape iota_shape = indices->shape(); + const bool index_vector_dim_on_last_dim = + index_vector_dim == iota_shape.rank(); if (index_vector_dim_on_last_dim) { std::vector dimensions(iota_shape.dimensions().begin(), iota_shape.dimensions().end()); dimensions.push_back(1); - iota_shape = ShapeUtil::MakeShape(iota_shape.element_type(), dimensions); + iota_shape = ShapeUtil::MakeShape(element_type, dimensions); + indices = inst->AddInstruction( + HloInstruction::CreateReshape(iota_shape, indices)); } iota_shape.set_dimensions(index_vector_dim, 1); normalizer->UpdateLayout(&iota_shape); std::vector indices_to_concat; + indices_to_concat.reserve(indices_batching_dims.size() + 1); for (int64_t indices_batching_dim : indices_batching_dims) { indices_to_concat.push_back(inst->parent()->AddInstruction( HloInstruction::CreateIota(iota_shape, indices_batching_dim))); } - if (index_vector_dim_on_last_dim) { - std::vector dimensions(indices->shape().dimensions().begin(), - indices->shape().dimensions().end()); - dimensions.push_back(1); - Shape reshape_shape = - ShapeUtil::MakeShape(indices->shape().element_type(), dimensions); - normalizer->UpdateLayout(&reshape_shape); - HloInstruction* reshaped_indices = inst->AddInstruction( - HloInstruction::CreateReshape(reshape_shape, indices)); - indices_to_concat.push_back(reshaped_indices); - } else { - indices_to_concat.push_back(indices); - } + indices_to_concat.push_back(indices); + Shape concat_shape = iota_shape; concat_shape.set_dimensions( index_vector_dim, @@ -121,7 +155,10 @@ absl::StatusOr NormalizeBatchGather( dims.index_vector_dim()); return gather->AddInstruction(HloInstruction::CreateGather( gather->shape(), gather_operand, gather_indices, updated_dims, - gather->gather_slice_sizes(), gather->indices_are_sorted())); + gather->gather_slice_sizes(), + GetUpdatedIndicesAreSorted(gather->indices_are_sorted(), + dims.start_indices_batching_dims(), + start_index_map))); } absl::StatusOr NormalizeBatchScatter( @@ -154,7 +191,10 @@ absl::StatusOr NormalizeBatchScatter( scatter_dims_to_operand_dims, dims.index_vector_dim()); return scatter->AddInstruction(HloInstruction::CreateScatter( scatter->shape(), scatter_operands, scatter_indices, scatter_updates, - scatter->to_apply(), updated_dims, scatter->indices_are_sorted(), + scatter->to_apply(), updated_dims, + GetUpdatedIndicesAreSorted(scatter->indices_are_sorted(), + dims.scatter_indices_batching_dims(), + scatter_dims_to_operand_dims), scatter->unique_indices())); } diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc index 22bbdea6fb9be0..ea6995651389d3 100644 --- a/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc @@ -79,6 +79,126 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512,1024,100], start_indices: s6 )"); } +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchGatherIndicesBecomeUnsorted) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[3,4,1]{2,1,0})->f32[3,4,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,3,4,512], start_indices: s64[3,4,1]) -> f32[3,4,5] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %start_indices = s64[3,4,1]{2,1,0} parameter(1) + ROOT %gather = f32[3,4,5]{2,1,0} + gather(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[3,4,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={1}, start_index_map={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,1,5}, + indices_are_sorted=true +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[3,4,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[3,4,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s64[3,4,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[3,4,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={0,2,1}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,1,5} + CHECK-NOT: indices_are_sorted + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchGatherIndicesBecomeUnsorted2) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[3,2,1]{2,1,0})->f32[3,2,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,3,4,512], start_indices: s64[3,2,1]) -> f32[3,2,5] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %start_indices = s64[3,2,1]{2,1,0} parameter(1) + ROOT %gather = f32[3,2,5]{2,1,0} + gather(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[3,2,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={2}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={1,0}, index_vector_dim=2, slice_sizes={1,1,1,5}, + indices_are_sorted=true +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[3,2,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[IOTA2:.*]] = s64[3,2,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[INDICES_CONCAT:.*]] = s64[3,2,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[3,2,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,1,5} + CHECK-NOT: indices_are_sorted + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchGatherIndicesRemainSorted) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[2,3,1]{2,1,0})->f32[2,3,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,3,4,512], start_indices: s64[2,3,1]) -> f32[2,3,5] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %start_indices = s64[2,3,1]{2,1,0} parameter(1) + ROOT %gather = f32[2,3,5]{2,1,0} + gather(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[2,3,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={2}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,1,5}, + indices_are_sorted=true +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s64[2,3,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[2,3,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,1,5} + CHECK-SAME: indices_are_sorted=true + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchGatherIndicesRemainUnsorted) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[2,3,1]{2,1,0})->f32[2,3,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,3,4,512], start_indices: s64[2,3,1]) -> f32[2,3,5] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %start_indices = s64[2,3,1]{2,1,0} parameter(1) + ROOT %gather = f32[2,3,5]{2,1,0} + gather(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[2,3,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={2}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,1,5}, + indices_are_sorted=false +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s64[2,3,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[2,3,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,1,5} + CHECK-NOT: indices_are_sorted + )"); +} + TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchGatherDimSizeZero) { constexpr absl::string_view kModuleStr = R"( HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,0]{5,4,3,2,1,0}, s64[10,9,8,7,5,0]{5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,0]{9,8,7,6,5,4,3,2,1,0}} @@ -180,6 +300,42 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512,1024,100], scatter_indices: )"); } +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchScatterIndicesRemainSorted) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyScatter, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[2,3,1]{2,1,0}, f32[2,3,5]{2,1,0})->f32[2,3,4,512]{3,2,1,0}} + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[2,3,4,512], scatter_indices: s64[2,3,1], updates: f32[2,3,5]) -> f32[2,3,4,512] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %scatter_indices = s64[2,3,1]{2,1,0} parameter(1) + %updates = f32[2,3,5]{2,1,0} parameter(2) + ROOT %scatter = f32[2,3,4,512]{3,2,1,0} + scatter(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[2,3,1]{2,1,0} %scatter_indices, f32[2,3,5]{2,1,0} %updates), + update_window_dims={2}, inserted_window_dims={2}, scatter_dims_to_operand_dims={2}, input_batching_dims={0,1}, + scatter_indices_batching_dims={0,1}, index_vector_dim=2, indices_are_sorted=true, to_apply=%add_F32.v3 +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s64[2,3,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %scatter_indices) + CHECK: ROOT %[[SCATTER:.*]] = f32[2,3,4,512]{{.*}} scatter( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]], %updates), + CHECK-SAME: update_window_dims={2}, + CHECK-SAME: inserted_window_dims={0,1,2}, + CHECK-SAME: scatter_dims_to_operand_dims={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: indices_are_sorted=true + CHECK-SAME: to_apply=%add_F32.v3 + )"); +} + TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchScatterDimSizeZero) { constexpr absl::string_view kModuleStr = R"( @@ -245,5 +401,121 @@ ENTRY %Gather (input_tensor: f32[50,512,1024], start_indices: s64[10,9,8,7,6,512 )"); } +TEST_F(BatchedGatherScatterNormalizerTest, + BatchingDimSizeDoesNotOverflowIndicesType) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,127,512]{2,1,0}, s8[2,127,1]{2,1,0})->f32[2,127,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,127,512], start_indices: s8[2,127,1]) -> f32[2,127,5] { + %input_tensor = f32[2,127,512]{2,1,0} parameter(0) + %start_indices = s8[2,127,1]{2,1,0} parameter(1) + ROOT %gather = f32[2,127,5]{2,1,0} + gather(f32[2,127,512]{2,1,0} %input_tensor, s8[2,127,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,5} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s8[2,127,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s8[2,127,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s8[2,127,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[2,127,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,5} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + BatchingDimSizeOverflowsIndicesType) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,128,512]{2,1,0}, s8[2,128,1]{2,1,0})->f32[2,128,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,128,512], start_indices: s8[2,128,1]) -> f32[2,128,5] { + %input_tensor = f32[2,128,512]{2,1,0} parameter(0) + %start_indices = s8[2,128,1]{2,1,0} parameter(1) + ROOT %gather = f32[2,128,5]{2,1,0} + gather(f32[2,128,512]{2,1,0} %input_tensor, s8[2,128,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,5} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s32[2,128,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s32[2,128,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[CONVERT:.*]] = s32[2,128,1]{{.*}} convert(%start_indices) + CHECK: %[[INDICES_CONCAT:.*]] = s32[2,128,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %[[CONVERT]]) + CHECK: ROOT %[[GATHER:.*]] = f32[2,128,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,5} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + BatchingDimSizeOverflowsIndicesTypeAndS32) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2147483648,2,512]{2,1,0}, s8[2147483648,2,1]{2,1,0})->f32[2147483648,2,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2147483648,2,512], start_indices: s8[2147483648,2,1]) -> f32[2147483648,2,5] { + %input_tensor = f32[2147483648,2,512]{2,1,0} parameter(0) + %start_indices = s8[2147483648,2,1]{2,1,0} parameter(1) + ROOT %gather = f32[2147483648,2,5]{2,1,0} + gather(f32[2147483648,2,512]{2,1,0} %input_tensor, s8[2147483648,2,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,5} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[2147483648,2,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[2147483648,2,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[CONVERT:.*]] = s64[2147483648,2,1]{{.*}} convert(%start_indices) + CHECK: %[[INDICES_CONCAT:.*]] = s64[2147483648,2,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %[[CONVERT]]) + CHECK: ROOT %[[GATHER:.*]] = f32[2147483648,2,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,5} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + BatchingDimSizeOverflowsAndIndexVectorDimOnLastDim) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,128,512]{2,1,0}, s8[2,128]{1,0})->f32[2,128,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,128,512], start_indices: s8[2,128]) -> f32[2,128,5] { + %input_tensor = f32[2,128,512]{2,1,0} parameter(0) + %start_indices = s8[2,128]{1,0} parameter(1) + ROOT %gather = f32[2,128,5]{2,1,0} + gather(f32[2,128,512]{2,1,0} %input_tensor, s8[2,128]{1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,5} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s32[2,128,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s32[2,128,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[CONVERT:.*]] = s32[2,128]{{.*}} convert(%start_indices) + CHECK: %[[RESHAPE:.*]] = s32[2,128,1]{{.*}} reshape(%[[CONVERT]]) + CHECK: %[[INDICES_CONCAT:.*]] = s32[2,128,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %[[RESHAPE]]) + CHECK: ROOT %[[GATHER:.*]] = f32[2,128,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,5} + )"); +} + } // namespace } // namespace xla From 8dc568e66c82d68ab9cd610895e27e522e2f2b6b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 19:00:07 -0700 Subject: [PATCH 161/483] When checking for the shapes, we should take care of the dynamic shapes. PiperOrigin-RevId: 678026956 --- .../tests/tpu_sharding_identification.mlir | 44 +++++++++++++++++++ .../tpu_sharding_identification_pass.cc | 27 ++++++++++-- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index abb68b92146ba7..1aa574d12bf1ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -957,3 +957,47 @@ func.func @func(%arg0: tensor) -> tensor<4xf32> { func.return %1 : tensor<4xf32> } + +// ----- +// CHECK-LABEL: func @check_AddV2_variant_shape_with_input_sharding_propagation +func.func @check_AddV2_variant_shape_with_input_sharding_propagation(%arg0: tensor, %arg1: tensor<12x384xbf16>) { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["sharding_info_1", "sharding_info_1"] + // CHECK-SAME: output_sharding_configuration = ["sharding_info_1"] + "tf_device.cluster_func"(%arg0, %arg1) { + func = @func, + use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64 + } : (tensor, tensor<12x384xbf16>) -> tensor + func.return +} + +// CHECK-LABEL: func @func +// CHECK: {{.*}}mhlo.sharding = "sharding_info_1"{{.*}}mhlo.sharding = "sharding_info_1"{{.*}}->{{.*}}mhlo.sharding = "sharding_info_1" +func.func @func(%arg0: tensor, %arg1: tensor<12x384xbf16>) -> tensor { + %add = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<12x384xbf16>) -> tensor + %0 = "tf.XlaSharding"(%add) { _XlaSharding = "sharding_info_1"} : (tensor) -> tensor + func.return %0 : tensor +} + + + +// ----- +// CHECK-LABEL: func @check_BatchMatMul_variant_shape_without_input_sharding_propagation +func.func @check_BatchMatMul_variant_shape_without_input_sharding_propagation(%arg0: tensor, %arg1: tensor<256x384xbf16>) { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["", ""] + // CHECK-SAME: output_sharding_configuration = ["sharding_info_1"] + "tf_device.cluster_func"(%arg0, %arg1) { + func = @func, + use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64 + } : (tensor, tensor<256x384xbf16>) -> tensor + func.return +} + +// CHECK-LABEL: func @func +// CHECK: {{.*}}mhlo.sharding = ""{{.*}}mhlo.sharding = ""{{.*}}->{{.*}}mhlo.sharding = "sharding_info_1" +func.func @func(%arg0: tensor, %arg1: tensor<256x384xbf16>) -> tensor { + %mul = "tf.BatchMatMul"(%arg0, %arg1) : (tensor, tensor<256x384xbf16>) -> tensor + %0 = "tf.XlaSharding"(%mul) { _XlaSharding = "sharding_info_1"} : (tensor) -> tensor + func.return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc index 0d2475c5be5433..2d3bb7a5a3fc3b 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc @@ -129,13 +129,33 @@ bool BinaryOpHasTraitsForSharding(Operation* op) { return false; } -bool DoTypesHaveSameShape(Value value_0, Value value_1) { +bool DoTypesHavePartialSameShape(Value value_0, Value value_1) { auto shape_0 = mlir::dyn_cast_or_null(value_0.getType()); auto shape_1 = mlir::dyn_cast_or_null(value_1.getType()); if (shape_0 && shape_1) { - return shape_0.getShape() == shape_1.getShape(); + if (shape_0.hasStaticShape() && shape_1.hasStaticShape()) + return shape_0.getShape() == shape_1.getShape(); + int i = 0, j = 0; + while (i < shape_0.getShape().size() && j < shape_1.getShape().size()) { + if (shape_0.getShape()[i] != shape_1.getShape()[j] && + !shape_0.isDynamicDim(i) && !shape_1.isDynamicDim(j)) { + return false; + } + if (shape_0.getShape()[i] == shape_1.getShape()[j]) { + i++; + j++; + } else { + if (shape_0.isDynamicDim(i)) { + i++; + } + if (shape_1.isDynamicDim(j)) { + j++; + } + } + } + return i == shape_0.getShape().size() && j == shape_1.getShape().size(); } return false; } @@ -337,7 +357,8 @@ std::optional GetXlaShardingFromArg( } if (BinaryOpHasTraitsForSharding(owner)) { - if (DoTypesHaveSameShape(value_to_visit, owner->getResult(0))) { + if (DoTypesHavePartialSameShape(value_to_visit, + owner->getResult(0))) { next_values_to_visit.push_back(use.getOwner()->getResult(0)); continue; } From 783ae3c623428f8e7857160ab0a34f6d0a55cb91 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 20:13:04 -0700 Subject: [PATCH 162/483] [xla][cleanup] remove commented line from EmitComplexRsqrt PiperOrigin-RevId: 678045022 --- third_party/xla/xla/service/elemental_ir_emitter.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 15961cff6328bc..aff56f92b15601 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -1779,7 +1779,6 @@ absl::StatusOr ElementalIrEmitter::EmitComplexRsqrt( llvm::Value* neg_one = llvm::ConstantFP::get(type, -1); llvm::Value* inf = llvm::ConstantFP::getInfinity(type); llvm::Value* nan = llvm::ConstantFP::getNaN(type); - // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true); llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_); llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic( From db8129f1c33e63576fc00f06301fb063e5c912d7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 21:02:30 -0700 Subject: [PATCH 163/483] Automated Code Change PiperOrigin-RevId: 678057081 --- tensorflow/c/experimental/filesystem/BUILD | 8 ++++++++ .../c/experimental/filesystem/modular_filesystem.cc | 12 ++++++++++++ .../c/experimental/filesystem/modular_filesystem.h | 5 +++++ .../filesystem/modular_filesystem_registration.cc | 7 +++++++ 4 files changed, 32 insertions(+) diff --git a/tensorflow/c/experimental/filesystem/BUILD b/tensorflow/c/experimental/filesystem/BUILD index d25e6e9314f088..a8df18adf63470 100644 --- a/tensorflow/c/experimental/filesystem/BUILD +++ b/tensorflow/c/experimental/filesystem/BUILD @@ -36,11 +36,19 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":filesystem_interface", + "//tensorflow/c:tf_file_statistics", "//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_internal", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core/platform:env", "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:file_statistics", "//tensorflow/core/platform:status", + "//tensorflow/core/platform:strcat", + "//tensorflow/core/platform:stringpiece", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", ], ) diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.cc b/tensorflow/c/experimental/filesystem/modular_filesystem.cc index d030948787acdd..7fede4ff7dc801 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.cc @@ -18,11 +18,23 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h" +#include "tensorflow/c/tf_file_statistics.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/file_statistics.h" +#include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/file_system_helper.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tsl/platform/errors.h" +#include "tsl/platform/file_system.h" // TODO(b/139060984): After all filesystems are converted, all calls to // methods from `FileSystem` will have to be replaced to calls to private diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.h b/tensorflow/c/experimental/filesystem/modular_filesystem.h index 091b84529668a5..dc2096aafb2d66 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.h +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.h @@ -18,7 +18,12 @@ limitations under the License. #include #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/core/platform/file_statistics.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/file_system.h" /// This file builds classes needed to hold a filesystem implementation in the /// modular world. Once all TensorFlow filesystems are converted to use the diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc index ce5a4282e61091..58112a3fbe2296 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc @@ -16,10 +16,17 @@ limitations under the License. #include +#include "absl/log/log.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/modular_filesystem.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tsl/platform/errors.h" namespace tensorflow { From 4cfe8de06f05e316ead960592f48c4ebc9ad9c11 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 21:59:11 -0700 Subject: [PATCH 164/483] Automated Code Change PiperOrigin-RevId: 678070812 --- tensorflow/core/tfrt/common/BUILD | 1 + tensorflow/core/tfrt/common/pjrt_state_test.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/core/tfrt/common/BUILD b/tensorflow/core/tfrt/common/BUILD index eb9724f707ec5f..9a2f50d932b5a7 100644 --- a/tensorflow/core/tfrt/common/BUILD +++ b/tensorflow/core/tfrt/common/BUILD @@ -176,6 +176,7 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt/cpu:cpu_client", ], diff --git a/tensorflow/core/tfrt/common/pjrt_state_test.cc b/tensorflow/core/tfrt/common/pjrt_state_test.cc index 03dcdb7c8b9c23..c56b85eb91cd24 100644 --- a/tensorflow/core/tfrt/common/pjrt_state_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_state_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" +#include "tsl/protobuf/error_codes.pb.h" namespace { From 55fd78d335e1e1cbc2540ea8eba8a7696f0d105f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 22:23:12 -0700 Subject: [PATCH 165/483] Automated Code Change PiperOrigin-RevId: 678077737 --- .../core/common_runtime/base_collective_executor.cc | 9 +++++---- tensorflow/core/common_runtime/direct_session.cc | 10 +++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 03d3ad3ef34d69..cef1f1a7e2b57b 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -331,11 +331,12 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, // Run on an unbounded work queue that can handle blocking work so as to not // starve executor threads. col_impl->Ref(); - profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync"); + tsl::profiler::TraceMeProducer producer( + "BaseCollectiveExecutor::ExecuteAsync"); RunClosure([col_impl, col_ctx, done_safe, ctx, context_id = producer.GetContextId()]() { core::ScopedUnref unref(col_impl); - profiler::TraceMeConsumer consumer( + tsl::profiler::TraceMeConsumer consumer( [ctx, col_ctx] { string op = profiler::TraceMeOp(ctx->op_kernel().name_view(), ctx->op_kernel().type_string_view()); @@ -367,7 +368,7 @@ void BaseCollectiveExecutor::CompleteParamsAsync( // timeout callback executes, done_safe will become a no-op and the timeout // callback is responsible for invoking done() at the end. const auto is_callback_called = std::make_shared>(false); - int64_t trace_id = profiler::TraceMe::ActivityStart([cp]() { + int64_t trace_id = tsl::profiler::TraceMe::ActivityStart([cp]() { return profiler::TraceMeEncode("CollectiveExecutor::CompleteParams", {{"group_key", cp->group.group_key}, {"group_size", cp->group.group_size}}); @@ -375,7 +376,7 @@ void BaseCollectiveExecutor::CompleteParamsAsync( auto done_safe = [this, is_callback_called, cancel_mgr, trace_id, done](const Status& s) { - profiler::TraceMe::ActivityEnd(trace_id); + tsl::profiler::TraceMe::ActivityEnd(trace_id); bool called = is_callback_called->exchange(true); if (!called) { if (!s.ok() && !IsCancelled(cancel_mgr)) { diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 1cd597c0c21ebb..ecef91df59b923 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -517,7 +517,7 @@ Status DirectSession::RunInternal( RunState run_state(step_id, &devices_); const size_t num_executors = executors_and_keys->items.size(); - profiler::TraceMeProducer activity( + tsl::profiler::TraceMeProducer activity( // To TraceMeConsumers in ExecutorState::Process/Finish. [&] { if (options_.config.experimental().has_session_metadata()) { @@ -535,7 +535,7 @@ Status DirectSession::RunInternal( } }, tsl::profiler::ContextType::kTfExecutor, step_id, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); std::unique_ptr debugger_state; if (!run_options.debug_options().debug_tensor_watch_opts().empty()) { @@ -1488,9 +1488,9 @@ Status DirectSession::CreateExecutors( } Status DirectSession::GetOrCreateExecutors( - gtl::ArraySlice inputs, gtl::ArraySlice outputs, - gtl::ArraySlice target_nodes, ExecutorsAndKeys** executors_and_keys, - RunStateArgs* run_state_args) { + absl::Span inputs, absl::Span outputs, + absl::Span target_nodes, + ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args) { int64_t handle_name_counter_value = -1; if (LogMemory::IsEnabled() || run_state_args->is_partial_run) { handle_name_counter_value = handle_name_counter_.fetch_add(1); From 5d60f7c58fde119b717b1f4d7b23420f21182d15 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Mon, 23 Sep 2024 22:25:14 -0700 Subject: [PATCH 166/483] Support dynamic and other kinds of broadcasting ops in fold_broadcast_to_pass PiperOrigin-RevId: 678078296 --- tensorflow/compiler/mlir/lite/BUILD | 2 +- tensorflow/compiler/mlir/lite/stablehlo/BUILD | 6 +- ...cast_to.mlir => fold_broadcasting_op.mlir} | 46 ++++++++++++- ...o_pass.cc => fold_broadcasting_op_pass.cc} | 66 +++++++++++++------ .../mlir/lite/stablehlo/transforms/passes.h | 6 +- .../compiler/mlir/lite/tf_tfl_passes.cc | 4 +- 6 files changed, 101 insertions(+), 29 deletions(-) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{fold_broadcast_to.mlir => fold_broadcasting_op.mlir} (50%) rename tensorflow/compiler/mlir/lite/stablehlo/transforms/{fold_broadcast_to_pass.cc => fold_broadcasting_op_pass.cc} (78%) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 0f2f7f700450a2..56c88a001d0f86 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1598,7 +1598,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:build_stablehlo_composite", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", "//tensorflow/compiler/mlir/lite/stablehlo:composite_lowering", - "//tensorflow/compiler/mlir/lite/stablehlo:fold_broadcast_to_pass", # buildcleaner: keep + "//tensorflow/compiler/mlir/lite/stablehlo:fold_broadcasting_op_pass", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:lift_callsite_loc_caller", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index a0c3febeead92f..3d34c83ab65274 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -354,9 +354,9 @@ cc_library( ) cc_library( - name = "fold_broadcast_to_pass", + name = "fold_broadcasting_op_pass", srcs = [ - "transforms/fold_broadcast_to_pass.cc", + "transforms/fold_broadcasting_op_pass.cc", ], hdrs = [ "transforms/passes.h", @@ -1000,7 +1000,7 @@ tf_cc_binary( deps = [ ":compose_uniform_quantized_type_pass", ":fold_broadcast_pass", - ":fold_broadcast_to_pass", + ":fold_broadcasting_op_pass", ":fuse_convolution_pass", ":legalize_stablehlo_composite_to_tfl_custom", ":legalize_stablehlo_custom_call_to_composite", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast_to.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcasting_op.mlir similarity index 50% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast_to.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcasting_op.mlir index 9bbfb3dee9313f..f34d239a25ab20 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast_to.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcasting_op.mlir @@ -1,4 +1,4 @@ -// RUN: odml-to-stablehlo-opt %s -fold-broadcast-to-pass -cse -verify-diagnostics | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -fold-broadcasting-op-pass -cse -verify-diagnostics | FileCheck %s // CHECK-LABEL: @broadcast_mul0 func.func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { @@ -45,3 +45,47 @@ func.func @broadcast_batchmatmul(%arg0: tensor<5x30x1024xf32>) -> tensor<5x30x81 return %1 : tensor<5x30x8192xf32> // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) <{adj_x = false, adj_y = false}> : (tensor<5x30x1024xf32>, tensor<1024x8192xf32>) -> tensor<5x30x8192xf32> } + +// CHECK-LABEL: @dym_broadcast_mul0 +func.func @dym_broadcast_mul0(%arg0: tensor, %arg1: tensor<7xf32>) -> tensor { + %0 = "tfl.shape"(%arg0): (tensor) -> tensor<2xi32> + %1 = "tfl.broadcast_to"(%arg1, %0) : (tensor<7xf32>, tensor<2xi32>) -> tensor + %2 = "tfl.mul"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor + func.return %2 : tensor + // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor, tensor<7xf32>) -> tensor +} + +// CHECK-LABEL: @expanding_reshape_mul +func.func @expanding_reshape_mul(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { + %cst = mhlo.constant dense<[1, 7]> : tensor<2xi32> + %0 = "tfl.reshape"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<1x7xf32> + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<1x7xf32>) -> tensor<5x7xf32> + func.return %1 : tensor<5x7xf32> + // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> +} + +// CHECK-LABEL: @squeezing_reshape_mul +func.func @squeezing_reshape_mul(%arg0: tensor<5x7xf32>, %arg1: tensor<1x7xf32>) -> tensor<5x7xf32> { + %cst = mhlo.constant dense<[7]> : tensor<1xi32> + %0 = "tfl.reshape"(%arg1, %cst) : (tensor<1x7xf32>, tensor<1xi32>) -> tensor<7xf32> + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> + func.return %1 : tensor<5x7xf32> + // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<5x7xf32>, tensor<1x7xf32>) -> tensor<5x7xf32> +} + +// CHECK-LABEL: @expanddims_mul +func.func @expanddims_mul(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { + %cst = mhlo.constant dense<1> : tensor + %0 = "tfl.expand_dims"(%arg1, %cst) : (tensor<7xf32>, tensor) -> tensor<1x7xf32> + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<1x7xf32>) -> tensor<5x7xf32> + func.return %1 : tensor<5x7xf32> + // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> +} + +// CHECK-LABEL: @squeeze_mul +func.func @squeeze_mul(%arg0: tensor<5x7xf32>, %arg1: tensor<1x7xf32>) -> tensor<5x7xf32> { + %0 = "tfl.squeeze"(%arg1) {squeeze_dims = [0]} : (tensor<1x7xf32>) -> tensor<7xf32> + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> + func.return %1 : tensor<5x7xf32> + // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<5x7xf32>, tensor<1x7xf32>) -> tensor<5x7xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_to_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcasting_op_pass.cc similarity index 78% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_to_pass.cc rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcasting_op_pass.cc index 204102c9e080a5..f7153e8ecd9564 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_to_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcasting_op_pass.cc @@ -59,7 +59,7 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern { // Determine op with shapes is valid. TODO: @lukeboyer - Move the // `TFL_OperandsHaveSameShapesOrBroadcastableShape` runtime verification trait // into a standard (not runtime verification) trait and change this function to -// use only that interface. Curently there is no way to query derived runtime +// use only that interface. Currently there is no way to query derived runtime // verification traits. bool IsRankSupported(Operation* op) { // These ops have no rank constraints. @@ -75,6 +75,14 @@ bool IsRankSupported(Operation* op) { return llvm::cast(op->getResultTypes()[0]).getRank() <= 4; } +// Returns true when the op may be a broadcasting op. Broadcasting op is not +// limited to TFL::BroadcastToOp, but also other ops that may change the shape +// of a tensor to match the shape of another operand. +bool MayBeBroadcastingOp(Operation* op) { + return op && llvm::isa(op); +} + LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { if (op->hasTrait()) @@ -93,7 +101,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the result shape is fully defined. auto result_type = mlir::dyn_cast_or_null(op->getResultTypes().front()); - if (!result_type || !result_type.hasStaticShape()) return failure(); + if (!result_type) return failure(); if (!IsRankSupported(op)) { return failure(); @@ -102,19 +110,39 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( bool changed = false; for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) { // Check that the i'th operand is a broadcast. - auto broadcast = llvm::dyn_cast_or_null( - op->getOpOperand(i).get().getDefiningOp()); - if (!broadcast) continue; + auto broadcast = op->getOpOperand(i).get().getDefiningOp(); + if (!broadcast || !MayBeBroadcastingOp(broadcast)) { + continue; + } + + auto broadcast_input = broadcast->getOperand(0); + if (!broadcast_input) { + continue; + } // Check that the operand of the broadcast has fully defined shape. - auto broadcast_arg_type = mlir::dyn_cast_or_null( - broadcast.getInput().getType()); - if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; + // Fusing dynamic broadcasting op (non static broadcast_arg_type shape) + // is experimental and theoretically unsafe, because checking equality on + // unknown dimensions in broadcasted shape is not reliable. + // TODO: Full dynamism support with symbolic shape comparisons. + auto broadcast_arg_type = + mlir::cast(broadcast_input.getType()); + if (!broadcast_arg_type) { + continue; + } // Check that the other argument has fully defined shape. - auto argument_type = mlir::dyn_cast_or_null( - op->getOpOperand(1 - i).get().getType()); - if (!argument_type || !argument_type.hasStaticShape()) continue; + auto argument = op->getOperand(1 - i); + auto argument_type = mlir::cast(argument.getType()); + // When two operands are both dynamic broadcasting op, it has high chance + // that the model is doing explicitly broadcasting. In this case, removing + // either broadcasting op may result in incorrect output shape in the + // runtime. + // TODO: Full dynamism support with symbolic shape comparisons. + if (!argument_type || (!broadcast_arg_type.hasStaticShape() && + MayBeBroadcastingOp(argument.getDefiningOp()))) { + continue; + } // Get the unbroadcasted shapes in the operand order. std::array, 2> operand_shapes; @@ -134,7 +162,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Update the operand of the op to be the operand of the broadcast. rewriter.modifyOpInPlace( - op, [&]() { op->getOpOperand(i).set(broadcast.getInput()); }); + op, [&]() { op->getOpOperand(i).set(broadcast->getOperand(0)); }); changed = true; } return success(changed); @@ -210,12 +238,12 @@ LogicalResult ConvertResultsBroadcastableBatchMatMulShapeOp::RewriteOp( } // namespace -class FoldBroadcastToPass - : public PassWrapper> { +class FoldBroadcastingOpPass + : public PassWrapper> { public: - StringRef getArgument() const final { return "fold-broadcast-to-pass"; } + StringRef getArgument() const final { return "fold-broadcasting-op-pass"; } StringRef getDescription() const final { - return "Folds tfl.BroadcastTo nodes with subsequent ops"; + return "Folds TFL broadcasting/shape changing nodes with subsequent ops"; } void runOnOperation() override { @@ -233,11 +261,11 @@ class FoldBroadcastToPass }; // TODO(weiyiw): Consider having this as canonicalization? -std::unique_ptr> CreateFoldBroadcastToPass() { - return std::make_unique(); +std::unique_ptr> CreateFoldBroadcastingOpPass() { + return std::make_unique(); } -static PassRegistration pass; +static PassRegistration pass; } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h index 331505e2445e87..defaf17af16a97 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h @@ -70,9 +70,9 @@ std::unique_ptr> CreateLegalizeChloToTflPass(); // Rewrites MHLO in preparation for tflite legalization. std::unique_ptr> CreatePrepareHloPass(); -// Folds tfl.BroadcastTo nodes with subsequent ops that supports implicit -// broadcasting. -std::unique_ptr> CreateFoldBroadcastToPass(); +// Folds TFL broadcasting/shape changing nodes with subsequent ops that +// supports implicit broadcasting. +std::unique_ptr> CreateFoldBroadcastingOpPass(); // Adds the HLO to TF rewrite patterns to the specified pattern list. void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 3c46ef3532aecf..d415c9a63d8473 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -287,12 +287,12 @@ void AddPostQuantizationStableHloToTfPasses( // TODO: b/354280588 - Rewrite this pass into a pattern in PrepareHloPass. pass_manager.addPass(mlir::odml::CreateUnfoldSplatConstantPass()); pass_manager.addPass(mlir::odml::CreateLegalizeHloToTfLitePass()); - // Folds tfl.BroadcastTo ops with subsequent ops if they have built in + // Folds TFL broadcasting ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TFL // legalization, otherwise the newly generated TFL broadcast ops can fold // and materialize the weights. pass_manager.addNestedPass( - mlir::odml::CreateFoldBroadcastToPass()); + mlir::odml::CreateFoldBroadcastingOpPass()); } // folds tf.BroadcastTo ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TF From 01823faa0611beee7ef3e61bad5da54ac8ce082a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 23:06:26 -0700 Subject: [PATCH 167/483] Automated Code Change PiperOrigin-RevId: 678089465 --- tensorflow/core/lib/strings/proto_text_util.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/core/lib/strings/proto_text_util.h b/tensorflow/core/lib/strings/proto_text_util.h index 44559578da02e9..af288e0738011f 100644 --- a/tensorflow/core/lib/strings/proto_text_util.h +++ b/tensorflow/core/lib/strings/proto_text_util.h @@ -85,8 +85,7 @@ class ProtoTextOutput { // Appends a string value, like my_field: "abc123". void AppendString(const char field_name[], const string& value) { - AppendFieldAndValue( - field_name, StrCat("\"", ::tensorflow::str_util::CEscape(value), "\"")); + AppendFieldAndValue(field_name, StrCat("\"", absl::CEscape(value), "\"")); } // Appends a string value, like my_field: "abc123", but only if value is not From 5d7b872381133d64949dd783af4a38c3052ee956 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 23:08:57 -0700 Subject: [PATCH 168/483] Automated Code Change PiperOrigin-RevId: 678090109 --- tensorflow/core/grappler/optimizers/data/fusion_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.h b/tensorflow/core/grappler/optimizers/data/fusion_utils.h index 19b7002dcd8562..f7da097d4b1b09 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.h +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.h @@ -34,7 +34,7 @@ using SetFunctionSignatureFn = std::function; -using StringCollection = gtl::InlinedVector; +using StringCollection = absl::InlinedVector; // These functions are invoked with nodes from second function that were // previously taking arguments as input. The `arg_num` tells which From 20e44bf316ce26129cb0107da225928be4663c83 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 23:19:38 -0700 Subject: [PATCH 169/483] Automated Code Change PiperOrigin-RevId: 678092989 --- tensorflow/core/profiler/BUILD | 5 +++++ tensorflow/core/profiler/tfprof_options.cc | 6 +++++- tensorflow/core/profiler/tfprof_options.h | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index 8d2d6f030b670a..c5477810bf0fde 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -147,6 +147,11 @@ cc_library( deps = [ ":protos_all_cc", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_types_hdr", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], diff --git a/tensorflow/core/profiler/tfprof_options.cc b/tensorflow/core/profiler/tfprof_options.cc index 0f4ec58c540236..595d4190997baa 100644 --- a/tensorflow/core/profiler/tfprof_options.cc +++ b/tensorflow/core/profiler/tfprof_options.cc @@ -15,9 +15,13 @@ limitations under the License. #include "tensorflow/core/profiler/tfprof_options.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/tfprof_options.pb.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/tfprof_options.h b/tensorflow/core/profiler/tfprof_options.h index d8704dd736bab7..c1f13ebf355b27 100644 --- a/tensorflow/core/profiler/tfprof_options.h +++ b/tensorflow/core/profiler/tfprof_options.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tfprof { From 565d222a34b3df7d91a225389cdebac0d8d77959 Mon Sep 17 00:00:00 2001 From: Harsha H S Date: Mon, 23 Sep 2024 23:33:47 -0700 Subject: [PATCH 170/483] PR #17507: [ROCm] Fix build break due to 1c21b0bba Imported from GitHub PR https://github.com/openxla/xla/pull/17507 Copybara import of the project: -- f3aca25a7f7a8e10e7534501c76914a95cae50e7 by Harsha HS : [ROCm] Fix build break due to 1c21b0bba Merging this change closes #17507 PiperOrigin-RevId: 678097290 --- third_party/xla/xla/stream_executor/rocm/rocm_runtime.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h index 83148914b93977..b1a197fe0643bd 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" namespace stream_executor::gpu { From 6de155e82f63ef3cb66a36015ca67d352a229789 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 23:37:15 -0700 Subject: [PATCH 171/483] Automated Code Change PiperOrigin-RevId: 678098311 --- .../stablehlo/transforms/legalize_hlo_conversions/conv_util.h | 1 + .../stablehlo/transforms/legalize_hlo_conversions/custom_call.h | 1 + .../stablehlo/transforms/legalize_hlo_conversions/dot_general.cc | 1 - .../stablehlo/transforms/legalize_hlo_conversions/dot_general.h | 1 + .../transforms/legalize_hlo_conversions/get_dimension_size.h | 1 + .../lite/stablehlo/transforms/legalize_hlo_conversions/iota.h | 1 + .../transforms/legalize_hlo_conversions/op_util_common.h | 1 + .../lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc | 1 + .../lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h | 1 + .../lite/stablehlo/transforms/legalize_hlo_conversions/sort.h | 1 + .../lite/stablehlo/transforms/legalize_hlo_conversions/util.h | 1 + .../lite/stablehlo/transforms/legalize_hlo_conversions/while.h | 1 + 12 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h index ed8b06e036d816..fe9664c13cdccb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h index 679737b9a25fbf..c7c3bddeacfb21 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CUSTOM_CALL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CUSTOM_CALL_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc index 3c139e7dbbcdd3..940c75256b9e75 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc @@ -35,7 +35,6 @@ limitations under the License. #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h index 2ea7c96dfbae08..91df1b63e76a7c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_DOT_GENERAL_H_ #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h index 74081e1e04716e..6cd637303b3a7e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GET_DIMENSION_SIZE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GET_DIMENSION_SIZE_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::odml { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h index a53bdeda2a2097..7d4f76bd3f8b6b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::odml { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h index 3c2c8ae5ced600..16a5c293b0e989 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir::odml { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc index ae784414e4bac0..f237a7168e5660 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h index 82c4f88937d061..3bf03aec97dcfe 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h index 9bbb1f3fde06ab..c293bad98cf4ef 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::odml { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h index 01c619cbbf6178..28661d299e03df 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h index 129b19388821c9..3b3022153b2d43 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_WHILE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_WHILE_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" From 66fa6375247e2b416ae9d24f6a2f4f2b27a4c3d7 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 23 Sep 2024 23:45:24 -0700 Subject: [PATCH 172/483] Remove unused gpu_types dependency from topk_kernel_gpu target This change also uncovered that `:topk_kernel_gpu` was relying on `:gpu_types` for its dependency on CUDA headers. `:topk_kernel_gpu` depends on CUDA headers because clang automatically includes `cuda_runtime.h` in all CUDA compilation units. To fix that properly I'm adding the dependency to `:cuda_headers` inside the `cuda_library` macro. PiperOrigin-RevId: 678100189 --- third_party/gpus/cuda/build_defs.bzl.tpl | 5 ++++- third_party/gpus/cuda/hermetic/BUILD.tpl | 10 ++++++++++ .../tsl/third_party/gpus/cuda/build_defs.bzl.tpl | 5 ++++- .../tsl/third_party/gpus/cuda/hermetic/BUILD.tpl | 10 ++++++++++ third_party/xla/xla/service/gpu/kernels/BUILD | 1 - .../xla/xla/service/gpu/kernels/topk_kernel.cu.h | 1 - 6 files changed, 28 insertions(+), 4 deletions(-) diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index 2faabefe081f4b..6c1b68ffb77bcf 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -149,11 +149,14 @@ def cuda_header_library( **kwargs ) -def cuda_library(copts = [], tags = [],**kwargs): +def cuda_library(copts = [], tags = [], deps = [], **kwargs): """Wrapper over cc_library which adds default CUDA options.""" native.cc_library( copts = cuda_default_copts() + copts, tags = tags + ["gpu"], + deps = deps + if_cuda_is_configured([ + "@local_config_cuda//cuda:implicit_cuda_headers_dependency", + ]), **kwargs ) diff --git a/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/gpus/cuda/hermetic/BUILD.tpl index 5d9a9da3c967d8..da34a336c5a2d4 100644 --- a/third_party/gpus/cuda/hermetic/BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -69,6 +69,16 @@ cc_library( ":nvjitlink_headers"], ) +# This target is needed by the `cuda_library` rule. We can't implicitly +# depend on `:cuda_headers` directly since the user may explicit depend +# on `:cuda_headers` and duplicated dependencies are not allowed in Bazel. +# There is also no good way to deduplicate dependencies, but an alias works +# just fine. +alias( + name = "implicit_cuda_headers_dependency", + actual = ":cuda_headers", +) + cc_library( name = "cudart_static", srcs = ["@cuda_cudart//:static"], diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl index 2faabefe081f4b..6c1b68ffb77bcf 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl @@ -149,11 +149,14 @@ def cuda_header_library( **kwargs ) -def cuda_library(copts = [], tags = [],**kwargs): +def cuda_library(copts = [], tags = [], deps = [], **kwargs): """Wrapper over cc_library which adds default CUDA options.""" native.cc_library( copts = cuda_default_copts() + copts, tags = tags + ["gpu"], + deps = deps + if_cuda_is_configured([ + "@local_config_cuda//cuda:implicit_cuda_headers_dependency", + ]), **kwargs ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl index 5d9a9da3c967d8..da34a336c5a2d4 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -69,6 +69,16 @@ cc_library( ":nvjitlink_headers"], ) +# This target is needed by the `cuda_library` rule. We can't implicitly +# depend on `:cuda_headers` directly since the user may explicit depend +# on `:cuda_headers` and duplicated dependencies are not allowed in Bazel. +# There is also no good way to deduplicate dependencies, but an alias works +# just fine. +alias( + name = "implicit_cuda_headers_dependency", + actual = ":cuda_headers", +) + cc_library( name = "cudart_static", srcs = ["@cuda_cudart//:static"], diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 986ed6486dc821..5c517476917932 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -177,7 +177,6 @@ gpu_kernel_library( compatible_with = [], deps = [ "//xla:types", - "//xla/stream_executor/gpu:gpu_types_header", "//xla/tsl/lib/math:math_util", ], ) diff --git a/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h index c5390649ac9945..5a68b56efd7ef7 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h @@ -25,7 +25,6 @@ limitations under the License. #include #include "xla/service/gpu/kernels/topk_kernel_common.h" -#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/tsl/lib/math/math_util.h" #if GOOGLE_CUDA From e8f9972febd4b60c9d158b8af7aeb700dd550e94 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 23 Sep 2024 23:47:30 -0700 Subject: [PATCH 173/483] Add missing tag to cuda_collectives build target The missing tag was making the dependency violation presubmit check fail. Fixes #17508 PiperOrigin-RevId: 678100814 --- third_party/xla/xla/stream_executor/cuda/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 9b900305327f4f..68c13a736099fb 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -228,6 +228,7 @@ cuda_only_cc_library( cuda_only_cc_library( name = "cuda_collectives", hdrs = ["cuda_collectives.h"], + tags = ["gpu"], deps = if_nccl( [":cuda_collectives_impl"], [":cuda_collectives_stub"], From 4bc72f621fc61ee96ff53d7dab0d554505ec1d12 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Tue, 24 Sep 2024 00:40:01 -0700 Subject: [PATCH 174/483] [XLA:UNSTACKER] Fix a bug in HloUnstacker that causes it to unstack while loops that are not unstackable. The bug is that the HloUnstacker does not check if the root operand of a while loop is gte of a non-loop instruction. We restrict unstackable operands to those that are either gte of while parameter or gte of another nested loop. PiperOrigin-RevId: 678116048 --- third_party/xla/xla/service/hlo_unstacker.cc | 31 +++++++---- .../xla/xla/service/hlo_unstacker_test.cc | 55 +++++++++++++++++++ 2 files changed, 76 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/hlo_unstacker.cc b/third_party/xla/xla/service/hlo_unstacker.cc index 03c6f5c2eb6356..2ca37c0193861c 100644 --- a/third_party/xla/xla/service/hlo_unstacker.cc +++ b/third_party/xla/xla/service/hlo_unstacker.cc @@ -530,19 +530,30 @@ bool CanUnstackWhileOperand(const HloInstruction* while_instr, return false; } - const HloInstruction* root_operand = - while_instr->while_body()->root_instruction()->operand(index); + HloInstruction* root_operand = + while_instr->while_body()->root_instruction()->mutable_operand(index); if (root_operand == nullptr) { return false; } - if (Match(root_operand, match::GetTupleElement(match::While()))) { - VLOG(3) << "Faced a gte originating from loop: " - << root_operand->ToString(); - bool loop_feeding_root_changes_collected = CanUnstackWhileOperand( - root_operand->operand(0), unstacker, root_operand->tuple_index()); - if (!loop_feeding_root_changes_collected) { - VLOG(3) << "Failed: loop " << root_operand->operand(0)->name() - << " output at " << index << " is not unstackable"; + + HloInstruction* gte_operand = nullptr; + // Currently, we only support unstacking of while operands that either: + // 1. Are parameters of the while_body. + // 2. Are get-tuple-elements of another while instruction. + if (Match(root_operand, match::GetTupleElement(match::Op(>e_operand)))) { + if (Match(gte_operand, match::While())) { + VLOG(3) << "Faced a gte originating from loop: " + << root_operand->ToString(); + bool loop_feeding_root_changes_collected = CanUnstackWhileOperand( + root_operand->operand(0), unstacker, root_operand->tuple_index()); + if (!loop_feeding_root_changes_collected) { + VLOG(3) << "Failed: loop " << root_operand->operand(0)->name() + << " output at " << index << " is not unstackable"; + return false; + } + } else if (!Match(gte_operand, match::Parameter().WithParameterNum(0))) { + VLOG(3) << "Failed: root operand of while_body at " << index + << " is not a parameter"; return false; } } diff --git a/third_party/xla/xla/service/hlo_unstacker_test.cc b/third_party/xla/xla/service/hlo_unstacker_test.cc index 3b00f9236a1ae7..590597a99a92c8 100644 --- a/third_party/xla/xla/service/hlo_unstacker_test.cc +++ b/third_party/xla/xla/service/hlo_unstacker_test.cc @@ -99,6 +99,61 @@ TEST_F(UnstackerTest, UnstackDSFusionPattern) { std::nullopt, false)); } +TEST_F(UnstackerTest, NotUnstackDSFusionPattern) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) + } + + %fused_computation.tuple { + %param_0.51117 = s8[3,128,128] parameter(0) + mult = multiply(param_0.51117, param_0.51117) + ROOT out = tuple(param_0.51117, mult) + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + fusion_mult = (s8[3,128,128], s8[3,128,128]) fusion(s8[3,128,128] p1), kind=kLoop, calls=%fused_computation.tuple + mult = s8[3,128,128] get-tuple-element(fusion_mult), index=1 + ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, mult) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + EXPECT_FALSE(unstacked); +} + TEST_F(UnstackerTest, UnstackReduceFusionPattern) { std::string hlo_string = R"( HloModule SimpleLoop From 8bd9ae5d963bcd5559fb45f51f7de22181dfe855 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 00:50:24 -0700 Subject: [PATCH 175/483] Automated Code Change PiperOrigin-RevId: 678118952 --- tensorflow/lite/kernels/internal/BUILD | 4 ++++ tensorflow/lite/kernels/internal/reference/comparisons.cc | 4 ++++ tensorflow/lite/kernels/internal/reference/comparisons.h | 2 ++ 3 files changed, 10 insertions(+) diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 57001a800c5287..2a1f510c131b4a 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -669,7 +669,9 @@ cc_library( copts = tflite_copts(), deps = [ ":common", + ":compatibility", ":types", + "//tensorflow/lite:macros", "//tensorflow/lite/core/c:common", ], ) @@ -784,6 +786,7 @@ cc_library( ":tensor", ":tensor_utils", ":types", + "//tensorflow/lite:macros", "//tensorflow/lite:string_util", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", @@ -881,6 +884,7 @@ cc_library( ":tensor", ":tensor_utils", ":types", + "//tensorflow/lite:macros", "//tensorflow/lite:string_util", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", diff --git a/tensorflow/lite/kernels/internal/reference/comparisons.cc b/tensorflow/lite/kernels/internal/reference/comparisons.cc index 86b4a6af0c0f2e..36ce951ec17536 100644 --- a/tensorflow/lite/kernels/internal/reference/comparisons.cc +++ b/tensorflow/lite/kernels/internal/reference/comparisons.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/comparisons.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" + namespace tflite { namespace reference_ops { diff --git a/tensorflow/lite/kernels/internal/reference/comparisons.h b/tensorflow/lite/kernels/internal/reference/comparisons.h index 366b378c825266..a9f1e42c0a6c94 100644 --- a/tensorflow/lite/kernels/internal/reference/comparisons.h +++ b/tensorflow/lite/kernels/internal/reference/comparisons.h @@ -16,7 +16,9 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { From 2cb88a608d96171d7b6182aa834fbf58202e402e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 01:11:29 -0700 Subject: [PATCH 176/483] IFRT Proxy: Batch array deletes and destructs. With this CL, the proxy-client batches array deletes and destructions before sending them over to the server. For `Array::Delete()`, the caller always gets returned an OK status. Any errors returned by the server are printed as a warning at the client but otherwise ignored. PiperOrigin-RevId: 678125382 --- .../xla/xla/python/ifrt_proxy/client/BUILD | 31 +++ .../xla/xla/python/ifrt_proxy/client/array.cc | 14 +- .../python/ifrt_proxy/client/array_test.cc | 17 +- .../python/ifrt_proxy/client/rpc_helper.cc | 205 +++++++++++++++++- .../xla/python/ifrt_proxy/client/rpc_helper.h | 18 +- .../ifrt_proxy/client/rpc_helper_test.cc | 155 +++++++++++++ .../xla/python/ifrt_proxy/client/version.h | 2 +- .../xla/xla/python/ifrt_proxy/common/BUILD | 14 ++ .../xla/python/ifrt_proxy/common/VERSION.md | 6 + .../ifrt_proxy/common/ifrt_service.proto | 10 +- .../python/ifrt_proxy/common/test_utils.cc | 83 +++++++ .../xla/python/ifrt_proxy/common/test_utils.h | 18 ++ .../integration_tests/mock_array_test.cc | 43 ---- .../xla/xla/python/ifrt_proxy/server/BUILD | 1 - .../python/ifrt_proxy/server/ifrt_backend.cc | 55 ++++- .../ifrt_proxy/server/ifrt_backend_test.cc | 65 ++++-- .../xla/python/ifrt_proxy/server/version.h | 2 +- 17 files changed, 650 insertions(+), 89 deletions(-) create mode 100644 third_party/xla/xla/python/ifrt_proxy/client/rpc_helper_test.cc create mode 100644 third_party/xla/xla/python/ifrt_proxy/common/test_utils.cc diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 1ca0e1527fbb26..8cf33e364e988f 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -97,13 +97,20 @@ cc_library( ":host_buffer", "//xla/python/ifrt", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:test_utils", + "//xla/python/ifrt_proxy/common:types", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/profiler/lib:traceme", @@ -112,6 +119,28 @@ cc_library( ] + if_google(["@com_google_absl//absl/types:source_location"]), ) +ifrt_proxy_cc_test( + name = "rpc_helper_test", + srcs = ["rpc_helper_test.cc"], + deps = [ + ":client_session", + ":mock_client_session", + ":rpc_helper", + ":version", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:test_utils", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "client", srcs = ["client.cc"], @@ -240,6 +269,8 @@ ifrt_proxy_cc_test( "//xla/python/ifrt_proxy/common:types_proto_cc", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:protobuf", diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.cc b/third_party/xla/xla/python/ifrt_proxy/client/array.cc index b34d45299f1be4..ca531be4d9b901 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.cc @@ -100,8 +100,13 @@ Array::MakeArrayFromHostBuffer( } void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) { + if (rpc_helper->version().protocol_version() >= 5) { + rpc_helper->Batch(RpcHelper::kDestructArray, handle); + return; + } + auto req = std::make_unique(); - req->set_array_handle(handle.handle); + req->set_array_handle_deprecated(handle.handle); rpc_helper->DestructArray(std::move(req)) .OnReady( [](absl::StatusOr> response) { @@ -126,8 +131,13 @@ Future<> Array::GetReadyFuture() const { } Future<> Array::Delete() { + if (rpc_helper_->version().protocol_version() >= 5) { + rpc_helper_->Batch(RpcHelper::kDeleteArray, handle_); + return Future<>(absl::OkStatus()); + } + auto req = std::make_unique(); - req->set_array_handle(handle_.handle); + req->set_array_handle_deprecated(handle_.handle); absl::StatusOr> response = rpc_helper_->DeleteArray(std::move(req)).Await(); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc index f069dd959f662e..140c74b0311138 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc @@ -20,6 +20,8 @@ #include #include #include "absl/status/status.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/dtype.h" @@ -60,7 +62,7 @@ namespace { IfrtProxyVersion Version() { IfrtProxyVersion version; - version.set_protocol_version(kClientMinVersion); + version.set_protocol_version(kClientMaxVersion); return version; } @@ -88,17 +90,26 @@ class ArrayTest : public ::testing::Test { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) TEST_F(ArrayTest, Destruction) { - IfrtResponse response; + // Destruction may not happen immediately because of batching at the + // client-side. This test waits until destruction happens. + absl::Notification destructed; EXPECT_CALL( *session_, Enqueue(Pointee(Partially(EquivToProto(R"pb(destruct_array_request { array_handle: 1234 })pb"))))) - .WillOnce(MockClientSessionReturnResponse(response)); + .WillOnce([&](std::unique_ptr request) + -> Future { + destructed.Notify(); + auto result = std::make_shared(); + return Future(result); + }); MockClient client; tsl::MakeRef(&client, rpc_helper_, DType(DType::Kind::kBF16), Shape({}), /*sharding=*/nullptr, ArrayHandle{1234}); + + ASSERT_TRUE(destructed.WaitForNotificationWithTimeout(absl::Seconds(10))); } #endif diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index 409978966a6f4e..b2631019a6ac5e 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -14,22 +14,34 @@ #include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include #include #include +#include #include #include +#include +#include "absl/base/thread_annotations.h" +#include "absl/functional/bind_front.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/test_utils.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" #include "tsl/profiler/utils/xplane_schema.h" @@ -42,6 +54,8 @@ namespace { using ::tsl::profiler::XFlow; +constexpr absl::Duration kPeriodicFlushInterval = absl::Microseconds(50); + // XFlowHelper makes it easier to create trace spans with a flow between them. // Typical usage: // @@ -99,12 +113,184 @@ class XFlowHelper { const absl::string_view name_; }; +// Thread-safe data structure for holding batched operations. +class BatchedOps { + public: + using BatchOperation = RpcHelper::BatchOperation; + + void Add(BatchOperation op, ArrayHandle handle) { + absl::MutexLock l(&mu_); + batched_[op].push_back(handle); + } + + struct IfrtRequests { + std::unique_ptr delete_req; + std::unique_ptr destruct_req; + }; + + IfrtRequests Consume() { + IfrtRequests result; + absl::MutexLock l(&mu_); + if (!batched_[BatchOperation::kDeleteArray].empty()) { + result.delete_req = std::make_unique(); + for (const auto& arr_handle : batched_[BatchOperation::kDeleteArray]) { + result.delete_req->mutable_delete_array_request()->add_array_handle( + arr_handle.handle); + } + batched_[BatchOperation::kDeleteArray].clear(); + } + if (!batched_[BatchOperation::kDestructArray].empty()) { + result.destruct_req = std::make_unique(); + for (const auto& arr_handle : batched_[BatchOperation::kDestructArray]) { + result.destruct_req->mutable_destruct_array_request()->add_array_handle( + arr_handle.handle); + } + batched_[BatchOperation::kDestructArray].clear(); + } + return result; + } + + private: + absl::Mutex mu_; + std::array, BatchOperation::kSentinelDoNotUse> + batched_ ABSL_GUARDED_BY(mu_); +}; + } // namespace +// Batches any requested operations and flushes them periodically in the +// background, and allows sending other requested operations immediately. +// Immediate operations are guaranteed to be sent after all previously enqueued +// batched operations. +class RpcHelper::Batcher { + public: + explicit Batcher(std::shared_ptr session) + : session_(std::move(session)) { + thread_pool_.emplace(tsl::Env::Default(), "IfrtProxyRpcHelperBatcher", + /*num_threads=*/1); + thread_pool_->Schedule(absl::bind_front(&Batcher::PeriodicFlusher, this)); + } + + // Sends the given request immediately after sending any batched operations + // that have been previously enqueued. + Future Immediate( + std::unique_ptr request) { + absl::MutexLock l(&mu_); + if (finished_) { + LOG(WARNING) << "After RpcHelper::Finish(): " << request->DebugString(); + return Future( + absl::FailedPreconditionError("RpcHelper::Finish() already called.")); + } + Flush(); + return session_->Enqueue(std::move(request)); + } + + // Enqueues an operation to be sent later. Guaranteed to not be blocked by the + // underlying transport. + void Batch(BatchOperation op, ArrayHandle handle) { + batched_.Add(op, handle); + } + + // Asks the underlying transport to terminate. + void Finish(absl::Status s) { + { + absl::MutexLock l(&mu_); + finished_ = true; + auto remaining = batched_.Consume(); + if (remaining.delete_req != nullptr) { + LOG(WARNING) << "RpcHelper::Batch: Finish() called while there are " + "still batched delete operations"; + } + if (remaining.destruct_req != nullptr) { + LOG(WARNING) << "RpcHelper::Batch: Finish() called while there are " + "still batched destruct operations"; + } + } + thread_pool_.reset(); + session_->Finish(s); + } + + private: + void PeriodicFlusher() { + while (true) { + absl::SleepFor(kPeriodicFlushInterval); + absl::MutexLock l(&mu_); + if (finished_) { + return; + } + { + bool periodic_flush_paused = false; + TestHookCall(TestHookName::kRpcBatcherPausePeriodicFlush, + &periodic_flush_paused); + if (periodic_flush_paused) { + continue; + } + } + tsl::profiler::TraceMe traceme("proxy_periodic_flush"); + Flush(); + } + } + + // Sends all enqueued batched operations. + void Flush() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto reqs = batched_.Consume(); + if (reqs.delete_req != nullptr) { + XFlowHelper x_flow_helper("batch_delete"); + auto traceme = x_flow_helper.Span(); + session_->Enqueue(std::move(reqs.delete_req)) + .OnReady( + absl::bind_front(HandleBatchResponse, session_, x_flow_helper)); + } + if (reqs.destruct_req != nullptr) { + XFlowHelper x_flow_helper("batch_destruct"); + auto traceme = x_flow_helper.Span(); + session_->Enqueue(std::move(reqs.destruct_req)) + .OnReady( + absl::bind_front(HandleBatchResponse, session_, x_flow_helper)); + } + } + + // Handles a response from the server of a previous batched operation; + // bad responses are logged but otherwise ignored. The method is static since + // it can be called in the background after RpcHelper::Batcher is destroyed. + static void HandleBatchResponse( + std::shared_ptr session, XFlowHelper x_flow_helper, + absl::StatusOr> r) { + if (!r.ok()) { + x_flow_helper.InstantActivity(); + LOG(WARNING) << "Batched response from ifrt proxy server: " << r.status(); + return; + } + if (r.value()->has_delete_array_response()) { + auto traceme = x_flow_helper.Span(); + auto ifrt_req = std::make_unique(); + ifrt_req->mutable_check_future_request()->set_future_handle( + r.value()->delete_array_response().deletion_future_handle()); + session->Enqueue(std::move(ifrt_req)) + .OnReady( + absl::bind_front(HandleBatchResponse, session, x_flow_helper)); + } else if (r.value()->has_destruct_array_response() || + r.value()->has_check_future_response()) { + x_flow_helper.InstantActivity(); + } else { + LOG(ERROR) << "Unrecognized response from server for batched request: " + << (*r)->DebugString(); + } + } + + const std::shared_ptr session_; + + BatchedOps batched_; + + absl::Mutex mu_; + bool finished_ ABSL_GUARDED_BY(mu_) = false; + std::optional thread_pool_; +}; + // DoRpc is a templated function that implements the logic of all RPC-wrapping // functions of `RpcHelper`, such as `RpcHelper::MakeArrayFromHostBuffer()`. template -Future> DoRpc(ClientSession* session, +Future> DoRpc(RpcHelper::Batcher* batcher, void (IfrtRequest::*set_req)(Req*), Resp* (IfrtResponse::*get_resp)(), bool (IfrtResponse::*has_resp)() const, @@ -168,7 +354,7 @@ Future> DoRpc(ClientSession* session, std::make_shared(*std::move((response.get()->*get_resp)()))); } }; - session->Enqueue(std::move(ifrt_req)).OnReady(on_ready); + batcher->Immediate(std::move(ifrt_req)).OnReady(on_ready); return Future>(promise); } @@ -177,7 +363,7 @@ Future> DoRpc(ClientSession* session, RpcHelper::ResponseFuture RpcHelper::METHOD( \ std::unique_ptr req) { \ return DoRpc( \ - session_.get(), &IfrtRequest::set_allocated_##PROPERTY##_request, \ + batcher_.get(), &IfrtRequest::set_allocated_##PROPERTY##_request, \ &IfrtResponse::mutable_##PROPERTY##_response, \ &IfrtResponse::has_##PROPERTY##_response, std::move(req), #PROPERTY); \ } @@ -220,8 +406,19 @@ Future<> RpcHelper::CheckFuture(uint64_t handle) { return Future<>(std::move(promise)); } +RpcHelper::RpcHelper(IfrtProxyVersion version, + std::shared_ptr session) + : batcher_(std::make_unique(std::move(session))), + version_(std::move(version)) {} + +RpcHelper::~RpcHelper() { Disconnect(); } + +void RpcHelper::Batch(BatchOperation op, ArrayHandle handle) { + return batcher_->Batch(op, handle); +} + void RpcHelper::Disconnect() { - session_->Finish(absl::CancelledError("Disconnected by client")); + batcher_->Finish(absl::CancelledError("Disconnected by client")); } } // namespace proxy diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h index d6eb5b1fcd2c58..fc88c22756502d 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h @@ -28,6 +28,7 @@ #include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/client/host_buffer.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" namespace xla { namespace ifrt { @@ -41,14 +42,13 @@ namespace proxy { // specify the necessary dependency. class RpcHelper { public: - RpcHelper(IfrtProxyVersion version, std::shared_ptr session) - : version_(std::move(version)), session_(std::move(session)) {} + RpcHelper(IfrtProxyVersion version, std::shared_ptr session); void Disconnect(); RpcHelper(const RpcHelper&) = delete; RpcHelper& operator=(const RpcHelper&) = delete; - ~RpcHelper() { Disconnect(); } + ~RpcHelper(); // IFRT Proxy version negotiated between the client and the server. const IfrtProxyVersion& version() const { return version_; } @@ -69,6 +69,15 @@ class RpcHelper { template using ResponseFuture = Future>; + class Batcher; + enum BatchOperation { kDeleteArray, kDestructArray, kSentinelDoNotUse }; + + // Adds the given operation to an impending batch of operations and returns + // immediately. The batch of operation is sent later (as a single logical + // RPC). The RPC is guaranteed to be sent before any unbatched RPCs resulting + // from the wrapper functions below. + void Batch(BatchOperation op, ArrayHandle handle); + // Wrapper function for various logical RPCs defined in ifrt_service.proto. // Whenever the RPC finishes, `on_done` will be called with the result or the // return status. `on_done` can be called with various locks held and should @@ -135,8 +144,9 @@ class RpcHelper { Future<> CheckFuture(uint64_t handle); private: + const std::unique_ptr batcher_; + const IfrtProxyVersion version_; - const std::shared_ptr session_; std::shared_ptr host_buffer_store_; absl::Mutex mu_; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper_test.cc new file mode 100644 index 00000000000000..36adbabd3fafb4 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper_test.cc @@ -0,0 +1,155 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/rpc_helper.h" + +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/test_utils.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "tsl/platform/test.h" + +using ::testing::_; +using ::testing::UnorderedElementsAre; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +constexpr absl::Duration kMaxFlushTimeout = absl::Seconds(10); + +void PausePeriodicFlushes() { + // We want to (a) return 'paused=true' whenever the flusher thread tries to + // find out whether flushing has been paused, and (b) wait for any ongoing + // background flushes to complete. To achieve (b), we wait until the flusher + // thread asks for the value of `paused` at least once. + struct AtomicBool { + absl::Mutex mu; + bool b = false; + }; + + auto called_at_least_once = std::make_shared(); + auto periodic_flusher_pause_hook = [called_at_least_once](bool* paused) { + *paused = true; + absl::MutexLock l(&called_at_least_once->mu); + called_at_least_once->b = true; + }; + TestHookSet(TestHookName::kRpcBatcherPausePeriodicFlush, + std::move(periodic_flusher_pause_hook)); + + absl::MutexLock l(&called_at_least_once->mu); + CHECK(called_at_least_once->mu.AwaitWithTimeout( + absl::Condition(&called_at_least_once->b), kMaxFlushTimeout)); +} + +void ResumePeriodicFlushes() { + TestHookClear(TestHookName::kRpcBatcherPausePeriodicFlush); +} + +class RpcHelperTest : public ::testing::Test { + public: + RpcHelperTest() : requests_(kMaxFlushTimeout) { + session_ = std::make_shared(); + IfrtProxyVersion version; + version.set_protocol_version(kClientMaxVersion); + rpc_helper_ = std::make_shared(version, session_); + EXPECT_CALL(*session_, Finish(_)).Times(1); + ON_CALL(*session_, Enqueue) + .WillByDefault([this](std::unique_ptr req) { + requests_.Push(std::move(req)); + return Future( + absl::InternalError("Fake error response")); + }); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + TestQueue> requests_; +}; + +TEST_F(RpcHelperTest, BatchedPeriodicFlush) { + PausePeriodicFlushes(); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{1}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{2}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{3}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{4}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{9}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{8}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{7}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{6}); + ResumePeriodicFlushes(); + + auto delete_req = requests_.Pop(); + auto destruct_req = requests_.Pop(); + + if (destruct_req->has_delete_array_request()) { + destruct_req.swap(delete_req); + } + + EXPECT_THAT(destruct_req->destruct_array_request().array_handle(), + UnorderedElementsAre(1, 3, 9, 7)); + EXPECT_THAT(delete_req->delete_array_request().array_handle(), + UnorderedElementsAre(2, 4, 8, 6)); +} + +TEST_F(RpcHelperTest, BatchedNoPeriodicFlush) { + PausePeriodicFlushes(); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{1}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{2}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{3}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{4}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{9}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{8}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{7}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{6}); + + // Send some non-batched request, which should flush all the batched requests. + { + auto dummy_request = std::make_unique(); + dummy_request->set_future_handle(1); + rpc_helper_->CheckFuture(std::move(dummy_request)); + requests_.AllowNonEmptyDestruction(/*allow=*/true); + } + + auto delete_req = requests_.Pop(); + auto destruct_req = requests_.Pop(); + + if (destruct_req->has_delete_array_request()) { + destruct_req.swap(delete_req); + } + + EXPECT_THAT(destruct_req->destruct_array_request().array_handle(), + UnorderedElementsAre(1, 3, 9, 7)); + EXPECT_THAT(delete_req->delete_array_request().array_handle(), + UnorderedElementsAre(2, 4, 8, 6)); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/version.h b/third_party/xla/xla/python/ifrt_proxy/client/version.h index 13c753ee9c5d61..e590ed6fff6431 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/version.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/version.h @@ -24,7 +24,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kClientMinVersion = 3; -inline constexpr int kClientMaxVersion = 4; +inline constexpr int kClientMaxVersion = 5; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) } // namespace proxy diff --git a/third_party/xla/xla/python/ifrt_proxy/common/BUILD b/third_party/xla/xla/python/ifrt_proxy/common/BUILD index 7e30c171af6249..724e22bb0659c4 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/common/BUILD @@ -182,9 +182,12 @@ cc_library( cc_library( name = "test_utils", + srcs = ["test_utils.cc"], hdrs = ["test_utils.h"], deps = [ "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/debugging:leak_check", "@com_google_absl//absl/log:check", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", @@ -198,4 +201,15 @@ cc_library( # parse_tests = False, # visibility = ["//visibility:private"], # ) +# +# bzl_library( +# name = "ifrt_proxy_google_bzl", +# srcs = ["ifrt_proxy.google.bzl"], +# parse_tests = False, +# visibility = ["//visibility:private"], +# deps = [ +# "//devtools/build_cleaner/skylark:build_defs_lib", +# "//xla:xla_bzl", +# ], +# ) # copybara:uncomment_end diff --git a/third_party/xla/xla/python/ifrt_proxy/common/VERSION.md b/third_party/xla/xla/python/ifrt_proxy/common/VERSION.md index 4166a27daf9ca7..cce8633c0d2318 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/VERSION.md +++ b/third_party/xla/xla/python/ifrt_proxy/common/VERSION.md @@ -24,3 +24,9 @@ * Changes: * Changed the serialization of client and device attributes to use `xla.ifrt.AttributeMapProto` instead of `map`. +## Version 4 + +* Added date: 2024-09-20. +* Changes: + * Batch array deletions and destruction on client before sending to server. + diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 0d73b688c0ee3d..3a047542402488 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -331,7 +331,10 @@ message FullyReplicatedShardResponse { // Deletes the given Array. Response contains the handle for a Future that // becomes ready when the deletion completes. message DeleteArrayRequest { - fixed64 array_handle = 1; + // TODO(b/296144873): Remove after compatibility window. + optional fixed64 array_handle_deprecated = 1 [deprecated = true]; + + repeated fixed64 array_handle = 2; } message DeleteArrayResponse { fixed64 deletion_future_handle = 1; @@ -345,7 +348,10 @@ message IsArrayDeletedResponse { } message DestructArrayRequest { - fixed64 array_handle = 1; + // TODO(b/296144873): Remove after compatibility window. + optional fixed64 array_handle_deprecated = 1 [deprecated = true]; + + repeated fixed64 array_handle = 2; } message DestructArrayResponse {} diff --git a/third_party/xla/xla/python/ifrt_proxy/common/test_utils.cc b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.cc new file mode 100644 index 00000000000000..eed9fcea24e76e --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt_proxy/common/test_utils.h" + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/debugging/leak_check.h" +#include "absl/synchronization/mutex.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +class Overrides { + public: + void Set(TestHookName h, std::function fn) { + absl::MutexLock l(&mu_); + overrides_[h] = std::move(fn); + } + + void Clear(TestHookName h) { + absl::MutexLock l(&mu_); + overrides_.erase(h); + } + + void Call(TestHookName h, bool* param1) { + absl::MutexLock l(&mu_); + const auto it = overrides_.find(h); + if (it != overrides_.end()) { + it->second(param1); + } + } + + private: + absl::Mutex mu_; + absl::flat_hash_map> overrides_ + ABSL_GUARDED_BY(mu_); +}; + +Overrides* overrides() { + // Declaring a global absl::NoDestructor is easier, but as of Sep + // 2024, NoDestructor<> was not yet available in the version of absl linked + // into TSL. + static Overrides* result = []() { + auto* result = new Overrides; + absl::IgnoreLeak(result); + return result; + }(); + return result; +} + +}; // namespace + +void TestHookSet(TestHookName h, std::function fn) { + overrides()->Set(h, std::move(fn)); +} +void TestHookClear(TestHookName h) { overrides()->Clear(h); } + +void TestHookCall(TestHookName h, bool* param1) { + overrides()->Call(h, param1); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h index 8ecae77206529b..002394fc6074fb 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h +++ b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_PYTHON_IFRT_PROXY_COMMON_TEST_UTILS_H_ #include +#include #include #include @@ -89,6 +90,23 @@ class TestQueue { bool allow_non_empty_destruction_ ABSL_GUARDED_BY(mu_) = false; }; +// TestHook provides a lightweight mechanism to modify the behavior of +// production code from tests. +// TODO(b/266635130): Extend for more hook types (as of Sep 2023, only allows +// `void(bool*)`) and make more lightweight. +enum class TestHookName { + kRpcBatcherPausePeriodicFlush, +}; + +// Allows test code to override the default noop behavior for hook `h`. +void TestHookSet(TestHookName h, std::function fn); + +// Resets hook `h` to the default noop behavior. +void TestHookClear(TestHookName h); + +// Calls hook `h` if it has been overridden by test setup; noop otherwise. +void TestHookCall(TestHookName h, bool* param1); + } // namespace proxy } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc b/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc index ae8c86855662b8..6545ebbab3e8e7 100644 --- a/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc @@ -196,49 +196,6 @@ TEST_F(MockArrayTest, ReadyFuturePropagatesError) { StatusIs(kInternal)); } -TEST_F(MockArrayTest, DeletionFutureWaitsUntilDeleted) { - TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); - - tsl::thread::ThreadPool threads(tsl::Env::Default(), "t", /*num_threads=*/1); - absl::Notification wait_ready; - - EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { - // TODO(b/266635130): Write a version of this testcase where the Delete() - // call of the MockArray blocks on `wait_ready`, instead of the Future it - // returns being blocked on `wait_ready`. That version of the testcase does - // not currently work since both the client and the server synchronously - // block until the MockArray's Delete() returns. - auto promise = Future<>::CreatePromise(); - threads.Schedule([&, promise]() mutable { - wait_ready.WaitForNotification(); - promise.Set(arr.backend_array->delegated()->Delete().Await()); - }); - return Future<>(promise); - }); - - EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); - auto deleted_future = arr.proxy_client_array->Delete(); - - absl::SleepFor(kSomeTime); - EXPECT_FALSE(deleted_future.IsReady()); - EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); - - wait_ready.Notify(); - EXPECT_THAT(deleted_future.Await(), IsOk()); - EXPECT_TRUE(arr.proxy_client_array->IsDeleted()); -} - -TEST_F(MockArrayTest, DeletionPropagatesError) { - TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); - - EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { - return Future<>(absl::InternalError("testing")); - }); - - EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); - EXPECT_THAT(arr.proxy_client_array->Delete().Await(), StatusIs(kInternal)); -} - TEST_F(MockArrayTest, CopyToHostFutureWaitsUntilCopied) { TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); diff --git a/third_party/xla/xla/python/ifrt_proxy/server/BUILD b/third_party/xla/xla/python/ifrt_proxy/server/BUILD index 8484fee04b7e10..d426fd428fe62c 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/server/BUILD @@ -178,7 +178,6 @@ ifrt_proxy_cc_test( "//xla/python/ifrt:serdes", "//xla/python/ifrt:sharding_serdes", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", - "//xla/python/ifrt_proxy/common:types", "//xla/python/ifrt_proxy/common:types_proto_cc", "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index 040c0444cbd20f..b7f551e2716110 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -35,6 +35,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -206,6 +207,8 @@ Future IfrtBackend::Process( return Future( HandleGetDefaultDeviceAssignmentRequest(std::move(request))); default: + LOG(ERROR) << "Got unimplemented request type: " + << request->DebugString(); return Future(absl::UnimplementedError(absl::StrCat( "Got unimplemented request type: ", request->request_case()))); } @@ -764,14 +767,32 @@ IfrtBackend::HandleFullyReplicatedShardRequest( absl::StatusOr IfrtBackend::HandleDeleteArrayRequest(std::unique_ptr request) { - TF_ASSIGN_OR_RETURN(auto array, - GetArray(request->delete_array_request().array_handle())); + std::vector bad_handles; + std::vector> deletion_futures; + + auto delete_handle = [&](uint64_t handle) { + auto array = GetArray(handle); + if (array.ok()) { + deletion_futures.push_back(array.value()->Delete()); + } else { + deletion_futures.push_back(Future<>(array.status())); + } + }; + + if (request->delete_array_request().has_array_handle_deprecated()) { + // TODO(b/296144873): After removing array_handle_deprecated(), move + // delete_handle's definition to the single place it is used. + delete_handle(request->delete_array_request().array_handle_deprecated()); + } + + for (auto array_handle : request->delete_array_request().array_handle()) { + delete_handle(array_handle); + } - auto deletion_future = array->Delete(); uint64_t future_handle = handle_generator_.New(); { absl::MutexLock lock(&futures_mutex_); - futures_.insert({future_handle, std::move(deletion_future)}); + futures_.insert({future_handle, JoinFutures(deletion_futures)}); } auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); @@ -793,16 +814,30 @@ IfrtBackend::HandleIsArrayDeletedRequest(std::unique_ptr request) { absl::StatusOr IfrtBackend::HandleDestructArrayRequest(std::unique_ptr request) { + std::vector bad_handles; { absl::MutexLock lock(&arrays_mutex_); - bool deleted = - arrays_.erase(request->destruct_array_request().array_handle()); - if (!deleted) { - return absl::NotFoundError( - absl::StrCat("Unknown array handle: ", - request->destruct_array_request().array_handle())); + for (const uint64_t array_handle : + request->destruct_array_request().array_handle()) { + if (!arrays_.erase(array_handle)) { + bad_handles.push_back(array_handle); + } + } + + if (request->destruct_array_request().has_array_handle_deprecated()) { + const uint64_t array_handle = + request->destruct_array_request().array_handle_deprecated(); + if (!arrays_.erase(array_handle)) { + bad_handles.push_back(array_handle); + } } } + + if (!bad_handles.empty()) { + return absl::NotFoundError(absl::StrCat("Unknown array handle(s): ", + absl::StrJoin(bad_handles, ","))); + } + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); // Currently DestructArrayResponse is an empty message, but proxy clients may diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 17ae98c13cd76b..160d6fe2885a61 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -345,6 +345,9 @@ class IfrtBackendHandlerTest : public IfrtBackendTest { } absl::Status CheckFuture(uint64_t handle) { + if (handle == 0) { + return absl::InternalError("Test error, future handle is 0"); + } auto request = NewIfrtRequest(NewOpId()); request->mutable_check_future_request()->set_future_handle(handle); TF_ASSIGN_OR_RETURN(std::shared_ptr response, @@ -913,26 +916,46 @@ TEST_P(IfrtBackendHandlerTest, } TEST_P(IfrtBackendHandlerTest, DeleteArraySuccess) { - tsl::RCReference mock_array = - tsl::MakeRef(); - EXPECT_CALL(*mock_array, Delete()) + auto mock_array1 = tsl::MakeRef(); + EXPECT_CALL(*mock_array1, Delete()) + .WillOnce(Return(Future<>(absl::OkStatus()))); + auto mock_array2 = tsl::MakeRef(); + EXPECT_CALL(*mock_array2, Delete()) .WillOnce(Return(Future<>(absl::OkStatus()))); - TF_ASSERT_OK_AND_ASSIGN(auto array_handle, - MakeTestArray(std::move(mock_array))); + + TF_ASSERT_OK_AND_ASSIGN(auto array_handle1, + MakeTestArray(std::move(mock_array1))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle2, + MakeTestArray(std::move(mock_array2))); uint64_t op_id = NewOpId(); auto ifrt_request = NewIfrtRequest(op_id); - ifrt_request->mutable_delete_array_request()->set_array_handle(array_handle); + ifrt_request->mutable_delete_array_request()->add_array_handle(array_handle1); + ifrt_request->mutable_delete_array_request()->add_array_handle(array_handle2); TF_ASSERT_OK_AND_ASSIGN(auto resp, CallBackend(std::move(ifrt_request))); EXPECT_THAT(tsl::StatusFromProto(resp->response_metadata().status()), IsOk()); - EXPECT_NE(resp->delete_array_response().deletion_future_handle(), 0); + TF_EXPECT_OK( + CheckFuture(resp->delete_array_response().deletion_future_handle())); } -TEST_P(IfrtBackendHandlerTest, DeleteArrayFailsWithNonExistentArrayHandle) { +TEST_P(IfrtBackendHandlerTest, + DeleteArrayReturnsFutureWithNonExistentArrayHandle) { + // Create one existing array. + auto mock_array1 = tsl::MakeRef(); + EXPECT_CALL(*mock_array1, Delete()) + .WillOnce(Return(Future<>(absl::OkStatus()))); + TF_ASSERT_OK_AND_ASSIGN(auto real_handle, + MakeTestArray(std::move(mock_array1))); + + constexpr int kBadHandle = 400; auto ifrt_request = NewIfrtRequest(NewOpId()); - ifrt_request->mutable_delete_array_request()->set_array_handle(0); - EXPECT_THAT(CallBackend(std::move(ifrt_request)), - StatusIs(absl::StatusCode::kNotFound)); + ifrt_request->mutable_delete_array_request()->add_array_handle(real_handle); + ifrt_request->mutable_delete_array_request()->add_array_handle(kBadHandle); + TF_ASSERT_OK_AND_ASSIGN(auto resp, CallBackend(std::move(ifrt_request))); + + EXPECT_THAT( + CheckFuture(resp->delete_array_response().deletion_future_handle()), + StatusIs(absl::StatusCode::kNotFound)); } TEST_P(IfrtBackendHandlerTest, @@ -968,14 +991,20 @@ TEST_P(IfrtBackendHandlerTest, IsDeleteFailsForNonExistentArrays) { } TEST_P(IfrtBackendHandlerTest, DestructArrayTest) { - tsl::RCReference mock_array = + tsl::RCReference mock_array1 = tsl::MakeRef(); - TF_ASSERT_OK_AND_ASSIGN(auto array_handle, - MakeTestArray(std::move(mock_array))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle1, + MakeTestArray(std::move(mock_array1))); + tsl::RCReference mock_array2 = + tsl::MakeRef(); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle2, + MakeTestArray(std::move(mock_array2))); auto ifrt_request = NewIfrtRequest(NewOpId()); - ifrt_request->mutable_destruct_array_request()->set_array_handle( - array_handle); + ifrt_request->mutable_destruct_array_request()->add_array_handle( + array_handle1); + ifrt_request->mutable_destruct_array_request()->add_array_handle( + array_handle2); TF_ASSERT_OK_AND_ASSIGN(auto ifrt_resp, CallBackend(std::move(ifrt_request))); EXPECT_TRUE(ifrt_resp->has_destruct_array_response()); @@ -983,8 +1012,8 @@ TEST_P(IfrtBackendHandlerTest, DestructArrayTest) { // handle no longer exists on the server, (2) DestructArray fails for // non-existent arrays and (3) DestructArray is not idempotent. ifrt_request = NewIfrtRequest(NewOpId()); - ifrt_request->mutable_destruct_array_request()->set_array_handle( - array_handle); + ifrt_request->mutable_destruct_array_request()->add_array_handle( + array_handle1); EXPECT_THAT(CallBackend(std::move(ifrt_request)), StatusIs(absl::StatusCode::kNotFound)); } diff --git a/third_party/xla/xla/python/ifrt_proxy/server/version.h b/third_party/xla/xla/python/ifrt_proxy/server/version.h index 686fe78993bfd2..d3707d1748175b 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/version.h +++ b/third_party/xla/xla/python/ifrt_proxy/server/version.h @@ -26,7 +26,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kServerMinVersion = 1; -inline constexpr int kServerMaxVersion = 4; +inline constexpr int kServerMaxVersion = 5; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) // Returns a version that both the client and the server support, or an error if From 3ab5eb4a19ba3ba3933c42591f29c7857499e74f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 02:02:11 -0700 Subject: [PATCH 177/483] compat: Update forward compatibility horizon to 2024-09-24 PiperOrigin-RevId: 678142346 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index f7a548d8aa13ca..df65c9e37bbc68 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 23) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 24) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From c3a4b5e7a09864d4a89c6c6fd39afdb0d63eefdd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 02:02:21 -0700 Subject: [PATCH 178/483] Update GraphDef version to 1995. PiperOrigin-RevId: 678142416 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e89334c66dc1b5..71f6b4932cbd33 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1994 // Updated: 2024/9/23 +#define TF_GRAPH_DEF_VERSION 1995 // Updated: 2024/9/24 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 6c64150f614290a8c52c992c3c088a077861417e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 24 Sep 2024 02:16:37 -0700 Subject: [PATCH 179/483] [XLA:GPU][NFC] Clean up `TritonSoftmaxTest.CanFuseAndEmitDiamondWithInputNumberOfElementsLargerThanInt32Max`. The test doesn't need to exercise the optimization pipeline. PiperOrigin-RevId: 678147277 --- .../triton_fusion_emitter_large_test.cc | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc index 97634100ad3aa6..3e52ae45936b67 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc @@ -144,7 +144,7 @@ class TritonSoftmaxTest : public GpuCodegenTest { }; TEST_F(TritonSoftmaxTest, - CanFuseAndEmitDiamondWithInputNumberOfElementsLargerThanInt32Max) { + CanEmitDiamondWithInputNumberOfElementsLargerThanInt32Max) { const std::string hlo_text = R"( HloModule softmax @@ -154,26 +154,26 @@ max_computation { ROOT maximum = f16[] maximum(arg_0, arg_1) } -ENTRY main { +triton_fusion_computation { param_0 = f16[65538,32768]{1,0} parameter(0) constant_neg_inf = f16[] constant(-inf) reduce = f16[65538]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation broadcast = f16[65538,32768]{1,0} broadcast(reduce), dimensions={0} ROOT subtract = f16[65538,32768]{1,0} subtract(param_0, broadcast) } -)"; - MatchOptimizedHlo(hlo_text, R"( -; CHECK: ENTRY -; CHECK: %[[P0:.*]] = f16[65538,32768]{1,0} parameter(0) -; CHECK: ROOT -; CHECK-SAME: fusion(%[[P0]]) -; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton -)"); +ENTRY main { + param_0 = f16[65538,32768]{1,0} parameter(0) + ROOT fusion = f16[65538,32768]{1,0} fusion(param_0), kind=kCustom, + calls=triton_fusion_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", + "block_level_fusion_config":{"output_tile_sizes":["1","32768"], + "num_warps":"1"}}} +})"; // Checking that this does not crash should be enough. - EXPECT_TRUE(Run(hlo_text)); + EXPECT_TRUE(Run(hlo_text, /*run_hlo_passes=*/false)); } } // namespace From 69b0d4d7887942618ce89aa2068ec173a230474f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 02:23:09 -0700 Subject: [PATCH 180/483] Automated Code Change PiperOrigin-RevId: 678149465 --- tensorflow/cc/training/queue_runner.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 508be19fd50a0f..fe3e8ba881059e 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -205,8 +205,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { UpdateStatus(RealRun(sess, close_op_name_, false)); } } else if (!status.ok()) { - LOG(ERROR) << "Queue runner thread got a failure status: " - << status.ToString(); + LOG(ERROR) << "Queue runner thread got a failure status: " << status; UpdateStatus(status); if (coord_) { coord_->RequestStop().IgnoreError(); From e87146e2fc7f073562130fdd24a9b6c9d8a98174 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 03:52:16 -0700 Subject: [PATCH 181/483] Automated Code Change PiperOrigin-RevId: 678177083 --- .../delegate_performance/android/src/main/native/BUILD | 3 +++ .../android/src/main/native/accuracy_benchmark.cc | 3 +++ .../android/src/main/native/accuracy_benchmark.h | 2 ++ .../android/src/main/native/latency_benchmark.cc | 3 +++ 4 files changed, 11 insertions(+) diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD index d2d1f807604ab5..9114d2314f6327 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD @@ -37,7 +37,10 @@ cc_library( "//tensorflow/core/util:stats_calculator_portable", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/profiling:memory_info", + "//tensorflow/lite/tools/benchmark:benchmark_model_lib", + "//tensorflow/lite/tools/benchmark:benchmark_params", "//tensorflow/lite/tools/benchmark:benchmark_tflite_model_lib", "//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto:delegate_performance_cc_proto", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc index 906b109f86da66..b49c45e12ec62b 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc @@ -23,8 +23,11 @@ limitations under the License. #include #include +#include "flatbuffers/base.h" // from @flatbuffers #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h" #include "tensorflow/lite/kernels/internal/compatibility.h" diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h index 63e85cf70028b1..c2fcb0ca5df72c 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" namespace tflite { diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc index 4f739ccc838a47..03ca95276e0700 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc @@ -29,9 +29,12 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/core/util/stats_calculator.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/profiling/memory_info.h" +#include "tensorflow/lite/tools/benchmark/benchmark_model.h" +#include "tensorflow/lite/tools/benchmark/benchmark_params.h" #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" #include "tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/delegate_performance.pb.h" From fb48aab30a2b25e6a6f78ee5b335c2eecc7a64e6 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 24 Sep 2024 03:56:52 -0700 Subject: [PATCH 182/483] [XLA:GPU] Tighten the heuristic that determines if a tile is too big. The previous heuristic allowed tiles with sizes equal to the number of available registers---which is almost guaranteed to spill. Reducing this number to 50% is better, but the heuristic still easily leads to spills for non-trivial kernels. We empirically determined 40% to be a reasonable value by looking at normalization-based kernels. PiperOrigin-RevId: 678178051 --- .../model/gpu_indexing_performance_model.cc | 21 +++++++++- .../gpu_indexing_performance_model_test.cc | 39 +++++++++---------- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 7798b80d17a681..13a4a0ac6e4d54 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -86,6 +86,24 @@ int64_t GetPaddedTileSize(absl::Span tile_sizes) { // heuristic tries to be safe and increase recall at the cost of precision. bool DoesTileFitsInRegisters(int64_t tile_size, const se::DeviceDescription& device_info) { + // This is a conservative estimate to make sure that we don't get a tile that + // is too big and results in register spills. + // + // We had the following reasoning for the value of this constant: + // * Whenever a block needs to use a tile more than once, it needs to + // either (1) load the tile from HBM several times, or (2) store the tile + // in registers at the same time as some of the results. That is the case + // for normalization diamonds for instance, where the input tile is used + // twice. + // * We expect kernels without reuse to benefit from smaller tile sizes + // anyway. + // * We use around 20% of the registers as working memory for indexing + // computations and expensive instructions like exponential or cosine. + // + // This value was empirically determined in September 2024 and may change in + // the future. + constexpr double kFractionOfRegistersAvailableToStoreTile = 0.4; + // Register allocation happens at PTX->SASS level, so we can't know the exact // number of registers used by a kernel. We make a few assumptions about the // kernel we will generate (this may not hold in the future): @@ -104,7 +122,8 @@ bool DoesTileFitsInRegisters(int64_t tile_size, // data type. `registers_per_block_limit()` returns the number of 32-bit // registers. Check if 64-bit types need twice as many registers. Check if // smaller types can fit into one register. - return tile_size <= device_info.registers_per_block_limit(); + return tile_size <= kFractionOfRegistersAvailableToStoreTile * + device_info.registers_per_block_limit(); } // Returns the number of warps to use based on the tile size. The numbers were diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index dddf7b1d428f9d..f9f6b05702e79e 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -360,30 +360,29 @@ max_computation { } softmax { - param_0 = f16[65538,32768]{1,0} parameter(0) + param_0 = f16[131076,16384]{1,0} parameter(0) constant_neg_inf = f16[] constant(-inf) - reduce = f16[65538]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f16[65538,32768]{1,0} broadcast(reduce), dimensions={0} - ROOT subtract = f16[65538,32768]{1,0} subtract(param_0, broadcast) + reduce = f16[131076]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f16[131076,16384]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f16[131076,16384]{1,0} subtract(param_0, broadcast) } ENTRY main { - param_0 = f16[65538,32768]{1,0} parameter(0) - ROOT fusion = f16[65538,32768]{1,0} fusion(param_0), kind=kCustom, calls=softmax -} -)")); + param_0 = f16[131076,16384]{1,0} parameter(0) + ROOT fusion = f16[131076,16384]{1,0} fusion(param_0), kind=kCustom, calls=softmax +})")); auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - LaunchDimensions launch_dimensions{65538LL * 32768LL, 32}; + LaunchDimensions launch_dimensions{131076LL * 16384LL, 32}; TF_ASSERT_OK_AND_ASSIGN( auto runtime_data, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, launch_dimensions, /*output_tile_sizes=*/{1, 1})); - EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.read_time), 5863, 1); - EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.compute_time), 39, 1); - EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 5865, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.read_time), 2931, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.compute_time), 19, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 2932, 1); } // TODO(b/351342921): Remove this test once there is no special filter for @@ -463,16 +462,16 @@ add { } triton_softmax_computation { - param_0 = f32[16,40000] parameter(0) + param_0 = f32[16,16000] parameter(0) constant_0 = f32[] constant(0) reduce_0 = f32[16] reduce(param_0, constant_0), dimensions={1}, to_apply=add - broadcast = f32[16,40000] broadcast(reduce_0), dimensions={0} - ROOT multiply = f32[16,40000] multiply(param_0, broadcast) + broadcast = f32[16,16000] broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[16,16000] multiply(param_0, broadcast) } ENTRY main { - param_0 = f32[16,40000] parameter(0) - ROOT triton_softmax = f32[16,40000] fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} + param_0 = f32[16,16000] parameter(0) + ROOT triton_softmax = f32[16,16000] fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} } )")); auto fusion_adaptor = HloFusionAdaptor::ForInstruction( @@ -485,13 +484,13 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(auto res1, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, /*launch_dimensions=*/{16, 32}, - /*output_tile_sizes=*/{1, 40000})); - EXPECT_NEAR(absl::ToDoubleMicroseconds(res1.exec_time), 7, 1); + /*output_tile_sizes=*/{1, 16000})); + EXPECT_NEAR(absl::ToDoubleMicroseconds(res1.exec_time), 3, 1); TF_ASSERT_OK_AND_ASSIGN(auto res2, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, /*launch_dimensions=*/{8, 32}, - /*output_tile_sizes=*/{2, 40000})); + /*output_tile_sizes=*/{2, 16000})); EXPECT_TRUE(res2.IsInfinite()); } From df5cc128af4b28fc15adcf1f2514f86fec9ffbac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 04:52:25 -0700 Subject: [PATCH 183/483] Integrate LLVM at llvm/llvm-project@df0864e76110 Updates LLVM usage to match [df0864e76110](https://github.com/llvm/llvm-project/commit/df0864e76110) PiperOrigin-RevId: 678193538 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/temporary.patch | 10 +++++----- third_party/shardy/workspace.bzl | 4 ++-- third_party/xla/third_party/shardy/temporary.patch | 10 +++++----- third_party/xla/third_party/shardy/workspace.bzl | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 726a367bee5547..abe15efc5e7204 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" - LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" + LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" + LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index f6938677141184..5fd5f295cd7dfc 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +1,15 @@ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 55290bf..726a367 100644 +index 726a367..abe15ef 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" -- LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" -+ LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" -+ LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" +- LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" +- LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" ++ LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" ++ LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index cb2dc13e58bb06..b62c918736eb7e 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "5013981d546b3bb99d9193841edcd5318cce3ce2" - SHARDY_SHA256 = "800366d7604691e63939e2cb3ecb4acfa349253f2f0ed8b1b84e6783fad55a01" + SHARDY_COMMIT = "f9efe2966f00f8e7da8f7af3f8c8b3255cc158b8" + SHARDY_SHA256 = "6ca4c5f2de2102eca2a78ab64a443b2d327fd7b0ceb8c633a67cd1a2a316a2db" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index f6938677141184..5fd5f295cd7dfc 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,15 +1,15 @@ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 55290bf..726a367 100644 +index 726a367..abe15ef 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "42b696d7b9942fdf07d65267da40ab178464adaa" -- LLVM_SHA256 = "4f0d2053b381d3f074c64b2e460792cab11a02333f1c88bbc22b01686cf2fcb0" -+ LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" -+ LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" +- LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" +- LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" ++ LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" ++ LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index cb2dc13e58bb06..b62c918736eb7e 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "5013981d546b3bb99d9193841edcd5318cce3ce2" - SHARDY_SHA256 = "800366d7604691e63939e2cb3ecb4acfa349253f2f0ed8b1b84e6783fad55a01" + SHARDY_COMMIT = "f9efe2966f00f8e7da8f7af3f8c8b3255cc158b8" + SHARDY_SHA256 = "6ca4c5f2de2102eca2a78ab64a443b2d327fd7b0ceb8c633a67cd1a2a316a2db" tf_http_archive( name = "shardy", From 485f5d0e7316d30ef511254242e109708d25276e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 04:53:10 -0700 Subject: [PATCH 184/483] Automated Code Change PiperOrigin-RevId: 678193775 --- tensorflow/core/kernels/bincount_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc index d6f8d3dbad9ed0..69905e46be8d0c 100644 --- a/tensorflow/core/kernels/bincount_op.cc +++ b/tensorflow/core/kernels/bincount_op.cc @@ -325,7 +325,7 @@ class DenseBincountOp : public OpKernel { const int64_t num_rows = data.dim_size(0); auto weight_matrix = (weights.NumElements() == 0) - ? weights.shaped(gtl::InlinedVector(2, 0)) + ? weights.shaped(absl::InlinedVector(2, 0)) : weights.matrix(); OP_REQUIRES_OK( ctx, ctx->allocate_output(0, TensorShape({num_rows, size}), &out_t)); From 8563573b93414625defe46d42cd826324d8aa19d Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 24 Sep 2024 05:13:58 -0700 Subject: [PATCH 185/483] PR #17527: Improve readability for denormal numbers in FP8 conversion tests Imported from GitHub PR https://github.com/openxla/xla/pull/17527 As suggested in post-submit comments in https://github.com/openxla/xla/pull/17437 Limit the exponent of the test literals to the minimum representable value. Copybara import of the project: -- 39c77ff7b40f6e88603bae626158144108ce0a11 by Sergey Kozub : Improve readability for denormal numbers in E5M2 conversion tests Merging this change closes #17527 PiperOrigin-RevId: 678199529 --- third_party/xla/xla/tests/convert_test.cc | 148 +++++++++++----------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/third_party/xla/xla/tests/convert_test.cc b/third_party/xla/xla/tests/convert_test.cc index 1663ffe8a88619..10d9d29274e037 100644 --- a/third_party/xla/xla/tests/convert_test.cc +++ b/third_party/xla/xla/tests/convert_test.cc @@ -779,14 +779,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { {0x1.Cp-15, 0x1p-14}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-15, 0x1.0p-15}, // Denormal without rounding - {0x1.4p-15, 0x1.0p-15}, // Round-to-even down - {0x1.Cp-15, 0x1.0p-14}, // Round-to-even up - {0x1.3p-15, 0x1.0p-15}, // Round-to-nearest down - {0x1.5p-15, 0x1.8p-15}, // Round-to-nearest up - {0x1p-17, 0}, // Largest number that underflows - {0x1.04p-17, 0x1p-16}, // Smallest number that doesn't underflow - {0x1.BFp-15, 0x1.8p-15}, // Largest number that rounds to denormal + {0x0.8p-14, 0x0.8p-14}, // Denormal without rounding + {0x0.Ap-14, 0x0.8p-14}, // Round-to-even down + {0x0.Ep-14, 0x1.0p-14}, // Round-to-even up + {0x0.98p-14, 0x0.8p-14}, // Round-to-nearest down + {0x0.A8p-14, 0x0.Cp-14}, // Round-to-nearest up + {0x0.2p-14, 0}, // Largest number that underflows + {0x0.204p-14, 0x0.4p-14}, // Smallest number that doesn't underflow + {0x0.DFCp-14, 0x0.Cp-14}, // Largest number that rounds to denormal }; std::vector inputs; @@ -831,14 +831,14 @@ XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e5m2Roundtrip)) { {0x1.Cp-15, 0x1p-14}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-15, 0x1.0p-15}, // Denormal without rounding - {0x1.4p-15, 0x1.0p-15}, // Round-to-even down + {0x1.0p-15, 0x0.8p-14}, // Denormal without rounding + {0x1.4p-15, 0x0.8p-14}, // Round-to-even down {0x1.Cp-15, 0x1.0p-14}, // Round-to-even up - {0x1.3p-15, 0x1.0p-15}, // Round-to-nearest down - {0x1.5p-15, 0x1.8p-15}, // Round-to-nearest up + {0x1.3p-15, 0x0.8p-14}, // Round-to-nearest down + {0x1.5p-15, 0x0.Cp-14}, // Round-to-nearest up {0x1p-17, 0}, // Largest number that underflows - {0x1.000002p-17, 0x1p-16}, // Smallest number that doesn't underflow - {0x1.BFFFFEp-15, 0x1.8p-15}, // Largest number that rounds to denormal + {0x1.000002p-17, 0x0.4p-14}, // Smallest number that doesn't underflow + {0x1.BFFFFEp-15, 0x0.Cp-14}, // Largest number that rounds to denormal }; std::vector inputs; @@ -943,14 +943,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-8, 0x1.0p-8}, // Denormal without rounding - {0x1.4p-8, 0x1.0p-8}, // Round-to-even down - {0x1.Cp-8, 0x1.0p-7}, // Round-to-even up - {0x1.3p-8, 0x1.0p-8}, // Round-to-nearest down - {0x1.5p-8, 0x1.8p-8}, // Round-to-nearest up - {0x1p-10, 0}, // Largest number that underflows - {0x1.004p-10, 0x1p-9}, // Smallest number that doesn't underflow - {0x1.DFCp-7, 0x1.Cp-7}, // Largest number that rounds to denormal + {0x1.0p-8, 0x0.4p-6}, // Denormal without rounding + {0x1.4p-8, 0x0.4p-6}, // Round-to-even down + {0x1.Cp-8, 0x0.8p-6}, // Round-to-even up + {0x1.3p-8, 0x0.4p-6}, // Round-to-nearest down + {0x1.5p-8, 0x0.6p-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.004p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x1.DFCp-7, 0x0.Ep-6}, // Largest number that rounds to denormal }; std::vector inputs; @@ -994,14 +994,14 @@ XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3fnRoundtrip)) { {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-8, 0x1.0p-8}, // Denormal without rounding - {0x1.4p-8, 0x1.0p-8}, // Round-to-even down - {0x1.Cp-8, 0x1.0p-7}, // Round-to-even up - {0x1.3p-8, 0x1.0p-8}, // Round-to-nearest down - {0x1.5p-8, 0x1.8p-8}, // Round-to-nearest up - {0x1p-10, 0}, // Largest number that underflows - {0x1.000002p-10, 0x1p-9}, // Smallest number that doesn't underflow - {0x1.DFFFFEp-7, 0x1.Cp-7}, // Largest number that rounds to denormal + {0x1.0p-8, 0x0.4p-6}, // Denormal without rounding + {0x1.4p-8, 0x0.4p-6}, // Round-to-even down + {0x1.Cp-8, 0x0.8p-6}, // Round-to-even up + {0x1.3p-8, 0x0.4p-6}, // Round-to-nearest down + {0x1.5p-8, 0x0.6p-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.000002p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-7, 0x0.Ep-6}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1106,14 +1106,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3b11fnuzRoundtrip) { {0x1.Ep-11, 0x1p-10}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-12, 0x1.0p-12}, // Denormal without rounding - {0x1.4p-12, 0x1.0p-12}, // Round-to-even down - {0x1.Cp-12, 0x1.0p-11}, // Round-to-even up - {0x1.3p-12, 0x1.0p-12}, // Round-to-nearest down - {0x1.5p-12, 0x1.8p-12}, // Round-to-nearest up + {0x1.0p-12, 0x0.4p-10}, // Denormal without rounding + {0x1.4p-12, 0x0.4p-10}, // Round-to-even down + {0x1.Cp-12, 0x0.8p-10}, // Round-to-even up + {0x1.3p-12, 0x0.4p-10}, // Round-to-nearest down + {0x1.5p-12, 0x0.6p-10}, // Round-to-nearest up {0x1p-14, 0}, // Largest number that underflows - {0x1.004p-14, 0x1p-13}, // Smallest number that doesn't underflow - {0x1.DFCp-11, 0x1.Cp-11}, // Largest number that rounds to denormal + {0x1.004p-14, 0x0.2p-10}, // Smallest number that doesn't underflow + {0x1.DFCp-11, 0x0.Ep-10}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1157,14 +1157,14 @@ XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3b11fnuzRoundtrip)) { {0x1.Ep-11, 0x1p-10}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-12, 0x1.0p-12}, // Denormal without rounding - {0x1.4p-12, 0x1.0p-12}, // Round-to-even down - {0x1.Cp-12, 0x1.0p-11}, // Round-to-even up - {0x1.3p-12, 0x1.0p-12}, // Round-to-nearest down - {0x1.5p-12, 0x1.8p-12}, // Round-to-nearest up + {0x1.0p-12, 0x0.4p-10}, // Denormal without rounding + {0x1.4p-12, 0x0.4p-10}, // Round-to-even down + {0x1.Cp-12, 0x0.8p-10}, // Round-to-even up + {0x1.3p-12, 0x0.4p-10}, // Round-to-nearest down + {0x1.5p-12, 0x0.6p-10}, // Round-to-nearest up {0x1p-14, 0}, // Largest number that underflows - {0x1.000002p-14, 0x1p-13}, // Smallest number that doesn't underflow - {0x1.DFFFFEp-11, 0x1.Cp-11}, // Largest number that rounds to denormal + {0x1.000002p-14, 0x0.2p-10}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-11, 0x0.Ep-10}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1271,14 +1271,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { {0x1.Cp-16, 0x1p-15}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-16, 0x1.0p-16}, // Denormal without rounding - {0x1.4p-16, 0x1.0p-16}, // Round-to-even down - {0x1.Cp-16, 0x1.0p-15}, // Round-to-even up - {0x1.3p-16, 0x1.0p-16}, // Round-to-nearest down - {0x1.5p-16, 0x1.8p-16}, // Round-to-nearest up - {0x1p-18, 0}, // Largest number that underflows - {0x1.04p-18, 0x1p-17}, // Smallest number that doesn't underflow - {0x1.BFp-16, 0x1.8p-16}, // Largest number that rounds to denormal + {0x0.4p-14, 0x0.8p-15}, // Denormal without rounding + {0x0.5p-14, 0x0.8p-15}, // Round-to-even down + {0x0.7p-14, 0x1.0p-15}, // Round-to-even up + {0x0.4Cp-14, 0x0.8p-15}, // Round-to-nearest down + {0x0.54p-14, 0x0.Cp-15}, // Round-to-nearest up + {0x0.1p-14, 0}, // Largest number that underflows + {0x0.104p-14, 0x0.4p-15}, // Smallest number that doesn't underflow + {0x0.6FCp-14, 0x0.Cp-15}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1323,14 +1323,14 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { {0x1.Cp-16, 0x1p-15}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-16, 0x1.0p-16}, // Denormal without rounding - {0x1.4p-16, 0x1.0p-16}, // Round-to-even down + {0x1.0p-16, 0x0.8p-15}, // Denormal without rounding + {0x1.4p-16, 0x0.8p-15}, // Round-to-even down {0x1.Cp-16, 0x1.0p-15}, // Round-to-even up - {0x1.3p-16, 0x1.0p-16}, // Round-to-nearest down - {0x1.5p-16, 0x1.8p-16}, // Round-to-nearest up + {0x1.3p-16, 0x0.8p-15}, // Round-to-nearest down + {0x1.5p-16, 0x0.Cp-15}, // Round-to-nearest up {0x1p-18, 0}, // Largest number that underflows - {0x1.000002p-18, 0x1p-17}, // Smallest number that doesn't underflow - {0x1.BFFFFEp-16, 0x1.8p-16}, // Largest number that rounds to denormal + {0x1.000002p-18, 0x0.4p-15}, // Smallest number that doesn't underflow + {0x1.BFFFFEp-16, 0x0.Cp-15}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1435,14 +1435,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { {0x1.Ep-8, 0x1p-7}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-9, 0x1.0p-9}, // Denormal without rounding - {0x1.4p-9, 0x1.0p-9}, // Round-to-even down - {0x1.Cp-9, 0x1.0p-8}, // Round-to-even up - {0x1.3p-9, 0x1.0p-9}, // Round-to-nearest down - {0x1.5p-9, 0x1.8p-9}, // Round-to-nearest up - {0x1p-11, 0}, // Largest number that underflows - {0x1.004p-11, 0x1p-10}, // Smallest number that doesn't underflow - {0x1.DFCp-8, 0x1.Cp-8}, // Largest number that rounds to denormal + {0x1.0p-9, 0x0.4p-7}, // Denormal without rounding + {0x1.4p-9, 0x0.4p-7}, // Round-to-even down + {0x1.Cp-9, 0x0.8p-7}, // Round-to-even up + {0x1.3p-9, 0x0.4p-7}, // Round-to-nearest down + {0x1.5p-9, 0x0.6p-7}, // Round-to-nearest up + {0x1p-11, 0}, // Largest number that underflows + {0x1.004p-11, 0x0.2p-7}, // Smallest number that doesn't underflow + {0x1.DFCp-8, 0x0.Ep-7}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1486,14 +1486,14 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { {0x1.Ep-8, 0x1p-7}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-9, 0x1.0p-9}, // Denormal without rounding - {0x1.4p-9, 0x1.0p-9}, // Round-to-even down - {0x1.Cp-9, 0x1.0p-8}, // Round-to-even up - {0x1.3p-9, 0x1.0p-9}, // Round-to-nearest down - {0x1.5p-9, 0x1.8p-9}, // Round-to-nearest up - {0x1p-11, 0}, // Largest number that underflows - {0x1.000002p-11, 0x1p-10}, // Smallest number that doesn't underflow - {0x1.DFFFFEp-8, 0x1.Cp-8}, // Largest number that rounds to denormal + {0x1.0p-9, 0x0.4p-7}, // Denormal without rounding + {0x1.4p-9, 0x0.4p-7}, // Round-to-even down + {0x1.Cp-9, 0x0.8p-7}, // Round-to-even up + {0x1.3p-9, 0x0.4p-7}, // Round-to-nearest down + {0x1.5p-9, 0x0.6p-7}, // Round-to-nearest up + {0x1p-11, 0}, // Largest number that underflows + {0x1.000002p-11, 0x0.2p-7}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-8, 0x0.Ep-7}, // Largest number that rounds to denormal }; std::vector inputs; From 153d49ec17be3cbc5b4672e24961b640669a749d Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 24 Sep 2024 05:33:16 -0700 Subject: [PATCH 186/483] [XLA:GPU][IndexAnalysis] Add a parser for indexing maps. PiperOrigin-RevId: 678205016 --- third_party/xla/xla/service/gpu/model/BUILD | 29 ++ .../gpu/model/indexing_map_serialization.cc | 493 ++++++++++++++++++ .../gpu/model/indexing_map_serialization.h | 35 ++ .../model/indexing_map_serialization_test.cc | 126 +++++ 4 files changed, 683 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc create mode 100644 third_party/xla/xla/service/gpu/model/indexing_map_serialization.h create mode 100644 third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 6ef51b853d38c1..4e4df85e58f5aa 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -521,6 +521,35 @@ xla_cc_test( ], ) +cc_library( + name = "indexing_map_serialization", + srcs = ["indexing_map_serialization.cc"], + hdrs = ["indexing_map_serialization.h"], + deps = [ + ":indexing_analysis", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +xla_cc_test( + name = "indexing_map_serialization_test", + srcs = ["indexing_map_serialization_test.cc"], + deps = [ + ":indexing_map_serialization", + ":indexing_test_utils", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "indexing_test_utils", testonly = True, diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc new file mode 100644 index 00000000000000..cbaea73fb879ed --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc @@ -0,0 +1,493 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_map_serialization.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { +namespace { + +using llvm::SmallVector; +using llvm::SmallVectorImpl; +using llvm::StringRef; +using mlir::AffineExpr; +using mlir::AffineMap; +using mlir::ArrayRef; +using mlir::MLIRContext; + +enum class Delimeter { kParen, kBracket }; + +struct Token { + enum class Kind { + // Variable name, e.g. "d0", "s1". + kVarName, + // Integer literal. + kIntLiteral, + kBoolLiteral, + // Keywords + kKeywordDomain, + kKeywordIn, + kKeywordIsSimplified, + // Arithmetic operation, e.g. "+", "-", "*", "floorDiv", "mod". + kPlus, + kMinus, + kTimes, + kFloorDiv, + kMod, + // Punctuation. + kArrow, + kLParen, + kRParen, + kLBracket, + kRBracket, + kComma, + kColon, + // Status. + kError, + kEOF + }; + StringRef spelling; + Token::Kind kind; +}; + +Token::Kind GetSingleCharTokenType(char c) { + switch (c) { + case '(': + return Token::Kind::kLParen; + case ')': + return Token::Kind::kRParen; + case '[': + return Token::Kind::kLBracket; + case ']': + return Token::Kind::kRBracket; + case ',': + return Token::Kind::kComma; + case ':': + return Token::Kind::kColon; + case '+': + return Token::Kind::kPlus; + case '-': + return Token::Kind::kMinus; + case '*': + return Token::Kind::kTimes; + default: + return Token::Kind::kError; + } +} + +bool IsPartOfAffineExpr(Token token) { + return token.kind == Token::Kind::kVarName || + token.kind == Token::Kind::kIntLiteral || + token.kind == Token::Kind::kPlus || + token.kind == Token::Kind::kMinus || + token.kind == Token::Kind::kTimes || + token.kind == Token::Kind::kFloorDiv || + token.kind == Token::Kind::kMod; +} + +class Parser { + public: + explicit Parser(llvm::StringRef input) : input_(input), it_(input.begin()) { + // Set the parser to the first token. + Advance(); + } + + const Token& GetCurrentToken() const { return current_token_; }; + void Advance() { current_token_ = GetNextTokenImpl(); } + Token GetNextToken() { + Advance(); + return current_token_; + } + + bool ConsumeToken(Token::Kind kind); + bool ParseVarName(std::string* var_name); + bool ParseInt(int64_t* value); + bool ParseBool(bool* boolean); + bool ParseInterval(Interval* interval); + bool ParseAffineExprString(std::string* affine_expr_str); + bool ParseCommaSeparatedVarList( + Delimeter delimeter, + llvm::function_ref parse_element_fn); + + private: + void ConsumeWhitespace() { + while (it_ != input_.end() && std::isspace(*it_)) ++it_; + } + + // Parses the next token from the input and sets the iterator to the position + // right after it. + Token GetNextTokenImpl(); + + llvm::StringRef input_; + llvm::StringRef::iterator it_; + Token current_token_; +}; + +bool Parser::ParseVarName(std::string* var_name) { + if (current_token_.kind != Token::Kind::kVarName) { + llvm::errs() << "Expected var name, got: " << current_token_.spelling + << "\n"; + return false; + } + *var_name = current_token_.spelling.str(); + Advance(); + return true; +} + +bool Parser::ParseInt(int64_t* value) { + int val; + if (current_token_.kind != Token::Kind::kIntLiteral || + current_token_.spelling.getAsInteger(/*radix=*/0, val)) { + llvm::errs() << "Expected int literal, got: " << current_token_.spelling + << "\n"; + return false; + } + *value = static_cast(val); + Advance(); + return true; +} + +bool Parser::ParseBool(bool* boolean) { + if (current_token_.kind != Token::Kind::kBoolLiteral) { + llvm::errs() << "Expected bool literal, got: " << current_token_.spelling + << "\n"; + return false; + } + *boolean = current_token_.spelling.compare("true") == 0; + Advance(); + return true; +} + +bool Parser::ParseInterval(Interval* interval) { + if (!ConsumeToken(Token::Kind::kLBracket) || !ParseInt(&interval->lower) || + !ConsumeToken(Token::Kind::kComma) || !ParseInt(&interval->upper) || + !ConsumeToken(Token::Kind::kRBracket)) { + return false; + } + return interval; +} + +bool Parser::ParseAffineExprString(std::string* affine_expr_str) { + unsigned num_unmatched_parens = 0; + while (true) { + if (!IsPartOfAffineExpr(current_token_)) { + if (ConsumeToken(Token::Kind::kLParen)) { + ++num_unmatched_parens; + } else if (current_token_.kind == Token::Kind::kRParen && + num_unmatched_parens > 0) { + --num_unmatched_parens; + Advance(); + } else { + break; + } + } + affine_expr_str->append(current_token_.spelling); + affine_expr_str->push_back(' '); + Advance(); + } + return current_token_.kind != Token::Kind::kError; +} + +bool Parser::ParseCommaSeparatedVarList( + Delimeter delimeter, + llvm::function_ref parse_element_fn) { + auto left_delimiter = delimeter == Delimeter::kParen ? Token::Kind::kLParen + : Token::Kind::kLBracket; + auto right_delimiter = delimeter == Delimeter::kParen + ? Token::Kind::kRParen + : Token::Kind::kRBracket; + if (!ConsumeToken(left_delimiter)) { + return false; + } + if (ConsumeToken(right_delimiter)) { + return true; + } + std::string element; + while (parse_element_fn(*this)) { + if (ConsumeToken(Token::Kind::kComma)) continue; + return ConsumeToken(right_delimiter); + } + return false; +} + +bool Parser::ConsumeToken(Token::Kind kind) { + Token token = GetCurrentToken(); + if (token.kind != kind) { + return false; + } + GetNextToken(); + return true; +} + +Token Parser::GetNextTokenImpl() { + if (current_token_.kind == Token::Kind::kError || + current_token_.kind == Token::Kind::kEOF) { + return current_token_; + } + ConsumeWhitespace(); + if (it_ == input_.end()) { + return Token{"", Token::Kind::kEOF}; + } + auto start = it_; + if (std::isalpha(*it_)) { + // Variable name. + while (it_ != input_.end() && + (std::isalpha(*it_) || std::isdigit(*it_) || *it_ == '_')) { + ++it_; + } + StringRef spelling = input_.substr(start - input_.data(), it_ - start); + if (spelling == "true" || spelling == "false") { + return Token{spelling, Token::Kind::kBoolLiteral}; + } + if (spelling == "domain") { + return Token{spelling, Token::Kind::kKeywordDomain}; + } + if (spelling == "is_simplified") { + return Token{spelling, Token::Kind::kKeywordIsSimplified}; + } + if (spelling == "in") { + return Token{spelling, Token::Kind::kKeywordIn}; + } + if (spelling == "mod") { + return Token{spelling, Token::Kind::kMod}; + } + if (spelling == "floorDiv") { + return Token{spelling, Token::Kind::kFloorDiv}; + } + return Token{spelling, Token::Kind::kVarName}; + } + if (std::isdigit(*it_)) { + auto start = it_; + while (it_ != input_.end() && std::isdigit(*it_)) { + ++it_; + } + + StringRef spelling = input_.substr(start - input_.data(), it_ - start); + return Token{spelling, Token::Kind::kIntLiteral}; + } + if (*it_ == '-') { + ++it_; + if (it_ != input_.end() && *it_ == '>') { + ++it_; + return Token{"->", Token::Kind::kArrow}; + } else { + return Token{"-", Token::Kind::kMinus}; + } + } + StringRef spelling = input_.substr(start - input_.data(), 1); + return Token{spelling, GetSingleCharTokenType(*(it_++))}; +} + +// Parses a comma separated list of variable names. It is used to parse the +// lists of dimension and symbol variables. +bool ParseVarNames(Parser& parser, Delimeter delimeter, + SmallVectorImpl& var_names) { + auto parse_var_name_fn = [&](Parser& parser) { + std::string var_name; + if (!parser.ParseVarName(&var_name)) { + return false; + } + var_names.push_back(var_name); + return true; + }; + return parser.ParseCommaSeparatedVarList(delimeter, parse_var_name_fn); +} + +// Parses a comma separated list of affine expressions. It is used to parse +// the list of affine map results. +bool ParseAffineMapResults(Parser& parser, + SmallVectorImpl& affine_expr_strs) { + auto parse_var_name_fn = [&](Parser& parser) { + std::string affine_expr_str; + if (!parser.ParseAffineExprString(&affine_expr_str)) { + return false; + } + affine_expr_strs.push_back(affine_expr_str); + return true; + }; + return parser.ParseCommaSeparatedVarList(Delimeter::kParen, + parse_var_name_fn); +} + +// Assembles an affine map from the given dimension and symbol names and the +// affine expressions for the results. +bool ParseAffineExprsWithMLIR(ArrayRef dim_var_names, + ArrayRef symbol_var_names, + ArrayRef affine_expr_strings, + MLIRContext* context, + SmallVectorImpl& affine_exprs) { + std::stringstream ss; + ss << "affine_map<(" << absl::StrJoin(dim_var_names, ", ") << ") "; + if (!symbol_var_names.empty()) { + ss << '[' << absl::StrJoin(symbol_var_names, ", ") << "] "; + } + ss << " -> (" << absl::StrJoin(affine_expr_strings, ", ") << ")>"; + auto affine_map_attr = mlir::parseAttribute(ss.str(), context); + if (!affine_map_attr) { + llvm::errs() << "Failed to parse affine map: " << ss.str() << "\n"; + return false; + } + mlir::AffineMap affine_map = + mlir::cast(affine_map_attr).getValue(); + affine_exprs = llvm::to_vector(affine_map.getResults()); + return true; +} + +} // namespace + +std::optional ParseIndexingMap(llvm::StringRef input, + mlir::MLIRContext* context) { + Parser parser(input); + + // Parse variable names. + SmallVector dim_var_names; + SmallVector symbol_var_names; + if (!ParseVarNames(parser, Delimeter::kParen, dim_var_names) || + (parser.GetCurrentToken().kind == Token::Kind::kLBracket && + !ParseVarNames(parser, Delimeter::kBracket, symbol_var_names))) { + llvm::errs() << "Failed to parse variable names\n"; + return std::nullopt; + } + + // Parse affine map results. + SmallVector affine_expr_strs; + if (!parser.ConsumeToken(Token::Kind::kArrow) || + !ParseAffineMapResults(parser, affine_expr_strs)) { + llvm::errs() << "Failed to parse affine map results\n"; + return std::nullopt; + } + int num_affine_map_results = affine_expr_strs.size(); + + // Special case: no domain is printed for the empty map. + if (dim_var_names.empty() && symbol_var_names.empty()) { + if (num_affine_map_results != 0 || + parser.GetCurrentToken().kind != Token::Kind::kEOF) { + llvm::errs() << "Expected an empty indexing map\n"; + return std::nullopt; + } + return IndexingMap{AffineMap::get(context), /*dimensions=*/{}, + /*range_vars=*/{}, /*rt_vars=*/{}}; + } + + if (!parser.ConsumeToken(Token::Kind::kComma) || + !parser.ConsumeToken(Token::Kind::kKeywordDomain) || + !parser.ConsumeToken(Token::Kind::kColon)) { + return std::nullopt; + } + // Parse dimension variables. + std::vector dim_vars; + for (auto& dim_name : dim_var_names) { + std::string var_name; + Interval interval; + if (!parser.ParseVarName(&var_name) || + !parser.ConsumeToken(Token::Kind::kKeywordIn) || + !parser.ParseInterval(&interval) || + !parser.ConsumeToken(Token::Kind::kComma)) { + return std::nullopt; + } + if (var_name != dim_name) { + return std::nullopt; + } + dim_vars.push_back(DimVar{interval}); + } + // Parse range variables. + std::vector range_vars; + for (auto& symbol_var : symbol_var_names) { + std::string var_name; + Interval interval; + if (!parser.ParseVarName(&var_name) || + !parser.ConsumeToken(Token::Kind::kKeywordIn) || + !parser.ParseInterval(&interval) || + !parser.ConsumeToken(Token::Kind::kComma)) { + return std::nullopt; + } + if (var_name != symbol_var) { + return std::nullopt; + } + range_vars.push_back(RangeVar{interval}); + } + // Parse constraints. + SmallVector constraint_bounds; + while (!parser.ConsumeToken(Token::Kind::kKeywordIsSimplified)) { + std::string affine_expr_str; + Interval interval; + if (!parser.ParseAffineExprString(&affine_expr_str) || + !parser.ConsumeToken(Token::Kind::kKeywordIn) || + !parser.ParseInterval(&interval) || + !parser.ConsumeToken(Token::Kind::kComma)) { + return std::nullopt; + } + affine_expr_strs.push_back(affine_expr_str); + constraint_bounds.push_back(interval); + } + // Parse is_simplified. + bool is_simplified; + if (!parser.ConsumeToken(Token::Kind::kColon) || + !parser.ParseBool(&is_simplified)) { + return std::nullopt; + } + // Check that the input is consumed. + if (!parser.ConsumeToken(Token::Kind::kEOF)) { + return std::nullopt; + } + + // Parse affine expressions. + SmallVector affine_exprs; + if (!ParseAffineExprsWithMLIR(dim_var_names, symbol_var_names, + affine_expr_strs, context, affine_exprs)) { + return std::nullopt; + } + ArrayRef affine_map_results = + ArrayRef(affine_exprs).take_front(num_affine_map_results); + ArrayRef constraint_exprs = + ArrayRef(affine_exprs).drop_front(num_affine_map_results); + + // Populate constraints. + SmallVector> constraints; + constraints.reserve(constraint_exprs.size()); + for (const auto& [expr, bounds] : + llvm::zip(constraint_exprs, constraint_bounds)) { + constraints.push_back(std::make_pair(expr, bounds)); + } + auto map = AffineMap::get(dim_vars.size(), range_vars.size(), + affine_map_results, context); + return IndexingMap{ + map, std::move(dim_vars), std::move(range_vars), /*rt_vars=*/{}, + constraints, is_simplified}; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h new file mode 100644 index 00000000000000..ce09a90e65a36d --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_INDEXING_MAP_SERIALIZATION_H_ +#define XLA_SERVICE_GPU_MODEL_INDEXING_MAP_SERIALIZATION_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +// Parses the given string into an IndexingMap. +std::optional ParseIndexingMap(llvm::StringRef input, + mlir::MLIRContext* context); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_INDEXING_MAP_SERIALIZATION_H_ diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc new file mode 100644 index 00000000000000..7efd04b5442804 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_map_serialization.h" + +#include +#include +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class IndexingMapSerializationTest : public HloTestBase { + public: + mlir::MLIRContext mlir_context_; + void ParseAndCheck(absl::string_view indexing_map_str) { + auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); + ASSERT_TRUE(indexing_map.has_value()); + EXPECT_THAT(indexing_map->ToString(), + MatchIndexingString(indexing_map_str)); + } +}; + +TEST_F(IndexingMapSerializationTest, EmptyMap) { ParseAndCheck("() -> ()"); } + +TEST_F(IndexingMapSerializationTest, DimsOnly) { + ParseAndCheck(R"( + (d0, d1) -> (d0 mod 2 + d1), + domain: + d0 in [0, 3], + d1 in [0, 4], + is_simplified: true + )"); +} + +TEST_F(IndexingMapSerializationTest, SymbolsOnly) { + ParseAndCheck(R"( + ()[s0, s1] -> (s0 floordiv s1), + domain: + s0 in [0, 3], + s1 in [0, 4], + is_simplified: true + )"); +} + +TEST_F(IndexingMapSerializationTest, DimsAndSymbolsNoConstraints) { + ParseAndCheck(R"( + (d0, d1)[s0, s1, s2] -> (s2, d0 + d1, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 4], + s0 in [0, 1], + s1 in [0, 1], + s2 in [0, 3], + is_simplified: false + )"); +} + +TEST_F(IndexingMapSerializationTest, DimsAndSymbolsAndConstraints) { + ParseAndCheck(R"( + (d0, d1)[s0, s1, s2] -> (s2, d0 + d1, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 4], + s0 in [0, 1], + s1 in [0, 1], + s2 in [0, 3], + d0 mod 4 in [0, 0], + d1 + s0 in [0, 45], + is_simplified: false + )"); +} + +// This test will be updated when the printing uses types of variables. +TEST_F(IndexingMapSerializationTest, CustomNames) { + auto indexing_map_str = R"( + (th_x, bl_x)[vector_elem, reduced_dim, contracted_dim] + -> (contracted_dim, th_x + bl_x, reduced_dim, vector_elem), + domain: + th_x in [0, 3], + bl_x in [0, 4], + vector_elem in [0, 1], + reduced_dim in [0, 1], + contracted_dim in [0, 3], + th_x mod 4 in [0, 0], + bl_x + vector_elem in [0, 45], + is_simplified: false + )"; + auto indexing_map_golden = R"( + (d0, d1)[s0, s1, s2] -> (s2, d0 + d1, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 4], + s0 in [0, 1], + s1 in [0, 1], + s2 in [0, 3], + d0 mod 4 in [0, 0], + d1 + s0 in [0, 45], + is_simplified: false + )"; + auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); + ASSERT_TRUE(indexing_map.has_value()); + EXPECT_THAT(indexing_map->ToString(), + MatchIndexingString(indexing_map_golden)); +} + +} // namespace +} // namespace gpu +} // namespace xla From 2c5391cdedea3e970856ce6b5cf52b239620bdbf Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 24 Sep 2024 06:39:25 -0700 Subject: [PATCH 187/483] [XLA:GPU] Add a test that ensures that certain passes are ordered as expected. This is just an initial CL. The goal is to add all such known constraints to this test. I will add more by looking at comments in a follow up. PiperOrigin-RevId: 678225547 --- third_party/xla/xla/service/gpu/BUILD | 2 + .../xla/xla/service/gpu/gpu_compiler_test.cc | 48 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8621ff79c54fff..2f2173c0295ae7 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1627,8 +1627,10 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", + "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index b37415c9d9382e..b938ee13b1f55c 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -15,8 +15,10 @@ limitations under the License. #include "xla/service/gpu/gpu_compiler.h" +#include #include #include +#include #include #include #include @@ -25,6 +27,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" @@ -55,6 +58,7 @@ limitations under the License. #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" +#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" @@ -1051,6 +1055,50 @@ ENTRY main { expect_custom_kernel_fusion_rewriter_has_run); } +struct PassRunIndex { + int first_run = std::numeric_limits::max(); + int last_run = std::numeric_limits::min(); +}; + +// Checks that both passes have actually run and that the first run of the +// `after` pass is after the last run of the `before` pass. +void VerifyPassOrder( + const absl::flat_hash_map& passes, + absl::string_view before, absl::string_view after) { + EXPECT_TRUE(passes.contains(before)) + << "Expected pass did not run: " << before; + EXPECT_TRUE(passes.contains(after)) << "Expected pass did not run: " << after; + EXPECT_LT(passes.at(before).last_run, passes.at(after).first_run) + << "Pass " << before << " ran after " << after; +} + +TEST_F(GpuCompilerPassTest, PassesAreRunInCorrectOrder) { + constexpr absl::string_view constant_module = R"( +ENTRY main { + ROOT constant = f32[] constant(0) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(constant_module)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + GetOptimizedModule(std::move(module))); + + // Maps a pass name to its first and last index. + absl::flat_hash_map passes; + int run_index = 0; + for (const HloPassMetadata& pass_metadata : + optimized_module->metadata()->proto().pass_metadata()) { + auto& pass = passes[pass_metadata.pass_name()]; + pass.first_run = std::min(pass.first_run, run_index); + pass.last_run = std::max(pass.last_run, run_index); + ++run_index; + } + + // This test captures known dependencies between passes. + VerifyPassOrder(passes, "layout-assignment", "priority-fusion"); + VerifyPassOrder(passes, "layout-assignment", "layout_normalization"); +} + } // namespace } // namespace gpu } // namespace xla From 4cb0fd4b378e64b1bcf5d74c711a6b15bf1aa670 Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Tue, 24 Sep 2024 09:30:13 -0700 Subject: [PATCH 188/483] Add custom kernel fusion to gemm fusion autotuner. The GemmFusionAutotuner currently takes a fusion and compares its runtime on different backends (Triton, CuBLAS and CuDNN). We add CustomKernelFusions (mostly Cutlass kernels) to the autotuner. Reverts 1fcc51367416e7dabe3fa5ed6df1756e594b8902 PiperOrigin-RevId: 678282242 --- third_party/xla/xla/autotuning.proto | 7 +- .../xla/xla/service/gpu/autotuning/BUILD | 11 +- .../gpu/autotuning/gemm_fusion_autotuner.cc | 317 +++++++++++++----- .../gpu/autotuning/gemm_fusion_autotuner.h | 21 +- .../autotuning/gemm_fusion_autotuner_test.cc | 220 +++++++++++- 5 files changed, 486 insertions(+), 90 deletions(-) diff --git a/third_party/xla/xla/autotuning.proto b/third_party/xla/xla/autotuning.proto index a7ffcbb57ae6ef..4cadf6dbb250eb 100644 --- a/third_party/xla/xla/autotuning.proto +++ b/third_party/xla/xla/autotuning.proto @@ -83,6 +83,10 @@ message AutotuneResult { int64 num_ctas = 7; } + message CustomKernelFusionKey { + int64 kernel_index = 1; + } + int64 scratch_bytes = 8; google.protobuf.Duration run_time = 9; @@ -93,10 +97,11 @@ message AutotuneResult { GemmKey gemm = 6; TritonGemmKey triton = 17; CudaConvPlanKey cuda_conv_plan = 15; + CustomKernelFusionKey custom_kernel_fusion = 18; stream_executor.dnn.AlgorithmProto algorithm = 16; } - // Next ID: 17 + // Next ID: 19 } message AutotuningLog { diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 88f55c49da7e2f..42d9d85bf528a9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -45,9 +45,11 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/utils:hlo_query", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:algorithm_util", + "//xla/service:call_inliner", "//xla/service:dump", "//xla/service:executable", "//xla/service:float_normalization", @@ -58,12 +60,15 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:gpu_float_support", - "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:split_k_gemm_rewriter", "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/kernels:custom_kernel_fusion", + "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", "//xla/service/gpu/transforms:cudnn_fusion_compiler", + "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:priority_fusion", @@ -72,11 +77,9 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:bits", "//xla/tsl/util/proto:proto_utils", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -137,6 +140,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_executor_header", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 79524924584c97..f18850be8ab3ee 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,7 +26,6 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -51,24 +49,28 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" +#include "xla/service/call_inliner.h" #include "xla/service/dump.h" -#include "xla/service/executable.h" #include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" -#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" +#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/priority_fusion.h" @@ -82,7 +84,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -140,76 +141,6 @@ constexpr std::array kNumCtas = {1, 2, 4, 8, 16}; using AutoTuneCacheKeyCount = absl::flat_hash_map; -class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { - public: - explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) - : config_(config) {} - - absl::Status HandleFusion(HloInstruction* hlo) override { - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); - FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - if (backend_config.kind() != kTritonGemmFusionKind && - backend_config.kind() != kCuDnnFusionKind) { - return absl::OkStatus(); - } - - VLOG(4) << "Processing " << hlo->ToString(); - if (!backend_config.has_triton_gemm_config() && - !backend_config.has_cudnn_fusion_config()) { - TF_ASSIGN_OR_RETURN( - AutotuneResult autotune_result, - AutotunerUtil::Autotune( - hlo, config_, [&]() -> absl::StatusOr { - if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( - "Expect autotune result cache hit for deviceless " - "compilation (HLO: ", - hlo->ToString(), ")")); - } - return absl::InternalError("Expect autotune result cache hit."); - })); - VLOG(4) << "Result: " << autotune_result.ShortDebugString(); - - if (autotune_result.has_triton()) { - *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } else if (autotune_result.has_gemm()) { - // Falling back to cuBLAS: Converting the fusion to a Call, so that it - // can be inlined back again. - HloComputation* const computation = hlo->parent(); - HloInstruction* const call = computation->AddInstruction( - HloInstruction::CreateCall(hlo->shape(), hlo->operands(), - hlo->fused_instructions_computation())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); - hlo = call; - } else { - CHECK(autotune_result.has_algorithm()); - backend_config.set_kind(std::string(kCuDnnFusionKind)); - backend_config.mutable_cudnn_fusion_config()->set_plan_id( - autotune_result.algorithm().algo_id()); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } - } - - if (backend_config.has_triton_gemm_config()) { - TF_ASSIGN_OR_RETURN( - const TritonGemmConfig config, - TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); - } - } - - MarkAsChanged(); - return absl::OkStatus(); - } - - private: - AutotuneConfig config_; -}; - class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { public: explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl) @@ -259,7 +190,9 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { bool missing_config = (backend_config.kind() == kTritonGemmFusionKind && !backend_config.has_triton_gemm_config()) || (backend_config.kind() == kCuDnnFusionKind && - !backend_config.has_cudnn_fusion_config()); + !backend_config.has_cudnn_fusion_config()) || + (backend_config.kind() == kCustomFusionKind && + !backend_config.has_custom_fusion_config()); if (missing_config) { if (error_out_on_cache_miss_) { return absl::NotFoundError(absl::StrCat( @@ -427,6 +360,46 @@ absl::StatusOr> CublasGemmAutotuneExtractor( return new_module; } +absl::Status UpdateFusionInstructionKernelIndex( + HloInstruction* fusion_instruction, int kernel_index) { + GpuBackendConfig gpu_config = + fusion_instruction->backend_config().value(); + gpu_config.mutable_fusion_backend_config() + ->mutable_custom_fusion_config() + ->set_kernel_index(kernel_index); + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config)); + + return absl::OkStatus(); +} + +absl::StatusOr> CustomFusionKernelAutotuneExtractor( + const GemmFusionAutotunerImpl::CustomKernelFusionConfig& cutlass_config, + const AutotuneConfig& config, const se::SemanticVersion& toolkit_version, + const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { + const HloComputation* fusion_computation = fusion->called_computation(); + std::unique_ptr new_module = + ExtractComputationIntoNewModule(*fusion_computation); + new_module->mutable_config().set_debug_options(debug_opts); + + CustomKernelFusionRewriter rewriter( + &config.GetExecutor()->GetDeviceDescription()); + PriorityFusion fusion_pass( + /*thread_pool=*/nullptr, config.GetExecutor()->GetDeviceDescription(), + PriorityFusionOptions()); + TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); + TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + + // Select custom kernel fusion kernel. + HloInstruction* custom_kernel_fusion = + hlo_query::GetFirstInstructionWithOpcode(*new_module->entry_computation(), + HloOpcode::kFusion); + int64_t kernel_index = cutlass_config.kernel_index; + TF_RETURN_IF_ERROR( + UpdateFusionInstructionKernelIndex(custom_kernel_fusion, kernel_index)); + + return new_module; +} + absl::StatusOr> FusionExtractor( const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); @@ -475,6 +448,11 @@ AutotuneResult FromConfig(const BackendConfig& config) { AutotuneResult res; if (std::holds_alternative(config)) { res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT); + } else if (std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config)) { + res.mutable_custom_kernel_fusion()->set_kernel_index( + std::get(config) + .kernel_index); } else if (std::holds_alternative( config)) { res.mutable_algorithm()->set_algo_id( @@ -574,6 +552,98 @@ std::string Serialize(const BackendConfig& config) { } // anonymous namespace +absl::Status RewriteGemmFusionToCall(HloInstruction* fusion_instr) { + // Falling back to cuBLAS: Converting the fusion to a Call, so that it + // can be inlined back again. + HloComputation* const computation = fusion_instr->parent(); + HloInstruction* const call = + computation->AddInstruction(HloInstruction::CreateCall( + fusion_instr->shape(), fusion_instr->operands(), + fusion_instr->fused_instructions_computation())); + return computation->ReplaceInstruction(fusion_instr, call); +} + +absl::Status RewriteGemmFusionToCustomKernelFusion( + HloInstruction* fusion_instr, se::DeviceDescription device_description, + int64_t kernel_index) { + // Rewrites gemm fusion to custom kernel fusion. + // First convert the fusion to a call. Then inlines the call. Then + // rewrites to custom kernel fusion. + HloComputation* const computation = fusion_instr->parent(); + HloInstruction* const call = + computation->AddInstruction(HloInstruction::CreateCall( + fusion_instr->shape(), fusion_instr->operands(), + fusion_instr->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(fusion_instr, call)); + HloPassPipeline pipeline("autotuner_custom_kernel_fusion_rewriter"); + pipeline.AddPass(); + pipeline.AddPass(&device_description, + kernel_index); + HloModule* hlo_module = call->GetModule(); + return pipeline.Run(hlo_module).status(); +} + +absl::Status GemmFusionAutotunerRewriterVisitor::HandleFusion( + HloInstruction* fusion_instr) { + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion_instr->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + if (backend_config.kind() != kTritonGemmFusionKind && + backend_config.kind() != kCuDnnFusionKind && + backend_config.kind() != kCustomFusionKind) { + return absl::OkStatus(); + } + + VLOG(4) << "Processing " << fusion_instr->ToString(); + if (!backend_config.has_triton_gemm_config() && + !backend_config.has_cudnn_fusion_config() && + !backend_config.has_custom_fusion_config()) { + TF_ASSIGN_OR_RETURN( + AutotuneResult autotune_result, + AutotunerUtil::Autotune( + fusion_instr, config_, [&]() -> absl::StatusOr { + if (config_.IsDeviceless()) { + return absl::InternalError(absl::StrCat( + "Expect autotune result cache hit for deviceless " + "compilation (HLO: ", + fusion_instr->ToString(), ")")); + } + return absl::InternalError("Expect autotune result cache hit."); + })); + VLOG(4) << "Result: " << autotune_result.ShortDebugString(); + + if (autotune_result.has_triton()) { + *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); + TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); + } else if (autotune_result.has_gemm()) { + TF_RETURN_IF_ERROR(RewriteGemmFusionToCall(fusion_instr)); + } else if (autotune_result.has_custom_kernel_fusion()) { + TF_RETURN_IF_ERROR(RewriteGemmFusionToCustomKernelFusion( + fusion_instr, config_.GetExecutor()->GetDeviceDescription(), + autotune_result.custom_kernel_fusion().kernel_index())); + } else { + CHECK(autotune_result.has_algorithm()); + backend_config.set_kind(std::string(kCuDnnFusionKind)); + backend_config.mutable_cudnn_fusion_config()->set_plan_id( + autotune_result.algorithm().algo_id()); + TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); + } + } + + if (backend_config.has_triton_gemm_config()) { + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(fusion_instr, config)); + } + } + + MarkAsChanged(); + return absl::OkStatus(); +} + // Methods required for sorting the configs. bool GemmFusionAutotunerImpl::CuBlasConfig::operator<( const CuBlasConfig& other) const { @@ -583,6 +653,10 @@ bool GemmFusionAutotunerImpl::CuDnnConfig::operator<( const CuDnnConfig& other) const { return plan_id < other.plan_id; } +bool GemmFusionAutotunerImpl::CustomKernelFusionConfig::operator<( + const CustomKernelFusionConfig& other) const { + return false; +} bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { return debug_options_.xla_gpu_autotune_level() > 0 && @@ -603,6 +677,72 @@ bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { } } +std::vector GenerateCustomKernelFusionConfigs( + const HloFusionInstruction& fusion, + se::DeviceDescription device_description) { + std::vector configs; + const CustomKernelFusionPatternRegistry* patterns = + CustomKernelFusionPatternRegistry::Default(); + HloComputation* computation = fusion.called_computation(); + // Get the first dot instruction in the fusion body. + HloInstruction* dot_instruction = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + std::vector match = + patterns->Match(device_description, dot_instruction); + + // For Cutlass we expect only one match for a gemm fusion. + if (match.size() == 1) { + CustomKernelFusionRegistry* registry = + CustomKernelFusionRegistry::Default(); + auto* custom_kernel_fusion = registry->Lookup(match[0].config().name()); + + // If custom fusion is not found it means that some of the build targets + // might not be statically linked into the binary. + if (custom_kernel_fusion != nullptr) { + // There can be multiple kernels for a single fusion pattern, which are + // selected by the kernel_index. + // To get the number of kernels we can rewrite the fusion to custom kernel + // fusion and count the number of loaded kernels. + const HloComputation* fusion_computation = fusion.called_computation(); + std::unique_ptr new_module = + ExtractComputationIntoNewModule(*fusion_computation); + CustomKernelFusionRewriter rewriter(&device_description); + absl::StatusOr changed = rewriter.Run(new_module.get()); + if (!changed.ok() || !changed.value()) { + VLOG(2) << "Skip custom kernel config. Failed to rewrite custom kernel " + "fusion: " + << changed.status(); + return configs; + } + + HloInstruction* custom_kernel_fusion_instr = + hlo_query::GetFirstInstructionWithOpcode( + *new_module->entry_computation(), HloOpcode::kFusion); + if (custom_kernel_fusion_instr == nullptr) { + VLOG(2) << "Skip custom kernel config. Failed to find custom kernel " + "fusion instruction in the rewritten module."; + return configs; + } + absl::StatusOr> kernels = + custom_kernel_fusion->LoadKernels( + device_description, + custom_kernel_fusion_instr->fused_instructions_computation()); + if (!kernels.ok()) { + VLOG(2) << "Skip custom kernel config. Failed to load custom kernels: " + << kernels.status(); + } else { + for (int i = 0; i < kernels.value().size(); ++i) { + GemmFusionAutotunerImpl::CustomKernelFusionConfig config{ + /*kernel_index=*/i}; + configs.push_back(config); + } + } + } + } + + return configs; +} + absl::StatusOr> GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { const HloDotInstruction* dot = @@ -642,6 +782,19 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { } } + // Add CustomKernelFusion (Cutlass) configs, if available. + // Go through all the instructions in the fusion body try to match them to + // a custom kernel fusion pattern. + if ((IsFusionKind(fusion, kCustomFusionKind) || + IsFusionKind(fusion, kTritonGemmFusionKind)) && + IsAutotuningEnabled() && !config_.IsDeviceless()) { + std::vector custom_kernel_fusion_configs = + GenerateCustomKernelFusionConfigs( + fusion, config_.GetExecutor()->GetDeviceDescription()); + configs.insert(configs.end(), custom_kernel_fusion_configs.begin(), + custom_kernel_fusion_configs.end()); + } + // Add triton configs. TF_ASSIGN_OR_RETURN(std::vector triton_configs, GenerateTritonConfigs(*dot)); @@ -805,6 +958,14 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, config_, config_.GetExecutor()->GetDeviceDescription(), toolkit_version_, fusion, opts); })); + } else if (std::holds_alternative(config)) { + TF_ASSIGN_OR_RETURN(executable, + compile_util.Compile([&](const DebugOptions& opts) { + return CustomFusionKernelAutotuneExtractor( + std::get(config), + config_, toolkit_version_, fusion, opts); + })); + } else { LOG(FATAL) << "Unsupported config type: " << config.index(); } @@ -1305,8 +1466,8 @@ absl::StatusOr GemmFusionAutotuner::Run( } } - return GemmFusionAutotunerVisitor(config_).RunOnModule(module, - execution_threads); + return GemmFusionAutotunerRewriterVisitor(config_).RunOnModule( + module, execution_threads); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index 7c262ffc8c613b..17272607532c20 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -29,7 +29,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" @@ -46,6 +48,18 @@ limitations under the License. namespace xla { namespace gpu { +// Uses profile results to rewrite a gemm fusion to use the best backend. +class GemmFusionAutotunerRewriterVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmFusionAutotunerRewriterVisitor(const AutotuneConfig& config) + : config_(config) {} + + absl::Status HandleFusion(HloInstruction* fusion_instr) override; + + private: + AutotuneConfig config_; +}; + // Takes a gemm fusion and chooses between cuBLAS, cuDNN, and Triton backends. // In the case of Triton, it also chooses the best tiling configuration. // @@ -99,8 +113,13 @@ class GemmFusionAutotunerImpl { int64_t plan_id; bool operator<(const CuDnnConfig& other) const; }; + struct CustomKernelFusionConfig { + int64_t kernel_index; + bool operator<(const CustomKernelFusionConfig& other) const; + }; using BackendConfig = - std::variant; + std::variant; using BackendConfigs = std::vector< std::pair>>; diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index d9bec3a09906a8..bb70c963a9c450 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -50,7 +50,9 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/semantic_version.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" @@ -195,6 +197,25 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { .cuda_compute_capability(); } + absl::StatusOr> + GetPossibleMatmulAutotuneConfigs( + const HloFusionInstruction& fusion, + const se::CudaComputeCapability& compute_capability, + const se::SemanticVersion& toolkit_version, + const DebugOptions& debug_options) { + se::GpuDeviceInfoProto deviceless_proto; + auto ccc = deviceless_proto.mutable_cuda_compute_capability(); + ccc->set_major(compute_capability.major); + ccc->set_minor(compute_capability.minor); + + DeviceConfig test_config{backend().default_stream_executor(), + backend().memory_allocator()}; + AutotuneConfig autotune_config{test_config, debug_options}; + GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, + debug_options, nullptr); + return autotuner.GenerateConfigs(fusion); + } + void CheckTritonAutotuning(absl::string_view hlo, absl::string_view expected) { HloPassPipeline pipeline("gemm_rewrite"); @@ -247,7 +268,8 @@ class GemmFusionAutotunerTestWithMorePreciseReduction } }; -absl::StatusOr> GetPossibleMatmulAutotuneConfigs( +absl::StatusOr> +GetPossibleMatmulAutotuneTritonConfigs( const HloDotInstruction& dot, const se::CudaComputeCapability& compute_capability, const se::SemanticVersion& toolkit_version, @@ -276,7 +298,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -298,7 +320,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -320,7 +342,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -872,7 +894,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -904,7 +926,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -935,7 +957,7 @@ ENTRY wais { TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), debug_options)); @@ -999,6 +1021,190 @@ ENTRY entry { CHECK_OK(autotuner.CompileAll(*compile_util, configs)); } +TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { + const std::string kHlo = R"( + HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} + + %gemm_fusion_r_computation { + %parameter_0 = bf16[1024,1024]{1,0} parameter(0) + %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) + %parameter_1 = bf16[1024,1024]{1,0} parameter(1) + %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) + ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = bf16[1024,1024]{1,0} parameter(0) + %p1 = bf16[1024,1024]{1,0} parameter(1) + ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + })"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, GeneratesConfigForUpcastGemmWithPrologue) { + const std::string kHlo = R"( + HloModule module + + %gemm_fusion_r_computation (parameter_0.1: f32[1,256,4,4096], parameter_1.1: bf16[1,4,4096,4096]) -> f32[256,4096] { + %parameter_0.1 = f32[1,256,4,4096]{3,2,1,0} parameter(0) + %bitcast.60 = f32[256,16384]{1,0} bitcast(f32[1,256,4,4096]{3,2,1,0} %parameter_0.1) + %parameter_1.1 = bf16[1,4,4096,4096]{3,2,1,0} parameter(1) + %bitcast.61 = bf16[16384,4096]{1,0} bitcast(bf16[1,4,4096,4096]{3,2,1,0} %parameter_1.1) + %convert.22 = f32[16384,4096]{1,0} convert(bf16[16384,4096]{1,0} %bitcast.61) + ROOT r = f32[256,4096]{1,0} dot(f32[256,16384]{1,0} %bitcast.60, f32[16384,4096]{1,0} %convert.22), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = f32[1,256,4,4096] parameter(0) + %p1 = bf16[1,4,4096,4096] parameter(1) + ROOT %gemm_fusion_r = f32[256,4096] fusion(%p0, %p1), kind=kCustom, + calls=gemm_fusion_r_computation, + backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, + GeneratesConfigForUpcastGemmWithPrologueAndEpilogue) { + const std::string kHlo = R"( + HloModule module + + %gemm_fusion_r_computation (parameter_0.1: f32[1,256,4,4096], parameter_1.1: bf16[1,4,4096,4096]) -> bf16[1048576] { + %parameter_0.1 = f32[1,256,4,4096]{3,2,1,0} parameter(0) + %bitcast.60 = f32[256,16384]{1,0} bitcast(f32[1,256,4,4096]{3,2,1,0} %parameter_0.1) + %parameter_1.1 = bf16[1,4,4096,4096]{3,2,1,0} parameter(1) + %bitcast.61 = bf16[16384,4096]{1,0} bitcast(bf16[1,4,4096,4096]{3,2,1,0} %parameter_1.1) + %convert.22 = f32[16384,4096]{1,0} convert(bf16[16384,4096]{1,0} %bitcast.61) + %dot.5 = f32[256,4096]{1,0} dot(f32[256,16384]{1,0} %bitcast.60, f32[16384,4096]{1,0} %convert.22), lhs_contracting_dims={1}, rhs_contracting_dims={0} + %convert.23 = bf16[256,4096]{1,0} convert(f32[256,4096]{1,0} %dot.5) + %bitcast.62 = bf16[1,256,4096]{2,1,0} bitcast(bf16[256,4096]{1,0} %convert.23) + %transpose.18 = bf16[1,4096,256]{2,1,0} transpose(bf16[1,256,4096]{2,1,0} %bitcast.62), dimensions={0,2,1} + ROOT %bitcast.63 = bf16[1048576]{0} bitcast(bf16[1,4096,256]{2,1,0} %transpose.18) + } + + ENTRY main { + %p0 = f32[1,256,4,4096] parameter(0) + %p1 = bf16[1,4,4096,4096] parameter(1) + ROOT %gemm_fusion_r = bf16[1048576] fusion(%p0, %p1), kind=kCustom, + calls=gemm_fusion_r_computation, + backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, RewritesGemmFusionToCustomKernelFusion) { + const std::string kHlo = R"( + HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} + + %gemm_fusion_r_computation { + %parameter_0 = bf16[1024,1024]{1,0} parameter(0) + %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) + %parameter_1 = bf16[1024,1024]{1,0} parameter(1) + %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) + ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = bf16[1024,1024]{1,0} parameter(0) + %p1 = bf16[1024,1024]{1,0} parameter(1) + ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + + DebugOptions opts; + AutotuneConfig autotune_config{ + DeviceConfig{backend().default_stream_executor(), + backend().memory_allocator()}, + opts}; + AutotuneCacheKey cache_key(autotune_config.GetModelStr(), + *module->entry_computation()->root_instruction()); + TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override, + ParseTextProto(R"pb( + version: 3 + results { + device: "..." + hlo: "..." + result { + custom_kernel_fusion { kernel_index: 1 } + run_time { nanos: 14 } + } + })pb")); + autotune_results_override.mutable_results(0)->set_device( + std::string(cache_key.GetModelStr())); + autotune_results_override.mutable_results(0)->set_hlo( + std::string(cache_key.GetHlo())); + + GemmFusionAutotunerRewriterVisitor visitor(autotune_config); + + CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override)); + visitor.RunOnModule(module.get(), {}).value(); + std::string pattern = R"( + CHECK: ROOT %cutlass_gemm_with_upcast + CHECK-SAME: fusion + CHECK-SAME: kind=kCustom + CHECK-SAME: "kernel_index":1 + )"; + TF_ASSERT_OK_AND_ASSIGN(bool file_check_matches, + RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(file_check_matches); +} + } // namespace } // namespace gpu } // namespace xla From 8dfe83ae768774a9622b42aa69b3b15a3e492f71 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 24 Sep 2024 09:33:04 -0700 Subject: [PATCH 189/483] #sdy Support OpShardingRule in SDY round trip export. PiperOrigin-RevId: 678283299 --- .../xla/xla/service/spmd/shardy/constants.h | 13 +++++++++++-- .../shardy/sdy_round_trip/export_shardings.cc | 17 +++++++++++------ .../shardy/sdy_round_trip/export_shardings.h | 9 ++++++--- .../test/sdy_round_trip_export_pipeline.mlir | 19 +++++++++++++++++++ 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/constants.h b/third_party/xla/xla/service/spmd/shardy/constants.h index 020c6a6c893fb1..f9bbb8a0da789e 100644 --- a/third_party/xla/xla/service/spmd/shardy/constants.h +++ b/third_party/xla/xla/service/spmd/shardy/constants.h @@ -38,14 +38,23 @@ inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName = // The attribute name for backend config. inline constexpr llvm::StringRef kXlaBackendConfigAttr = "backend_config"; -// Attribute name for temporarily storing the Shardonnay sharding during HLO -// round-trip. It cannot match the name kShardingAttr ("sdy.sharding"), as +// Attribute name for temporarily storing the Shardy sharding during HLO +// round-trip. It cannot match the name `kShardingAttr` ("sdy.sharding"), as // during round-trip, going from HLO to MHLO, the code removes attributes // in the `frontend_attributes` field, making them top level. And Shardonnay // verification expects `kShardingAttr` to be of type // TensorShardingAttr/TensorShardingPerValueAttr - not a StringAttr. inline constexpr llvm::StringRef kShardingRoundTripAttr = "xla.sdy.sharding"; +// Attribute name for temporarily storing the Shardy sharding rule during HLO +// round-trip. It cannot match the name `kShardingRuleAttr` +// ("sdy.sharding_rule"), as during round-trip, going from HLO to MHLO, the code +// removes attributes in the `frontend_attributes` field, making them top level. +// And Shardy verification expects `kShardingRuleAttr` to be of type +// OpShardingRuleAttr - not a StringAttr. +inline constexpr llvm::StringRef kShardingRuleRoundTripAttr = + "xla.sdy.sharding_rule"; + // Attribute name for temporarily storing the Shardonnay meshes during HLO // round-trip. inline constexpr llvm::StringRef kMeshesRoundTripAttr = "xla.sdy.meshes"; diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc index a7177f334c077d..9edef72c4126ee 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc @@ -70,15 +70,16 @@ using ::mlir::func::FuncOp; using ::mlir::mhlo::CustomCallOp; using ::mlir::sdy::kShardingAttr; +using ::mlir::sdy::kShardingRuleAttr; using ::mlir::sdy::MeshOp; +using ::mlir::sdy::OpShardingRuleAttr; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; // Saves `shardingPerValueAttr` including any existing `frontendAttributes` on // the `op`. -void saveOpShardingPerValueAttr(Operation* op, - TensorShardingPerValueAttr shardingPerValueAttr, - OpBuilder& builder) { +void saveOpShardingPerValueAttr( + Operation* op, TensorShardingPerValueAttr shardingPerValueAttr) { addFrontendAttribute(op, kShardingRoundTripAttr, shardingPerValueAttr); } @@ -115,8 +116,7 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { customCallOp.setHasSideEffect(true); saveOpShardingPerValueAttr( customCallOp, - TensorShardingPerValueAttr::get(customCallOp.getContext(), sharding), - builder); + TensorShardingPerValueAttr::get(customCallOp.getContext(), sharding)); returnOperand.set(customCallOp.getResult(0)); } } @@ -124,7 +124,12 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { funcOp.front().walk([&](Operation* op) { if (auto oldShardingPerValue = op->getAttrOfType(kShardingAttr)) { - saveOpShardingPerValueAttr(op, oldShardingPerValue, builder); + saveOpShardingPerValueAttr(op, oldShardingPerValue); + } + if (auto oldShardingRule = + op->getAttrOfType(kShardingRuleAttr)) { + addFrontendAttribute(op, kShardingRuleRoundTripAttr, oldShardingRule); + op->removeAttr(kShardingRuleAttr); } }); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h index 4b8ce6ab737419..ac0a4d2a80f225 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h @@ -26,9 +26,12 @@ namespace sdy { // Registers the xla-sdy-round-trip-export-shardings pass. void registerSdyRoundTripExportShardingsPass(); -// Creates the pass that converts the shardings from `kShardingAttr` to -// `kShardingRoundTripAttr` in the HLO frontend attributes and saves the -// mesh symbols as `kMeshesRoundTripAttr` in the module frontend attributes. +// Creates the pass to convert SDY attributes to frontend attributes: +// +// . Converts shardings from `kShardingAttr` to `kShardingRoundTripAttr` +// . Converts sharding rules from `kShardingRuleAttr` to +// `kShardingRuleRoundTripAttr` +// . Saves the mesh symbols as `kMeshesRoundTripAttr` // // NOTE: The `kShardingAttr`s are not removed from the ops. They are kept around // because part of the `SdyRoundTripExportPipeline` it also converts the diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir index 959cbb9a4d4f4c..6c92341ce5eae5 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir @@ -96,3 +96,22 @@ func.func @constant() -> tensor { %0 = sdy.constant dense<0> : tensor return %0 : tensor } + +// CHECK-LABEL: func @op_sharding_rule +func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK: stablehlo.custom_call @foo(%arg0, %arg1) {mhlo.frontend_attributes = {xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"}} + %0 = stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>} : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> +} + +// CHECK-LABEL: func @sharding_and_op_sharding_rule +func.func @sharding_and_op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK: stablehlo.custom_call @foo(%arg0, %arg1) {mhlo.frontend_attributes = + // CHECK-SAME: {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {}]>]>" + // CHECK-SAME: xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"} + %0 = stablehlo.custom_call @foo(%arg0, %arg1) + {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>, + sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x"}, {}]>]>} + : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> +} From 570f19a25a235e597602a7cc30e68e4d0941056e Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Tue, 24 Sep 2024 09:34:32 -0700 Subject: [PATCH 190/483] [XLA:GPU] Add support for the explicit algorithm=BF16_BF16_F32 in Triton when the input is F32. It is the case that was not covered when BF16_BF16_F32_X3 was introduced. We enable F32 input in algorithm_util.cc. But the default behavior led to F32_F32_F32 triton that was slower than the cuBLAS with ~21ms. I.e. it was not faster despite lower precision and at the same time the fusion was forbidden due to "Pure matmul". With the explicit truncation the F32 input to BF16 in the triton emitter we could reach the latency ~4ms which is way better than F32_F32_F32 (~21ms), and BF16_BF16_F32_X3 (~13ms), and BF16_BF16_F32_X6 (~18ms), but it is still slower that the clear dot for BF16 arguments (1.53ms). PiperOrigin-RevId: 678283878 --- third_party/xla/xla/service/algorithm_util.cc | 12 ++++-- .../service/gpu/dot_algorithm_support_test.cc | 19 ++++++--- .../xla/xla/service/gpu/fusions/triton/BUILD | 1 + .../fusions/triton/triton_fusion_emitter.cc | 12 ++++++ ...riton_fusion_emitter_device_legacy_test.cc | 40 +++++++++++++++++++ .../fusions/triton/triton_support_legacy.cc | 6 ++- .../xla/service/gpu/transforms/gemm_fusion.cc | 7 ++-- 7 files changed, 83 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/service/algorithm_util.cc b/third_party/xla/xla/service/algorithm_util.cc index 4e037056730909..8eec061e84e13a 100644 --- a/third_party/xla/xla/service/algorithm_util.cc +++ b/third_party/xla/xla/service/algorithm_util.cc @@ -174,9 +174,15 @@ bool IsSupportedDotAlgorithmOnGpu( return input_storage_type == F16 && (output_storage_type == F16 || output_storage_type == F32); case PrecisionConfig::ALG_DOT_BF16_BF16_F32: - return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && - input_storage_type == BF16 && - (output_storage_type == BF16 || output_storage_type == F32); + if (!is_cuda_ge_ampere && !is_rocm_mi100_and_above) return false; + switch (input_storage_type) { + case BF16: + return output_storage_type == BF16 || output_storage_type == F32; + case F32: + return output_storage_type == F32; + default: + return false; + } case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && diff --git a/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc b/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc index bb17f1a6ed58f1..f731049a8f6f6f 100644 --- a/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc +++ b/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc @@ -173,10 +173,10 @@ TEST_P(DotAlgorithmSupportTest, AlgorithmIsSupportedFromCudaCapability) { if (params.backend_restriction == BackendRestriction::kTritonOnly) { MatchOptimizedHlo(hlo_text, R"( - ;CHECK: ENTRY - ;CHECK: ROOT - ;CHECK-SAME: kCustom - ;CHECK-SAME: "triton_gemm_config" + ;CHECK: ENTRY + ;CHECK: ROOT + ;CHECK-SAME: kCustom + ;CHECK-SAME: "triton_gemm_config" )"); } } else { @@ -215,7 +215,7 @@ INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest, Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest, +INSTANTIATE_TEST_SUITE_P(DotBF16ForBf16Bf16F32Tests, DotAlgorithmSupportTest, Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(BF16), Values(BF16, F32), Values(CC(8, 0)), @@ -224,8 +224,15 @@ INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest, Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest, +INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest, + Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32), + Values(F32), Values(CC(8, 0)), + Values(SemanticVersion{6, 0, 0}), + Values(BackendRestriction::kTritonOnly), + Values(Sizes{32, 32}, Sizes{16, 2})), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest, Combine(Values(PC::ALG_DOT_BF16_BF16_F32_X3, PC::ALG_DOT_BF16_BF16_F32_X6), Values(F32), Values(F32), Values(CC(8, 0)), diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 029032b3c76268..7532b1a4e1c7b9 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -433,6 +433,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:tensor_float_32_utils", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index e9974d3ce1584f..4a4245c1590670 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -2658,6 +2658,18 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth ? mt::InputPrecision::TF32 : mt::InputPrecision::IEEE; + + // Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32. + if (dot_instr->precision_config().algorithm() == + PrecisionConfig::ALG_DOT_BF16_BF16_F32) { + if (dot_instr->operand(0)->shape().element_type() == F32) { + dot_input_lhs = Cast(b, dot_input_lhs, b.getBF16Type()); + } + if (dot_instr->operand(1)->shape().element_type() == F32) { + dot_input_rhs = Cast(b, dot_input_rhs, b.getBF16Type()); + } + } + // For fp8 matmuls, disable accumulator promotion, as it's what cublas // does. It may make sense to enable frequent accumulator promotion at // higher matmul precisions set in the config. diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 7ac574504294a1..0ea4702d7445af 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -5134,6 +5134,46 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 /*arel=*/1e-5})); } +class TritonBF16BF16F32GemmTest : public TritonTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } + + protected: + void SetUp() override { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + } +}; + +TEST_F(TritonBF16BF16F32GemmTest, WorkWithF32InputAndAlgorithm_BF16_BF16_F32) { + const std::string kHloText = R"( + HloModule t + + ENTRY main { + lhs = f32[32,64]{1,0} parameter(0) + rhs = f32[64,16]{1,0} parameter(1) + ROOT dot = f32[32,16]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + // This test could be modified to allow TF32 once this bug is fixed. // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. TEST_F(TritonTest, NoTF32For8BitOrLessWithF32) { diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc index 8280accb99e10f..802fed51f4d200 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -274,8 +275,9 @@ CodegenDecision CanTritonHandleGEMM( } else { if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), gpu_version)) { - return CodegenDecision::Forbid( - "Unsupported algorithm on the current device(s)."); + return CodegenDecision::Forbid(absl::StrFormat( + "Unsupported algorithm on the current device(s): %s", + PrecisionConfig::Algorithm_Name(dot.precision_config().algorithm()))); } } diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc index 1f6d41698aeaa4..1934f59c48e24f 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc @@ -740,6 +740,7 @@ absl::StatusOr CreateDotFusion( dot.precision_config().algorithm(); if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || + algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32 || dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() || dot.sparse_operands()) { return Decision::Allow(); @@ -757,9 +758,9 @@ absl::StatusOr CreateDotFusion( } return absl::OkStatus(); }); - if (is_pure_matmul) { - return Decision::NotProfitable("Pure Matmul"); - } + + if (is_pure_matmul) return Decision::NotProfitable("Pure Matmul"); + return Decision::Allow(); } From 37e09cbb6b5fd3b142282fdb6e294350738b4206 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 09:47:48 -0700 Subject: [PATCH 191/483] Pass `exec_properties` flag from `tflite_jni_binary` to `native.cc_binary` PiperOrigin-RevId: 678288567 --- tensorflow/lite/build_def.bzl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 8ebe89096a8128..eb3ee2d65ebce5 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -211,7 +211,8 @@ def tflite_jni_binary( tags = [], srcs = [], visibility = None, # 'None' means use the default visibility. - local_defines = []): + local_defines = [], + exec_properties = {}): """Builds a jni binary for TFLite.""" linkopts = linkopts + select({ clean_dep("//tensorflow:macos"): [ @@ -239,6 +240,7 @@ def tflite_jni_binary( testonly = testonly, visibility = visibility, local_defines = local_defines, + exec_properties = exec_properties, ) def tflite_cc_shared_object( From cd5b7b79a457c488d83a10a5be8446dc5fba3398 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 10:37:57 -0700 Subject: [PATCH 192/483] Change use_bfloat16_ to test_type_ in client_library_test_base. This is in preparation for adding support for multiple primitive types in the client library tests. PiperOrigin-RevId: 678308803 --- .../xla/xla/tests/client_library_test_base.cc | 20 +++---- .../xla/xla/tests/client_library_test_base.h | 58 +++++++++---------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/tests/client_library_test_base.cc b/third_party/xla/xla/tests/client_library_test_base.cc index 743db05f93f73b..71b6f9bc175a80 100644 --- a/third_party/xla/xla/tests/client_library_test_base.cc +++ b/third_party/xla/xla/tests/client_library_test_base.cc @@ -331,7 +331,7 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Literal* expected_ptr = &expected; Literal converted_expected; Shape layout_shape; - if (use_bfloat16_) { + if (use_bfloat16()) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { @@ -389,7 +389,7 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Literal* expected_ptr = &expected; Literal converted_expected; Shape layout_shape; - if (use_bfloat16_) { + if (use_bfloat16()) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { @@ -537,9 +537,9 @@ ClientLibraryTestBase::ComputeValueAndReference( XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaBuilder builder("relu"); - auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); auto z_value = Parameter(&builder, 0, shape, "z_value"); - auto zero = use_bfloat16_ + auto zero = use_bfloat16() ? ConstantR0(&builder, static_cast(0.0f)) : ConstantR0(&builder, 0.0f); Max(z_value, zero); @@ -550,7 +550,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaBuilder builder("max"); - auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); auto x = Parameter(&builder, 0, shape, "x"); auto y = Parameter(&builder, 1, shape, "y"); Max(x, y); @@ -561,10 +561,10 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { XlaBuilder builder("relu_sensitivity"); - auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); auto activation = Parameter(&builder, 0, shape, "activation"); auto backprop = Parameter(&builder, 1, shape, "backprop"); - auto zero = use_bfloat16_ + auto zero = use_bfloat16() ? ConstantR0(&builder, static_cast(0.0f)) : ConstantR0(&builder, 0.0f); auto activation_gtz = Gt(activation, zero); @@ -610,7 +610,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { - return ConstantLiteral(builder, use_bfloat16_ + return ConstantLiteral(builder, use_bfloat16() ? LiteralUtil::ConvertF32ToBF16(literal) : LiteralSlice(literal)); } @@ -624,7 +624,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( } Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { - if (!use_bfloat16_) { + if (!use_bfloat16()) { return shape; } Shape new_shape = shape; @@ -639,7 +639,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( const Literal& literal) { - if (use_bfloat16_) { + if (use_bfloat16()) { return LiteralUtil::ConvertF32ToBF16(literal); } return literal.Clone(); diff --git a/third_party/xla/xla/tests/client_library_test_base.h b/third_party/xla/xla/tests/client_library_test_base.h index 9185a31d7f6211..8610dd6e5ae3cb 100644 --- a/third_party/xla/xla/tests/client_library_test_base.h +++ b/third_party/xla/xla/tests/client_library_test_base.h @@ -276,8 +276,8 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a parameter instruction, transfers the literal for the parameter to // server, then stores into "data_handle" the global handle for that - // parameter. When the use_bfloat16 flag is set but the literal has F32 - // elements, the literal will be converted to BF16 before being transferred. + // parameter. When the test_type is bfloat16 but the literal has F32 elements, + // the literal will be converted to BF16 before being transferred. absl::StatusOr> CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, XlaBuilder* builder, XlaOp* data_handle); @@ -302,15 +302,13 @@ class ClientLibraryTestBase : public ::testing::Test { return AddParam(LiteralUtil::CreateFromArray(argument), builder); } - // Creates a constant instruction with the given literal. When the - // use_bfloat16 flag is set but the literal has F32 elements, the elements - // will be converted to BF16s. + // Creates a constant instruction with the given literal. When the test_type + // is bfloat16 but the literal has F32 elements, the literal will be converted + // to BF16 before being transferred. XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); - // Creates a constant instruction with the given array. When the use_bfloat16 - // flag is set but the array has float elements, the elements will be - // converted to bfloat16s. - + // Creates a constant instruction with the given array. When the test_type is + // bfloat16, the elements will be converted to bfloat16s. template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { @@ -331,7 +329,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR0Parameter(NativeT value, @@ -346,7 +344,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR1Parameter( @@ -360,7 +358,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR2Parameter( @@ -374,7 +372,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR3Parameter( @@ -388,7 +386,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR4Parameter( @@ -402,13 +400,16 @@ class ClientLibraryTestBase : public ::testing::Test { XlaBuilder* builder, XlaOp* data_handle); - // Getter and setter for the use_bfloat16 flag, which indicates whether to run + // TODO(ralphnathan): These will eventually be removed. Please have new tests + // support multiple primitive types, not just BF16. + // Getter and setter for the test_type flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. - bool use_bfloat16() const { return use_bfloat16_; } - void set_use_bfloat16(bool value) { use_bfloat16_ = value; } + bool use_bfloat16() const { return test_type_ == BF16; } + void set_use_bfloat16(bool value) { test_type_ = value ? BF16 : F32; } - // The float type used in this test, BF16 or F32 according to use_bfloat16. - PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + // The float type used in this test. + PrimitiveType FloatType() const { return test_type_; } + void set_float_type(PrimitiveType type) { test_type_ = type; } // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, @@ -416,7 +417,7 @@ class ClientLibraryTestBase : public ::testing::Test { absl::StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); - // Converts an f32 literal to bf16 if use_bfloat16_ is true. + // Converts an f32 literal to bf16 if test_type is BF16. Literal MaybeConvertLiteralToBfloat16(const Literal& literal); LocalClient* client_; @@ -441,9 +442,8 @@ class ClientLibraryTestBase : public ::testing::Test { // Converts an f32 shape to bf16 if use_bfloat16_ is true. Shape MaybeConvertShapeToBfloat16(const Shape& shape); - // Whether to run tests with all float-type input/output converted to - // bfloat16. - bool use_bfloat16_ = false; + // Type to use when running tests. + PrimitiveType test_type_ = F32; // Arguments to be passed to the computation when it runs. std::vector arguments_; @@ -584,7 +584,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR0(value); - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16() && literal.shape().element_type() == F32) { literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = client_->TransferToServer(literal).value(); @@ -597,7 +597,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( absl::Span values, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR1(values); - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16() && literal.shape().element_type() == F32) { literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = client_->TransferToServer(literal).value(); @@ -610,7 +610,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16() && literal.shape().element_type() == F32) { literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = client_->TransferToServer(literal).value(); @@ -623,7 +623,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16() && literal.shape().element_type() == F32) { literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = client_->TransferToServer(literal).value(); @@ -636,7 +636,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR4Parameter( const Array4D& array_4d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d); - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16() && literal.shape().element_type() == F32) { literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = client_->TransferToServer(literal).value(); @@ -649,7 +649,7 @@ std::unique_ptr ClientLibraryTestBase::CreateParameter( const Array& array, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateFromArray(array); - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16() && literal.shape().element_type() == F32) { literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = client_->TransferToServer(literal).value(); From 0b493868d09a7b659debb7e4750e8c68458d55e6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 10:38:31 -0700 Subject: [PATCH 193/483] [NFC] Replace all expect_true statements using absl::Is with status code matchers. This improves debugging by showing the actual status codes in test failures. PiperOrigin-RevId: 678309006 --- .../coordination/coordination_service_test.cc | 162 ++++++++++-------- 1 file changed, 88 insertions(+), 74 deletions(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 2fa500109acc17..5a0093c9d6c972 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -306,14 +306,16 @@ TEST_F(CoordinateTwoTasksTest, TestStandaloneService) { ASSERT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); ASSERT_OK(coord_service_->RecordHeartbeat(task_1_, incarnation_1_)); - EXPECT_TRUE( - absl::IsInvalidArgument(coord_service_->RecordHeartbeat(task_2, 0))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_2, 0), + StatusIs(absl::StatusCode::kInvalidArgument)); // Sending heartbeat with incarnation mismatch leads to Aborted error. - EXPECT_TRUE(absl::IsAborted(coord_service_->RecordHeartbeat(task_1_, 0))); - EXPECT_TRUE(absl::IsAborted(coord_service_->RecordHeartbeat(task_1_, 0))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_1_, 0), + StatusIs(absl::StatusCode::kAborted)); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_1_, 0), + StatusIs(absl::StatusCode::kAborted)); // Error is propagated to other tasks. - EXPECT_TRUE(absl::IsAborted(client_0_.GetStatus())); + EXPECT_THAT(client_0_.GetStatus(), StatusIs(absl::StatusCode::kAborted)); } TEST(CoordinationServiceTest, TestCoordinatedJobs) { @@ -379,8 +381,8 @@ TEST(CoordinationServiceTest, TestCoordinatedJobs) { // Registering the evaluator task is unexpected absl::Status status = coord_service->RegisterTask(evaluator, /*incarnation=*/0); - EXPECT_TRUE(absl::IsInvalidArgument(status)) << status; - EXPECT_TRUE(!status.message().empty()); + + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument)); } // RegisterTask() may succeed in the service, but the agent response times out. @@ -426,8 +428,7 @@ TEST(CoordinationServiceTest, const absl::Status status = coord_service->RegisterTask(task_0, /*incarnation=*/1); - EXPECT_TRUE(absl::IsAborted(status)) << status; - EXPECT_TRUE(!status.message().empty()); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAborted)); } // TODO(b/195990880): Remove this test once server-client connection is removed. @@ -452,8 +453,7 @@ TEST(CoordinationServiceTest, RegisterTask_AlreadyInError_Fails) { const absl::Status status = coord_service->RegisterTask(task_0, /*incarnation=*/0); - EXPECT_TRUE(absl::IsAborted(status)) << status; - EXPECT_TRUE(!status.message().empty()); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAborted)); } TEST_F(CoordinateTwoTasksTest, TestTaskHeartbeatTimeout) { @@ -464,10 +464,10 @@ TEST_F(CoordinateTwoTasksTest, TestTaskHeartbeatTimeout) { // No heartbeat for a while, leader considers the task as stale. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); - EXPECT_TRUE(absl::IsUnavailable( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); - EXPECT_TRUE(absl::IsUnavailable( - coord_service_->RecordHeartbeat(task_1_, incarnation_1_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kUnavailable)); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_1_, incarnation_1_), + StatusIs(absl::StatusCode::kUnavailable)); } TEST_F(CoordinateTwoTasksTest, @@ -505,10 +505,10 @@ TEST_F(CoordinateTwoTasksTest, absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); // Unexpected heartbeat from unregistered tasks since service state has been // reset. - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_1_, incarnation_1_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_1_, incarnation_1_), + StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, @@ -518,15 +518,14 @@ TEST_F(CoordinateTwoTasksTest, ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); // Use notifications to guarantee the ordering of operations across threads. absl::Notification n0, n1; + absl::Status s0, s1; - // The heartbeat error below should be propagated to all tasks. - absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code)); + s0 = status; n0.Notify(); }); coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code)); + s1 = status; n1.Notify(); }); @@ -534,10 +533,13 @@ TEST_F(CoordinateTwoTasksTest, // the error to the tasks. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); - // Make sure the StatusCallbacks are called. n0.WaitForNotification(); n1.WaitForNotification(); + + // Heartbeat errors are propagated to everyone. + EXPECT_THAT(s0, StatusIs(absl::StatusCode::kUnavailable)); + EXPECT_THAT(s1, StatusIs(absl::StatusCode::kUnavailable)); } TEST_F(CoordinateTwoTasksTest, @@ -546,16 +548,15 @@ TEST_F(CoordinateTwoTasksTest, ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); // Use notifications to guarantee the ordering of operations across threads. + absl::Status s0, s1; absl::Notification n0, n1; - // The heartbeat error from `task_1_` below should be propagated to all tasks. - absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code, HasSubstr("task:1"))); + s0 = status; n0.Notify(); }); coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code, HasSubstr("task:1"))); + s1 = status; n1.Notify(); }); @@ -569,10 +570,15 @@ TEST_F(CoordinateTwoTasksTest, Env::Default()->SleepForMicroseconds(sleeping_time); TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); Env::Default()->SleepForMicroseconds(sleeping_time); - // Make sure the StatusCallbacks are called. n0.WaitForNotification(); n1.WaitForNotification(); + + // The heartbeat error from `task_1_` below should be propagated to all tasks. + EXPECT_THAT(s0, + StatusIs(absl::StatusCode::kUnavailable, HasSubstr("task:1"))); + EXPECT_THAT(s1, + StatusIs(absl::StatusCode::kUnavailable, HasSubstr("task:1"))); } TEST_F(CoordinateTwoTasksTest, ReportedErrorCanPropagateThroughErrorPolling) { @@ -601,16 +607,21 @@ TEST_F(CoordinateTwoTasksTest, TestTaskRestart) { // Simulate task restart scenario: trying to register to cluster again. absl::Status s = coord_service_->RegisterTask(task_1_, /*incarnation=*/random::New64()); - EXPECT_TRUE(absl::IsAborted(s)) << s; + + EXPECT_THAT(s, StatusIs(absl::StatusCode::kAborted)); // Aborted error is also propagated to other tasks in cluster. - EXPECT_TRUE(absl::IsAborted(client_0_.GetStatus())) << client_0_.GetStatus(); + EXPECT_THAT(client_0_.GetStatus(), StatusIs(absl::StatusCode::kAborted)); } TEST_F(CoordinateTwoTasksTest, InsertKeyValue_Duplicate_Fail) { EnableCoordinationService(); ASSERT_OK(coord_service_->InsertKeyValue("key0", "original_value")); - EXPECT_TRUE(absl::IsAlreadyExists( - coord_service_->InsertKeyValue("key0", "never_added"))); + + // Inserting the same key again should fail. + EXPECT_THAT(coord_service_->InsertKeyValue("key0", "never_added"), + StatusIs(absl::StatusCode::kAlreadyExists)); + + // The original value should still be set. auto result = coord_service_->TryGetKeyValue("key0"); TF_EXPECT_OK(result.status()); EXPECT_EQ(result.value(), "original_value"); @@ -701,7 +712,7 @@ TEST(CoordinationServiceTest, TryGetKeyValue) { // Try to get nonexistent key. absl::StatusOr result = coord_service->TryGetKeyValue("test_key"); - EXPECT_TRUE(absl::IsNotFound(result.status())); + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kNotFound)); // Insert key value. ASSERT_OK(coord_service->InsertKeyValue("test_key", "test_value")); @@ -711,7 +722,7 @@ TEST(CoordinationServiceTest, TryGetKeyValue) { // Delete Key, and try to get the key again. ASSERT_OK(coord_service->DeleteKeyValue("test_key")); result = coord_service->TryGetKeyValue("test_key"); - EXPECT_TRUE(absl::IsNotFound(result.status())); + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kNotFound)); } TEST_F(CoordinateTwoTasksTest, GetKeyValueDir_SingleValueInDirectory) { @@ -1074,8 +1085,8 @@ TEST_F(CoordinationBarrierTest, BarrierWithMismatchedTasks) { /*participating_tasks=*/{GetTask(1), GetTask(2)}, [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_0)); - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_1)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(barrier_status_1, StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTask) { @@ -1097,8 +1108,8 @@ TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTask) { [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); // Barrier should fail for all tasks with the unexpected call. - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_0)); - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_1)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(barrier_status_1, StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTaskThreeTasks) { @@ -1139,7 +1150,7 @@ TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTaskThreeTasks) { [&barrier_status_2](absl::Status s) { barrier_status_2 = s; }); // Barrier should fail for task 2 which is not participating in the barrier. - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_2)); + EXPECT_THAT(barrier_status_2, StatusIs(absl::StatusCode::kInvalidArgument)); // Other clients would need to check the barrier key to detect the error. } @@ -1163,7 +1174,7 @@ TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) { n_0.WaitForNotification(); // Barrier should fail with the unexpected participating task argument. - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_0)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinationBarrierTest, BarrierTimeout) { @@ -1191,7 +1202,7 @@ TEST_F(CoordinationBarrierTest, BarrierTimeout) { // All barrier calls should fail with the same error. EXPECT_EQ(barrier_status_0, barrier_status_1); - EXPECT_TRUE(absl::IsDeadlineExceeded(barrier_status_0)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kDeadlineExceeded)); EXPECT_FALSE( absl::StrContains(barrier_status_0.message(), GetTaskName(GetTask(0)))); EXPECT_TRUE( @@ -1227,8 +1238,8 @@ TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) { /*participating_tasks=*/{}, [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); - EXPECT_TRUE(absl::IsInternal(barrier_status_0)); - EXPECT_TRUE(absl::IsInternal(barrier_status_1)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(barrier_status_1, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinationBarrierTest, BarrierCancelled) { @@ -1243,7 +1254,7 @@ TEST_F(CoordinationBarrierTest, BarrierCancelled) { absl::Status cancelled_status = GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)); - EXPECT_TRUE(absl::IsCancelled(barrier_status)); + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kCancelled)); TF_EXPECT_OK(cancelled_status); } @@ -1260,7 +1271,7 @@ TEST_F(CoordinationBarrierTest, CancelNonExistentBarrier_FutureBarrierFails) { /*participating_tasks=*/{}, [&barrier_status](absl::Status s) { barrier_status = s; }); - EXPECT_TRUE(absl::IsCancelled(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kCancelled)); } TEST_F(CoordinationBarrierTest, CancelAfterBarrierHasPassed) { @@ -1286,7 +1297,8 @@ TEST_F(CoordinationBarrierTest, CancelAfterBarrierHasPassed) { absl::Status cancelled_status = GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)); - EXPECT_TRUE(absl::IsFailedPrecondition(cancelled_status)); + EXPECT_THAT(cancelled_status, + StatusIs(absl::StatusCode::kFailedPrecondition)); TF_EXPECT_OK(barrier_status_0); TF_EXPECT_OK(barrier_status_1); TF_EXPECT_OK(barrier_status_2); @@ -1354,7 +1366,7 @@ TEST_F(CoordinationBarrierTest, BarrierFailsIfTaskIsAlreadyInError) { /*participating_tasks=*/{}, [&barrier_status](absl::Status s) { barrier_status = s; }); - EXPECT_TRUE(absl::IsInternal(barrier_status)); + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinationBarrierTest, BarrierFailsUponTaskError) { @@ -1373,7 +1385,7 @@ TEST_F(CoordinationBarrierTest, BarrierFailsUponTaskError) { GetTask(0), absl::InternalError("test_error"))); n0.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)); + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinationBarrierTest, @@ -1440,8 +1452,8 @@ TEST_F(CoordinateTwoTasksTest, Reset_HeartbeatsAreAcceptedForAGracePeriod) { // period. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(3 * kHeartbeatTimeout)); - EXPECT_TRUE(absl::IsInvalidArgument( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) { @@ -1462,7 +1474,7 @@ TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) { // Ongoing barrier should fail with error after shutdown. EXPECT_TRUE(barrier_n.HasBeenNotified()); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, Shutdown_HeartbeatsAreAcceptedForAGracePeriod) { @@ -1484,8 +1496,8 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_HeartbeatsAreAcceptedForAGracePeriod) { // period. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(3 * kHeartbeatTimeout)); - EXPECT_TRUE(absl::IsInvalidArgument( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) { @@ -1511,7 +1523,7 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) { // Ongoing barrier should fail with error after shutdown. EXPECT_TRUE(barrier_n.HasBeenNotified()); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, ShutdownWithBarrier_BarrierSucceeds) { @@ -1554,7 +1566,7 @@ TEST_F(CoordinateTwoTasksTest, // Block until barrier times out. n.WaitForNotification(); - EXPECT_TRUE(absl::IsDeadlineExceeded(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kDeadlineExceeded)); // Confirm that task_0_ has disconnected. // Note: this should not happen in prod where RegisterTask() is called after // Shutdown(), which is prevented by agent-side logic. @@ -1562,7 +1574,7 @@ TEST_F(CoordinateTwoTasksTest, // Other task is alerted that shutdown has been initiated without it. absl::Status other_task_status = client_1_.GetStatus(); - EXPECT_TRUE(absl::IsInternal(other_task_status)) << other_task_status; + EXPECT_THAT(other_task_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, @@ -1585,7 +1597,7 @@ TEST_F(CoordinateTwoTasksTest, Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(absl::Seconds(1))); - EXPECT_TRUE(absl::IsDeadlineExceeded(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kDeadlineExceeded)); // Service stops because no service-to-client connection is available for // error propagation. @@ -1593,7 +1605,7 @@ TEST_F(CoordinateTwoTasksTest, // service has stopped yet, which should fail. absl::Status s = coord_service_->RecordHeartbeat(task_1_, incarnation_1_); - EXPECT_TRUE(absl::IsInternal(s)) << s; + EXPECT_THAT(s, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, BarrierFailsIfServiceHasStopped) { @@ -1615,7 +1627,7 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsIfServiceHasStopped) { }); n0.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { @@ -1624,15 +1636,14 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); // Use notifications to guarantee the ordering of operations across threads. absl::Notification n0, n1; + absl::Status s0, s1; - // The heartbeat error below should be propagated to all tasks. - absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code)); + s0 = status; n0.Notify(); }); coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code)); + s1 = status; n1.Notify(); }); @@ -1644,6 +1655,9 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { // Make sure the StatusCallbacks are called before the barrier is called. n0.WaitForNotification(); n1.WaitForNotification(); + // The heartbeat error should be propagated to all tasks. + EXPECT_THAT(s0, StatusIs(absl::StatusCode::kUnavailable)); + EXPECT_THAT(s1, StatusIs(absl::StatusCode::kUnavailable)); absl::Notification n_barrier; absl::Status barrier_status; @@ -1655,7 +1669,7 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { }); n_barrier.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, BarrierWithSubsetFailsIfServiceHasStopped) { @@ -1680,7 +1694,7 @@ TEST_F(CoordinateTwoTasksTest, BarrierWithSubsetFailsIfServiceHasStopped) { }); n0.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, @@ -1708,7 +1722,7 @@ TEST_F(CoordinateTwoTasksTest, }); n0.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, UnrecoverableTaskPropagatesError) { @@ -1722,10 +1736,10 @@ TEST_F(CoordinateTwoTasksTest, UnrecoverableTaskPropagatesError) { ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInternal)); // For unrecoverable task, error propagates to all connected tasks. - EXPECT_TRUE(absl::IsInternal(client_1_.GetStatus())); + EXPECT_THAT(client_1_.GetStatus(), StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, RecoverableTaskWillNotPropagateError) { @@ -1739,8 +1753,8 @@ TEST_F(CoordinateTwoTasksTest, RecoverableTaskWillNotPropagateError) { ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInternal)); // Since no error propagation for recoverable tasks, other tasks should work // as normal. TF_EXPECT_OK(client_1_.GetStatus()); @@ -1758,8 +1772,8 @@ TEST_F(CoordinateTwoTasksTest, ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInternal)); // Since no error propagation for recoverable tasks, other tasks should work // as normal. TF_EXPECT_OK(client_1_.GetStatus()); From b7f3881e9a31bc7a6a5897d146ddcd53630bab67 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Tue, 24 Sep 2024 11:11:08 -0700 Subject: [PATCH 194/483] PR #17422: [ffi] Support handler bundles in GPU plugin extension Imported from GitHub PR https://github.com/openxla/xla/pull/17422 Currently only the execute handler is supported. This PR allows all of them to be set, bringing the GPU plugin extensions in sync with https://github.com/openxla/xla/blob/b9fcb2422f5e38ed8aecaec6b604d2fc86755c4d/xla/python/xla_compiler.cc#L306-L308 . See also the JAX part of this change: https://github.com/jax-ml/jax/pull/23806 Copybara import of the project: -- 349d26e45cfa4d4f4cbc1be4773ad2ac6f4034d5 by Georg Stefan Schmid : [ffi] Support handler bundles in GPU plugin extension Merging this change closes #17422 PiperOrigin-RevId: 678322017 --- third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h | 10 ++++++---- third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 9 ++++++--- third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 10 ++++++++-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h index 3ecdaeafb32749..28b17e5434f2ea 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h @@ -24,17 +24,19 @@ limitations under the License. extern "C" { #endif -#define PJRT_API_GPU_EXTENSION_VERSION 1 +#define PJRT_API_GPU_EXTENSION_VERSION 2 struct PJRT_Gpu_Register_Custom_Call_Args { size_t struct_size; const char* function_name; size_t function_name_size; int api_version; // 0 for an untyped call, 1 -- for typed - void* custom_call_function; + void* handler_instantiate; + void* handler_prepare; + void* handler_initialize; + void* handler_execute; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Gpu_Register_Custom_Call_Args, - custom_call_function); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Gpu_Register_Custom_Call_Args, handler_execute); // Registers a custom call. typedef PJRT_Error* PJRT_Gpu_Register_Custom_Call( diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 2d593290087719..5bddd6f2660e8e 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -293,14 +293,17 @@ PJRT_Error* PJRT_Gpu_Register_Custom_Call( switch (args->api_version) { case 0: xla::CustomCallTargetRegistry::Global()->Register( - function_name, args->custom_call_function, - PJRT_GPU_PLUGIN_PLATFORM_NAME); + function_name, args->handler_execute, PJRT_GPU_PLUGIN_PLATFORM_NAME); return nullptr; case 1: xla::ffi::Ffi::RegisterStaticHandler( xla::ffi::GetXlaFfiApi(), function_name, PJRT_GPU_PLUGIN_PLATFORM_NAME, - reinterpret_cast(args->custom_call_function)); + XLA_FFI_Handler_Bundle{ + reinterpret_cast(args->handler_instantiate), + reinterpret_cast(args->handler_prepare), + reinterpret_cast(args->handler_initialize), + reinterpret_cast(args->handler_execute)}); return nullptr; default: return new PJRT_Error{absl::UnimplementedError( diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index c288315db8b96c..17d9c9d72228f3 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -461,7 +461,10 @@ TEST(PjrtCApiGpuExtensionTest, CustomCallUntyped) { args.function_name = function_name.c_str(); args.function_name_size = function_name.size(); args.api_version = 0; - args.custom_call_function = reinterpret_cast(&TestCustomCallV2); + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = reinterpret_cast(&TestCustomCallV2); auto api = GetPjrtApi(); const PJRT_Extension_Base* next = reinterpret_cast(api->extension_start); @@ -491,7 +494,10 @@ TEST(PjrtCApiGpuExtensionTest, CustomCallTyped) { args.function_name = function_name.c_str(); args.function_name_size = function_name.size(); args.api_version = 1; - args.custom_call_function = reinterpret_cast(kNoop); + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = reinterpret_cast(kNoop); auto api = GetPjrtApi(); const PJRT_Extension_Base* next = reinterpret_cast(api->extension_start); From 9dc73b23974c5a9f7149ecfd11267da786c9ba06 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Tue, 24 Sep 2024 11:12:30 -0700 Subject: [PATCH 195/483] Disable msan for failing test. The test runs out of memory under msan. PiperOrigin-RevId: 678322559 --- .../python/data/experimental/kernel_tests/optimization/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index 34c1098ccae8fc..ac6a8752f8db7c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -196,6 +196,7 @@ tf_py_strict_test( size = "medium", srcs = ["optimization_test.py"], shard_count = 2, + tags = ["nomsan"], # Runs out of memory. deps = [ "//tensorflow/python/data/experimental/ops:batching", "//tensorflow/python/data/experimental/ops:grouping", From 1ddc4d5388f7fc29372ff6b1964378c63673868e Mon Sep 17 00:00:00 2001 From: Matt Bahr Date: Tue, 24 Sep 2024 14:29:27 -0400 Subject: [PATCH 196/483] fix dependency in ragged range op test --- tensorflow/python/ops/ragged/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index 3e47c991e0247b..20699c83978df8 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -904,6 +904,7 @@ py_strict_test( "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:ragged_math_ops_gen", "//tensorflow/python/platform:test", + "//third_party/py/numpy", ], ) From 2bfeaf3cdf8a8a8bb5ab0d26a6f53caa78ec53f6 Mon Sep 17 00:00:00 2001 From: Raviteja Gorijala Date: Tue, 24 Sep 2024 11:14:50 -0700 Subject: [PATCH 197/483] Update release notes at HEAD PiperOrigin-RevId: 678323612 --- RELEASE.md | 101 +++++++++++++++++++++++++---------------------------- 1 file changed, 48 insertions(+), 53 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 10a4a9ce528043..c146e2215fb339 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,4 @@ -# Release 2.18.0 +# Release 2.19.0 ## TensorFlow @@ -9,26 +9,6 @@ * * -* `tf.lite` - * C API: - * An optional, fourth parameter was added `TfLiteOperatorCreate` as a step - forward towards a cleaner API for `TfLiteOperator`. Function - `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, - released on 7/11/2024, and we do not expect there will be much code using this - function yet. Any code breakages can be easily resolved by passing nullptr as - the new, 4th parameter. - * SignatureRunner is now supported for models with no signatures. - -* TensorRT support is disabled in CUDA builds for code health improvement. - -* Hermetic CUDA support is added. - - Hermetic CUDA uses a specific downloadable version of CUDA instead of the - user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL - distributions, and then use CUDA libraries and tools as dependencies in - various Bazel targets. This enables more reproducible builds for Google ML - projects and supported CUDA versions. - ### Known Caveats * @@ -40,44 +20,12 @@ * * -* `tf.lite`: - * The LiteRT [repo](https://github.com/google-ai-edge/LiteRT) is - live (see [announcement](https://developers.googleblog.com/en/tensorflow-lite-is-now-litert/)), which means that in the coming months there will be changes to the development experience - for TFLite. The TF Lite Runtime source will be moved later this year, - and sometime after that we will start accepting contributions through that repo. - ### Bug Fixes and Other Changes * * * -* `tf.data` - * Add optional `synchronous` argument to `map`, to specify that the `map` - should run synchronously, as opposed to be parallelizable when - `options.experimental_optimization.map_parallelization=True`. This saves - memory compared to setting `num_parallel_calls=1`. - * Add optional `use_unbounded_threadpool` argument to `map`, to specify that - the `map` should use an unbounded threadpool instead of the default pool - that is based on the number of cores on the machine. This can improve - throughput for map functions which perform IO or otherwise release the - CPU. - * Add [`tf.data.experimental.get_model_proto`](https://www.tensorflow.org/api_docs/python/tf/data/experimental/get_model_proto) - to allow users to peek into the analytical model inside of a dataset - iterator. - -* `tf.lite` - * `Dequantize` op supports `TensorType_INT4`. - * This change includes per-channel dequantization. - * Add support for `stablehlo.composite`. - * `EmbeddingLookup` op supports per-channel - quantization and `TensorType_INT4` values. - * `FullyConnected` op supports `TensorType_INT16` activation and - `TensorType_Int4` weight per-channel quantization. - -* `tf.tensor_scatter_update`, `tf.tensor_scatter_add` and of other reduce types. - * Support `bad_indices_policy`. - ## Keras @@ -110,6 +58,53 @@ This release contains contributions from many people at Google, as well as: , , , , , +# Release 2.18.0 + +## TensorFlow + +### Breaking Changes + +* `tf.lite` + * C API: + * An optional, fourth parameter was added `TfLiteOperatorCreate` as a step forward towards a cleaner API for `TfLiteOperator`. Function `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, released on 7/11/2024, and we do not expect there will be much code using this function yet. Any code breakages can be easily resolved by passing nullptr as the new, 4th parameter. + * SignatureRunner is now supported for models with no signatures. + +* TensorRT support is disabled in CUDA builds for code health improvement. + +* Hermetic CUDA support is added. + + Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL distributions, and then use CUDA libraries and tools as dependencies in various Bazel targets. This enables more reproducible builds for Google ML projects and supported CUDA versions. + +### Known Caveats + +### Major Features and Improvements + +* `tf.lite`: + * The LiteRT [repo](https://github.com/google-ai-edge/LiteRT) is live (see [announcement](https://developers.googleblog.com/en/tensorflow-lite-is-now-litert/)), which means that in the coming months there will be changes to the development experience for TFLite. The TF Lite Runtime source will be moved later this year, and sometime after that we will start accepting contributions through that repo. + +### Bug Fixes and Other Changes + +* `tf.data` + * Add optional `synchronous` argument to `map`, to specify that the `map` should run synchronously, as opposed to be parallelizable when `options.experimental_optimization.map_parallelization=True`. This saves memory compared to setting `num_parallel_calls=1`. + * Add optional `use_unbounded_threadpool` argument to `map`, to specify that the `map` should use an unbounded threadpool instead of the default pool that is based on the number of cores on the machine. This can improve throughput for map functions which perform IO or otherwise release the CPU. + * Add [`tf.data.experimental.get_model_proto`](https://www.tensorflow.org/api_docs/python/tf/data/experimental/get_model_proto) to allow users to peek into the analytical model inside of a dataset iterator. + +* `tf.lite` + * `Dequantize` op supports `TensorType_INT4`. + * This change includes per-channel dequantization. + * Add support for `stablehlo.composite`. + * `EmbeddingLookup` op supports per-channel quantization and `TensorType_INT4` values. + * `FullyConnected` op supports `TensorType_INT16` activation and `TensorType_Int4` weight per-channel quantization. + +* `tf.tensor_scatter_update`, `tf.tensor_scatter_add` and of other reduce types. + * Support `bad_indices_policy`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Akhil Goel, akhilgoe, Alexander Pivovarov, Amir Samani, Andrew Goodbody, Andrey Portnoy, Anthony Platanios, bernardoArcari, Brett Taylor, buptzyb, Chao, Christian Clauss, Cocoa, Daniil Kutz, Darya Parygina, dependabot[bot], Dimitris Vardoulakis, Dragan Mladjenovic, Elfie Guo, eukub, Faijul Amin, flyingcat, Frédéric Bastien, ganyu.08, Georg Stefan Schmid, Grigory Reznikov, Harsha H S, Harshit Monish, Heiner, Ilia Sergachev, Jan, Jane Liu, Jaroslav Sevcik, Kaixi Hou, Kanvi Khanna, Kristof Maar, Kristóf Maár, LakshmiKalaKadali, Lbertho-Gpsw, lingzhi98, MarcoFalke, Masahiro Hiramori, Mmakevic-Amd, mraunak, Nobuo Tsukamoto, Notheisz57, Olli Lupton, Pearu Peterson, pemeliya, Peyara Nando, Philipp Hack, Phuong Nguyen, Pol Dellaiera, Rahul Batra, Ruturaj Vaidya, sachinmuradi, Sergey Kozub, Shanbin Ke, Sheng Yang, shengyu, Shraiysh, Shu Wang, Surya, sushreebarsa, Swatheesh-Mcw, syzygial, Tai Ly, terryysun, tilakrayal, Tj Xu, Trevor Morris, Tzung-Han Juang, wenchenvincent, wondertx, Xuefei Jiang, Ye Huang, Yimei Sun, Yunlong Liu, Zahid Iqbal, Zhan Lu, Zoranjovanovic-Ns, Zuri Obozuwa + # Release 2.17.0 ## TensorFlow From 4aba9ebb1d0c2bb14415dd7a0ef2b02e30a19b15 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 11:25:53 -0700 Subject: [PATCH 198/483] Copy and rework flatbuffer_conversions.h into compiler PiperOrigin-RevId: 678328142 --- tensorflow/compiler/mlir/lite/core/api/BUILD | 32 + .../lite/core/api/flatbuffer_conversions.cc | 3186 +++++++++++++++++ .../lite/core/api/flatbuffer_conversions.h | 440 +++ .../core/api/flatbuffer_conversions_test.cc | 873 +++++ tensorflow/compiler/mlir/lite/core/c/BUILD | 13 +- .../mlir/lite/core/c/builtin_op_data.h | 624 +++- .../mlir/lite/core/c/dimension_type.h | 38 - .../compiler/mlir/lite/core/c/tflite_types.h | 70 + .../utils/sparsity_format_converter.cc | 2 +- .../utils/sparsity_format_converter.h | 2 +- tensorflow/lite/c/BUILD | 1 + tensorflow/lite/core/c/BUILD | 4 + tensorflow/lite/core/c/builtin_op_data.h | 639 +--- tensorflow/lite/core/c/c_api_types.h | 40 +- tensorflow/lite/core/c/common.h | 6 - tensorflow/lite/java/BUILD | 2 + 16 files changed, 5239 insertions(+), 733 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc create mode 100644 tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h create mode 100644 tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions_test.cc delete mode 100644 tensorflow/compiler/mlir/lite/core/c/dimension_type.h create mode 100644 tensorflow/compiler/mlir/lite/core/c/tflite_types.h diff --git a/tensorflow/compiler/mlir/lite/core/api/BUILD b/tensorflow/compiler/mlir/lite/core/api/BUILD index 0aaca3928420d6..cd16efca57aad8 100644 --- a/tensorflow/compiler/mlir/lite/core/api/BUILD +++ b/tensorflow/compiler/mlir/lite/core/api/BUILD @@ -52,3 +52,35 @@ tf_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "flatbuffer_conversions", + srcs = ["flatbuffer_conversions.cc"], + hdrs = [ + "flatbuffer_conversions.h", + ], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), + deps = [ + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + "//tensorflow/compiler/mlir/lite/kernels/internal:compatibility_macros", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@flatbuffers//:runtime_cc", + ], +) + +tf_cc_test( + name = "flatbuffer_conversions_test", + size = "small", + srcs = ["flatbuffer_conversions_test.cc"], + deps = [ + ":flatbuffer_conversions", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:runtime_cc", + ], +) diff --git a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc new file mode 100644 index 00000000000000..60db7412bd199f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc @@ -0,0 +1,3186 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h" + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +/// Check whether the value `a` is true, and if not return +/// absl::InvalidArgumentError from the current function, while also +/// reporting the location of the error. +#define TFL_MIGRATION_ENSURE(a) \ + do { \ + if (!(a)) { \ + auto error_message = \ + absl::StrFormat("%s:%d %s was not true.", __FILE__, __LINE__, #a); \ + LOG(ERROR) << error_message; \ + return absl::InvalidArgumentError(error_message); \ + } \ + } while (0) + +#define TFL_MIGRATION_ENSURE_STATUS(a) \ + do { \ + const absl::Status s = (a); \ + if (!s.ok()) { \ + return s; \ + } \ + } while (0) + +namespace tflite_migration { +using absl::OkStatus; +using tflite::ActivationFunctionType; +using tflite::ActivationFunctionType_NONE; +using tflite::ActivationFunctionType_RELU; +using tflite::ActivationFunctionType_RELU6; +using tflite::ActivationFunctionType_RELU_N1_TO_1; +using tflite::ActivationFunctionType_SIGN_BIT; +using tflite::ActivationFunctionType_TANH; +using tflite::BuiltinOperator; +using tflite::BuiltinOperator_ABS; +using tflite::BuiltinOperator_ADD; +using tflite::BuiltinOperator_ADD_N; +using tflite::BuiltinOperator_ARG_MAX; +using tflite::BuiltinOperator_ARG_MIN; +using tflite::BuiltinOperator_ASSIGN_VARIABLE; +using tflite::BuiltinOperator_AVERAGE_POOL_2D; +using tflite::BuiltinOperator_BATCH_MATMUL; +using tflite::BuiltinOperator_BATCH_TO_SPACE_ND; +using tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM; +using tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN; +using tflite::BuiltinOperator_BITWISE_XOR; +using tflite::BuiltinOperator_BROADCAST_ARGS; +using tflite::BuiltinOperator_BROADCAST_TO; +using tflite::BuiltinOperator_BUCKETIZE; +using tflite::BuiltinOperator_CALL; +using tflite::BuiltinOperator_CALL_ONCE; +using tflite::BuiltinOperator_CAST; +using tflite::BuiltinOperator_CEIL; +using tflite::BuiltinOperator_COMPLEX_ABS; +using tflite::BuiltinOperator_CONCAT_EMBEDDINGS; +using tflite::BuiltinOperator_CONCATENATION; +using tflite::BuiltinOperator_CONV_2D; +using tflite::BuiltinOperator_CONV_3D; +using tflite::BuiltinOperator_CONV_3D_TRANSPOSE; +using tflite::BuiltinOperator_COS; +using tflite::BuiltinOperator_CUMSUM; +using tflite::BuiltinOperator_CUSTOM; +using tflite::BuiltinOperator_DELEGATE; +using tflite::BuiltinOperator_DENSIFY; +using tflite::BuiltinOperator_DEPTH_TO_SPACE; +using tflite::BuiltinOperator_DEPTHWISE_CONV_2D; +using tflite::BuiltinOperator_DEQUANTIZE; +using tflite::BuiltinOperator_DIV; +using tflite::BuiltinOperator_DYNAMIC_UPDATE_SLICE; +using tflite::BuiltinOperator_ELU; +using tflite::BuiltinOperator_EMBEDDING_LOOKUP; +using tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE; +using tflite::BuiltinOperator_EQUAL; +using tflite::BuiltinOperator_EXP; +using tflite::BuiltinOperator_EXPAND_DIMS; +using tflite::BuiltinOperator_FAKE_QUANT; +using tflite::BuiltinOperator_FILL; +using tflite::BuiltinOperator_FLOOR; +using tflite::BuiltinOperator_FLOOR_DIV; +using tflite::BuiltinOperator_FLOOR_MOD; +using tflite::BuiltinOperator_FULLY_CONNECTED; +using tflite::BuiltinOperator_GATHER; +using tflite::BuiltinOperator_GATHER_ND; +using tflite::BuiltinOperator_GELU; +using tflite::BuiltinOperator_GREATER; +using tflite::BuiltinOperator_GREATER_EQUAL; +using tflite::BuiltinOperator_HARD_SWISH; +using tflite::BuiltinOperator_HASHTABLE; +using tflite::BuiltinOperator_HASHTABLE_FIND; +using tflite::BuiltinOperator_HASHTABLE_IMPORT; +using tflite::BuiltinOperator_HASHTABLE_LOOKUP; +using tflite::BuiltinOperator_HASHTABLE_SIZE; +using tflite::BuiltinOperator_IF; +using tflite::BuiltinOperator_IMAG; +using tflite::BuiltinOperator_L2_NORMALIZATION; +using tflite::BuiltinOperator_L2_POOL_2D; +using tflite::BuiltinOperator_LEAKY_RELU; +using tflite::BuiltinOperator_LESS; +using tflite::BuiltinOperator_LESS_EQUAL; +using tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION; +using tflite::BuiltinOperator_LOG; +using tflite::BuiltinOperator_LOG_SOFTMAX; +using tflite::BuiltinOperator_LOGICAL_AND; +using tflite::BuiltinOperator_LOGICAL_NOT; +using tflite::BuiltinOperator_LOGICAL_OR; +using tflite::BuiltinOperator_LOGISTIC; +using tflite::BuiltinOperator_LSH_PROJECTION; +using tflite::BuiltinOperator_LSTM; +using tflite::BuiltinOperator_MATRIX_DIAG; +using tflite::BuiltinOperator_MATRIX_SET_DIAG; +using tflite::BuiltinOperator_MAX_POOL_2D; +using tflite::BuiltinOperator_MAXIMUM; +using tflite::BuiltinOperator_MEAN; +using tflite::BuiltinOperator_MINIMUM; +using tflite::BuiltinOperator_MIRROR_PAD; +using tflite::BuiltinOperator_MUL; +using tflite::BuiltinOperator_MULTINOMIAL; +using tflite::BuiltinOperator_NEG; +using tflite::BuiltinOperator_NON_MAX_SUPPRESSION_V4; +using tflite::BuiltinOperator_NON_MAX_SUPPRESSION_V5; +using tflite::BuiltinOperator_NOT_EQUAL; +using tflite::BuiltinOperator_ONE_HOT; +using tflite::BuiltinOperator_PACK; +using tflite::BuiltinOperator_PAD; +using tflite::BuiltinOperator_PADV2; +using tflite::BuiltinOperator_POW; +using tflite::BuiltinOperator_PRELU; +using tflite::BuiltinOperator_QUANTIZE; +using tflite::BuiltinOperator_RANDOM_STANDARD_NORMAL; +using tflite::BuiltinOperator_RANDOM_UNIFORM; +using tflite::BuiltinOperator_RANGE; +using tflite::BuiltinOperator_RANK; +using tflite::BuiltinOperator_READ_VARIABLE; +using tflite::BuiltinOperator_REAL; +using tflite::BuiltinOperator_REDUCE_ALL; +using tflite::BuiltinOperator_REDUCE_ANY; +using tflite::BuiltinOperator_REDUCE_MAX; +using tflite::BuiltinOperator_REDUCE_MIN; +using tflite::BuiltinOperator_REDUCE_PROD; +using tflite::BuiltinOperator_REDUCE_WINDOW; +using tflite::BuiltinOperator_RELU; +using tflite::BuiltinOperator_RELU6; +using tflite::BuiltinOperator_RELU_0_TO_1; +using tflite::BuiltinOperator_RELU_N1_TO_1; +using tflite::BuiltinOperator_RESHAPE; +using tflite::BuiltinOperator_RESIZE_BILINEAR; +using tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR; +using tflite::BuiltinOperator_REVERSE_SEQUENCE; +using tflite::BuiltinOperator_REVERSE_V2; +using tflite::BuiltinOperator_RFFT2D; +using tflite::BuiltinOperator_RIGHT_SHIFT; +using tflite::BuiltinOperator_RNN; +using tflite::BuiltinOperator_ROUND; +using tflite::BuiltinOperator_RSQRT; +using tflite::BuiltinOperator_SCATTER_ND; +using tflite::BuiltinOperator_SEGMENT_SUM; +using tflite::BuiltinOperator_SELECT; +using tflite::BuiltinOperator_SELECT_V2; +using tflite::BuiltinOperator_SHAPE; +using tflite::BuiltinOperator_SIN; +using tflite::BuiltinOperator_SKIP_GRAM; +using tflite::BuiltinOperator_SLICE; +using tflite::BuiltinOperator_SOFTMAX; +using tflite::BuiltinOperator_SPACE_TO_BATCH_ND; +using tflite::BuiltinOperator_SPACE_TO_DEPTH; +using tflite::BuiltinOperator_SPARSE_TO_DENSE; +using tflite::BuiltinOperator_SPLIT; +using tflite::BuiltinOperator_SPLIT_V; +using tflite::BuiltinOperator_SQRT; +using tflite::BuiltinOperator_SQUARE; +using tflite::BuiltinOperator_SQUARED_DIFFERENCE; +using tflite::BuiltinOperator_SQUEEZE; +using tflite::BuiltinOperator_STABLEHLO_ABS; +using tflite::BuiltinOperator_STABLEHLO_ADD; +using tflite::BuiltinOperator_STABLEHLO_AND; +using tflite::BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM; +using tflite::BuiltinOperator_STABLEHLO_CLAMP; +using tflite::BuiltinOperator_STABLEHLO_COMPARE; +using tflite::BuiltinOperator_STABLEHLO_COMPOSITE; +using tflite::BuiltinOperator_STABLEHLO_CONCATENATE; +using tflite::BuiltinOperator_STABLEHLO_CONVERT; +using tflite::BuiltinOperator_STABLEHLO_CONVOLUTION; +using tflite::BuiltinOperator_STABLEHLO_COSINE; +using tflite::BuiltinOperator_STABLEHLO_CUSTOM_CALL; +using tflite::BuiltinOperator_STABLEHLO_DIVIDE; +using tflite::BuiltinOperator_STABLEHLO_DOT_GENERAL; +using tflite::BuiltinOperator_STABLEHLO_DYNAMIC_SLICE; +using tflite::BuiltinOperator_STABLEHLO_DYNAMIC_UPDATE_SLICE; +using tflite::BuiltinOperator_STABLEHLO_EXPONENTIAL; +using tflite::BuiltinOperator_STABLEHLO_FLOOR; +using tflite::BuiltinOperator_STABLEHLO_GATHER; +using tflite::BuiltinOperator_STABLEHLO_IOTA; +using tflite::BuiltinOperator_STABLEHLO_LOG; +using tflite::BuiltinOperator_STABLEHLO_LOGISTIC; +using tflite::BuiltinOperator_STABLEHLO_MAXIMUM; +using tflite::BuiltinOperator_STABLEHLO_MINIMUM; +using tflite::BuiltinOperator_STABLEHLO_MULTIPLY; +using tflite::BuiltinOperator_STABLEHLO_NEGATE; +using tflite::BuiltinOperator_STABLEHLO_OR; +using tflite::BuiltinOperator_STABLEHLO_PAD; +using tflite::BuiltinOperator_STABLEHLO_POWER; +using tflite::BuiltinOperator_STABLEHLO_REDUCE; +using tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW; +using tflite::BuiltinOperator_STABLEHLO_REMAINDER; +using tflite::BuiltinOperator_STABLEHLO_RESHAPE; +using tflite::BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR; +using tflite::BuiltinOperator_STABLEHLO_RSQRT; +using tflite::BuiltinOperator_STABLEHLO_SCATTER; +using tflite::BuiltinOperator_STABLEHLO_SELECT; +using tflite::BuiltinOperator_STABLEHLO_SLICE; +using tflite::BuiltinOperator_STABLEHLO_SORT; +using tflite::BuiltinOperator_STABLEHLO_SUBTRACT; +using tflite::BuiltinOperator_STABLEHLO_TANH; +using tflite::BuiltinOperator_STABLEHLO_TRANSPOSE; +using tflite::BuiltinOperator_STABLEHLO_WHILE; +using tflite::BuiltinOperator_STRIDED_SLICE; +using tflite::BuiltinOperator_SUB; +using tflite::BuiltinOperator_SUM; +using tflite::BuiltinOperator_SVDF; +using tflite::BuiltinOperator_TANH; +using tflite::BuiltinOperator_TILE; +using tflite::BuiltinOperator_TOPK_V2; +using tflite::BuiltinOperator_TRANSPOSE; +using tflite::BuiltinOperator_TRANSPOSE_CONV; +using tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM; +using tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN; +using tflite::BuiltinOperator_UNIQUE; +using tflite::BuiltinOperator_UNPACK; +using tflite::BuiltinOperator_UNSORTED_SEGMENT_MAX; +using tflite::BuiltinOperator_UNSORTED_SEGMENT_MIN; +using tflite::BuiltinOperator_UNSORTED_SEGMENT_PROD; +using tflite::BuiltinOperator_VAR_HANDLE; +using tflite::BuiltinOperator_WHERE; +using tflite::BuiltinOperator_WHILE; +using tflite::BuiltinOperator_ZEROS_LIKE; +using tflite::CombinerType; +using tflite::CombinerType_MEAN; +using tflite::CombinerType_SQRTN; +using tflite::CombinerType_SUM; +using tflite::LSHProjectionType; +using tflite::LSHProjectionType_DENSE; +using tflite::LSHProjectionType_SPARSE; +using tflite::MirrorPadMode; +using tflite::MirrorPadMode_REFLECT; +using tflite::MirrorPadMode_SYMMETRIC; +using tflite::Operator; +using tflite::Padding; +using tflite::Padding_SAME; +using tflite::Padding_VALID; +using tflite::ReduceWindowFunction_ADD; +using tflite::ReduceWindowFunction_ALL; +using tflite::ReduceWindowFunction_ANY; +using tflite::ReduceWindowFunction_MAXIMUM; +using tflite::ReduceWindowFunction_MINIMUM; +using tflite::ReduceWindowFunction_MUL; +using tflite::ReduceWindowFunction_UNSUPPORTED; +using tflite::RngAlgorithm; +using tflite::RngAlgorithm_DEFAULT; +using tflite::RngAlgorithm_PHILOX; +using tflite::RngAlgorithm_THREEFRY; +using tflite::TensorType; +using tflite::TensorType_BFLOAT16; +using tflite::TensorType_BOOL; +using tflite::TensorType_COMPLEX128; +using tflite::TensorType_COMPLEX64; +using tflite::TensorType_FLOAT16; +using tflite::TensorType_FLOAT32; +using tflite::TensorType_FLOAT64; +using tflite::TensorType_INT16; +using tflite::TensorType_INT32; +using tflite::TensorType_INT4; +using tflite::TensorType_INT64; +using tflite::TensorType_INT8; +using tflite::TensorType_RESOURCE; +using tflite::TensorType_STRING; +using tflite::TensorType_UINT16; +using tflite::TensorType_UINT32; +using tflite::TensorType_UINT64; +using tflite::TensorType_UINT8; +using tflite::TensorType_VARIANT; +; +using tflite::AddOptions; +using tflite::ArgMaxOptions; +using tflite::ArgMinOptions; +using tflite::BuiltinOperator_ATAN2; +using tflite::BuiltinOperator_BITCAST; +using tflite::BuiltinOperator_CONCAT_EMBEDDINGS; +using tflite::BuiltinOperator_DILATE; +using tflite::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES; +using tflite::BuiltinOperator_SIGN; +using tflite::BuiltinOperator_STABLEHLO_CBRT; +using tflite::BuiltinOperator_STABLEHLO_SHIFT_LEFT; +using tflite::BuiltinOperator_UNSORTED_SEGMENT_SUM; +using tflite::BuiltinOperator_WHERE; +using tflite::CallOnceOptions; +using tflite::ConcatenationOptions; +using tflite::Conv2DOptions; +using tflite::DepthwiseConv2DOptions; +using tflite::FullyConnectedOptions; +using tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; +using tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; +using tflite::IfOptions; +using tflite::L2NormOptions; +using tflite::LSTMKernelType_BASIC; +using tflite::LSTMKernelType_FULL; +using tflite::MirrorPadOptions; +using tflite::MulOptions; +using tflite::PackOptions; +using tflite::Pool2DOptions; +using tflite::ReducerOptions; +using tflite::ReshapeOptions; +using tflite::ResizeBilinearOptions; +using tflite::ResizeNearestNeighborOptions; +using tflite::ShapeOptions; +using tflite::SoftmaxOptions; +using tflite::SplitOptions; +using tflite::SplitVOptions; +using tflite::SqueezeOptions; +using tflite::StableHLOCompositeOptions; +using tflite::StablehloGatherOptions; +using tflite::StablehloPadOptions; +using tflite::StablehloReduceWindowOptions; +using tflite::StablehloRngBitGeneratorOptions; +using tflite::StablehloScatterOptions; +using tflite::StridedSliceOptions; +using tflite::SubOptions; +using tflite::SVDFOptions; +using tflite::TransposeConvOptions; +using tflite::UnpackOptions; +using tflite::VarHandleOptions; +using tflite::WhileOptions; + +namespace { + +// Utility class for safely allocating POD data. This is useful for avoiding +// leaks in cases where op params are allocated but fail to propagate to the +// parsed op data (e.g., when model parameters are invalid). +class SafeBuiltinDataAllocator { + public: + class BuiltinDataDeleter { + public: + explicit BuiltinDataDeleter(BuiltinDataAllocator* allocator) + : allocator_(allocator) {} + + void operator()(void* data) { allocator_->Deallocate(data); } + + private: + BuiltinDataAllocator* allocator_; + }; + + template + using BuiltinDataPtr = std::unique_ptr; + + explicit SafeBuiltinDataAllocator(BuiltinDataAllocator* allocator) + : allocator_(allocator) {} + + template + BuiltinDataPtr Allocate() { + return BuiltinDataPtr(allocator_->AllocatePOD(), + BuiltinDataDeleter(allocator_)); + } + + private: + BuiltinDataAllocator* allocator_; +}; + +// All the Parse functions take some pointers as params and this function has +// the common DCHECKs to catch if any of those are nullptr. +void CheckParsePointerParams(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + TFLITE_DCHECK(op != nullptr); + TFLITE_DCHECK(allocator != nullptr); + TFLITE_DCHECK(builtin_data != nullptr); +} + +// Copies the contents from the flatbuffer int vector `flatbuffer` into the +// int array `buffer`. `flat_vector` and `buffer` represent the same +// configuration operation for a given operation. +template +static absl::Status FlatBufferIntVectorToArray( + int max_size_of_buffer, const flatbuffers::Vector* flat_vector, + DataType* buffer, const char* op_name) { + if (!flat_vector) { + auto error_message = absl::StrFormat( + "Input array not provided for operation '%s'.\n", op_name); + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } else { + size_t num_dimensions = flat_vector->size(); + if (num_dimensions > max_size_of_buffer / sizeof(DataType)) { + auto error_message = absl::StrFormat( + "Found too many dimensions in the input array of operation '%s'.\n", + op_name); + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } else { + for (size_t i = 0; i < num_dimensions; ++i) { + buffer[i] = flat_vector->Get(i); + } + } + } + return OkStatus(); +} + +// Converts the flatbuffer activation to what is used at runtime. +TfLiteFusedActivation ConvertActivation(ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU_N1_TO_1: + return kTfLiteActReluN1To1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; +} + +TfLitePadding ConvertPadding(Padding padding) { + switch (padding) { + case Padding_SAME: + return kTfLitePaddingSame; + case Padding_VALID: + return kTfLitePaddingValid; + } + return kTfLitePaddingUnknown; +} + +// Converts the flatbuffer mirror padding enum to what is used at runtime. +TfLiteMirrorPaddingMode ConvertMirrorPadding(MirrorPadMode padding) { + switch (padding) { + case MirrorPadMode_REFLECT: + return kTfLiteMirrorPaddingReflect; + case MirrorPadMode_SYMMETRIC: + return kTfLiteMirrorPaddingSymmetric; + } + return kTfLiteMirrorPaddingUnknown; +} + +TfLiteRngAlgorithm ConvertRngAlgorithm(RngAlgorithm algorithm) { + switch (algorithm) { + case RngAlgorithm_THREEFRY: + return kTfLiteRngAlgorithmThreefry; + case RngAlgorithm_PHILOX: + return kTfLiteRngAlgorithmPhilox; + case RngAlgorithm_DEFAULT: + return kTfLiteRngAlgorithmDefault; + } + return kTfLiteRngAlgorithmUnknown; +} + +#ifndef TF_LITE_STATIC_MEMORY +absl::Status ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, + BuiltinDataAllocator* allocator, + void** builtin_data) { + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + SafeBuiltinDataAllocator safe_allocator(allocator); + *builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_ABS: { + return ParseAbs(op, allocator, builtin_data); + } + + case BuiltinOperator_ADD: { + return ParseAdd(op, allocator, builtin_data); + } + + case BuiltinOperator_ADD_N: { + return ParseAddN(op, allocator, builtin_data); + } + + case BuiltinOperator_ARG_MAX: { + return ParseArgMax(op, allocator, builtin_data); + } + + case BuiltinOperator_ARG_MIN: { + return ParseArgMin(op, allocator, builtin_data); + } + + case BuiltinOperator_ASSIGN_VARIABLE: { + return ParseAssignVariable(op, allocator, builtin_data); + } + + case BuiltinOperator_AVERAGE_POOL_2D: { + return ParsePool(op, allocator, builtin_data); + } + + case BuiltinOperator_BATCH_MATMUL: { + return ParseBatchMatMul(op, allocator, builtin_data); + } + + case BuiltinOperator_BATCH_TO_SPACE_ND: { + return ParseBatchToSpaceNd(op, allocator, builtin_data); + } + + case BuiltinOperator_BROADCAST_ARGS: { + return ParseBroadcastArgs(op, allocator, builtin_data); + } + + case BuiltinOperator_BROADCAST_TO: { + return ParseBroadcastTo(op, allocator, builtin_data); + } + + case BuiltinOperator_CALL_ONCE: { + return ParseCallOnce(op, allocator, builtin_data); + } + + case BuiltinOperator_CEIL: { + return ParseCeil(op, allocator, builtin_data); + } + + case BuiltinOperator_CONCATENATION: { + return ParseConcatenation(op, allocator, builtin_data); + } + + case BuiltinOperator_CONV_2D: { + return ParseConv2D(op, allocator, builtin_data); + } + + case BuiltinOperator_CUMSUM: { + return ParseCumsum(op, allocator, builtin_data); + } + + case BuiltinOperator_DEPTH_TO_SPACE: { + return ParseDepthToSpace(op, allocator, builtin_data); + } + + case BuiltinOperator_DEPTHWISE_CONV_2D: { + return ParseDepthwiseConv2D(op, allocator, builtin_data); + } + + case BuiltinOperator_DEQUANTIZE: { + return ParseDequantize(op, allocator, builtin_data); + } + + case BuiltinOperator_DIV: { + return ParseDiv(op, allocator, builtin_data); + } + + case BuiltinOperator_ELU: { + return ParseElu(op, allocator, builtin_data); + } + + case BuiltinOperator_EMBEDDING_LOOKUP: { + return ParseEmbeddingLookup(op, allocator, builtin_data); + } + + case BuiltinOperator_EXP: { + return ParseExp(op, allocator, builtin_data); + } + + case BuiltinOperator_EXPAND_DIMS: { + return ParseExpandDims(op, allocator, builtin_data); + } + + case BuiltinOperator_FILL: { + return ParseFill(op, allocator, builtin_data); + } + + case BuiltinOperator_FLOOR: { + return ParseFloor(op, allocator, builtin_data); + } + + case BuiltinOperator_FLOOR_DIV: { + return ParseFloorDiv(op, allocator, builtin_data); + } + + case BuiltinOperator_FLOOR_MOD: { + return ParseFloorMod(op, allocator, builtin_data); + } + + case BuiltinOperator_FULLY_CONNECTED: { + return ParseFullyConnected(op, allocator, builtin_data); + } + + case BuiltinOperator_GATHER_ND: { + return ParseGatherNd(op, allocator, builtin_data); + } + + case BuiltinOperator_GREATER: { + return ParseGreater(op, allocator, builtin_data); + } + + case BuiltinOperator_GREATER_EQUAL: { + return ParseGreaterEqual(op, allocator, builtin_data); + } + + case BuiltinOperator_HARD_SWISH: { + return ParseHardSwish(op, allocator, builtin_data); + } + + case BuiltinOperator_L2_NORMALIZATION: { + return ParseL2Normalization(op, allocator, builtin_data); + } + + case BuiltinOperator_L2_POOL_2D: { + return ParsePool(op, allocator, builtin_data); + } + + case BuiltinOperator_LEAKY_RELU: { + return ParseLeakyRelu(op, allocator, builtin_data); + } + + case BuiltinOperator_LESS: { + return ParseLess(op, allocator, builtin_data); + } + + case BuiltinOperator_LESS_EQUAL: { + return ParseLessEqual(op, allocator, builtin_data); + } + + case BuiltinOperator_LOG: { + return ParseLog(op, allocator, builtin_data); + } + + case BuiltinOperator_LOGICAL_AND: { + return ParseLogicalAnd(op, allocator, builtin_data); + } + + case BuiltinOperator_LOGICAL_NOT: { + return ParseLogicalNot(op, allocator, builtin_data); + } + + case BuiltinOperator_LOGICAL_OR: { + return ParseLogicalOr(op, allocator, builtin_data); + } + + case BuiltinOperator_LOGISTIC: { + return ParseLogistic(op, allocator, builtin_data); + } + + case BuiltinOperator_LOG_SOFTMAX: { + return ParseLogSoftmax(op, allocator, builtin_data); + } + + case BuiltinOperator_LSTM: { + return ParseLSTM(op, allocator, builtin_data); + } + + case BuiltinOperator_MAXIMUM: { + return ParseMaximum(op, allocator, builtin_data); + } + + case BuiltinOperator_MAX_POOL_2D: { + return ParsePool(op, allocator, builtin_data); + } + + case BuiltinOperator_MIRROR_PAD: { + return ParseMirrorPad(op, allocator, builtin_data); + } + + case BuiltinOperator_MEAN: { + return ParseReducer(op, allocator, builtin_data); + } + + case BuiltinOperator_MINIMUM: { + return ParseMinimum(op, allocator, builtin_data); + } + + case BuiltinOperator_MUL: { + return ParseMul(op, allocator, builtin_data); + } + + case BuiltinOperator_NEG: { + return ParseNeg(op, allocator, builtin_data); + } + + case BuiltinOperator_NOT_EQUAL: { + return ParseNotEqual(op, allocator, builtin_data); + } + + case BuiltinOperator_PACK: { + return ParsePack(op, allocator, builtin_data); + } + + case BuiltinOperator_PAD: { + return ParsePad(op, allocator, builtin_data); + } + + case BuiltinOperator_PADV2: { + return ParsePadV2(op, allocator, builtin_data); + } + + case BuiltinOperator_POW: { + return ParsePow(op, allocator, builtin_data); + } + + case BuiltinOperator_PRELU: { + return ParsePrelu(op, allocator, builtin_data); + } + + case BuiltinOperator_QUANTIZE: { + return ParseQuantize(op, allocator, builtin_data); + } + + case BuiltinOperator_READ_VARIABLE: { + return ParseReadVariable(op, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_ANY: { + return ParseReducer(op, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_ALL: { + return ParseReducer(op, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_MAX: { + return ParseReducer(op, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_MIN: { + return ParseReducer(op, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_PROD: { + return ParseReducer(op, allocator, builtin_data); + } + + case BuiltinOperator_RELU: { + return ParseRelu(op, allocator, builtin_data); + } + + case BuiltinOperator_RELU6: { + return ParseRelu6(op, allocator, builtin_data); + } + + case BuiltinOperator_RESHAPE: { + return ParseReshape(op, allocator, builtin_data); + } + + case BuiltinOperator_RESIZE_BILINEAR: { + return ParseResizeBilinear(op, allocator, builtin_data); + } + + case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: { + return ParseResizeNearestNeighbor(op, allocator, builtin_data); + } + + case BuiltinOperator_ROUND: { + return ParseRound(op, allocator, builtin_data); + } + + case BuiltinOperator_RSQRT: { + return ParseRsqrt(op, allocator, builtin_data); + } + + case BuiltinOperator_SELECT_V2: { + return ParseSelectV2(op, allocator, builtin_data); + } + + case BuiltinOperator_SHAPE: { + return ParseShape(op, allocator, builtin_data); + } + + case BuiltinOperator_SIN: { + return ParseSin(op, allocator, builtin_data); + } + + case BuiltinOperator_SOFTMAX: { + return ParseSoftmax(op, allocator, builtin_data); + } + + case BuiltinOperator_SPACE_TO_BATCH_ND: { + return ParseSpaceToBatchNd(op, allocator, builtin_data); + } + + case BuiltinOperator_SPACE_TO_DEPTH: { + return ParseSpaceToDepth(op, allocator, builtin_data); + } + + case BuiltinOperator_SPLIT: { + return ParseSplit(op, allocator, builtin_data); + } + + case BuiltinOperator_SPLIT_V: { + return ParseSplitV(op, allocator, builtin_data); + } + + case BuiltinOperator_SQRT: { + return ParseSqrt(op, allocator, builtin_data); + } + + case BuiltinOperator_SQUARE: { + return ParseSquare(op, allocator, builtin_data); + } + + case BuiltinOperator_SQUARED_DIFFERENCE: { + return ParseSquaredDifference(op, allocator, builtin_data); + } + + case BuiltinOperator_SQUEEZE: { + return ParseSqueeze(op, allocator, builtin_data); + } + + case BuiltinOperator_STRIDED_SLICE: { + return ParseStridedSlice(op, allocator, builtin_data); + } + + case BuiltinOperator_SUB: { + return ParseSub(op, allocator, builtin_data); + } + + case BuiltinOperator_SUM: { + return ParseReducer(op, allocator, builtin_data); + } + + case BuiltinOperator_SVDF: { + return ParseSvdf(op, allocator, builtin_data); + } + + case BuiltinOperator_TANH: { + return ParseTanh(op, allocator, builtin_data); + } + + case BuiltinOperator_TRANSPOSE_CONV: { + return ParseTransposeConv(op, allocator, builtin_data); + } + + case BuiltinOperator_UNPACK: { + return ParseUnpack(op, allocator, builtin_data); + } + + case BuiltinOperator_VAR_HANDLE: { + return ParseVarHandle(op, allocator, builtin_data); + } + + case BuiltinOperator_ZEROS_LIKE: { + return ParseZerosLike(op, allocator, builtin_data); + } + + case BuiltinOperator_BITWISE_XOR: { + return ParseBitwiseXor(op, allocator, builtin_data); + } + + case BuiltinOperator_RIGHT_SHIFT: { + return ParseRightShift(op, allocator, builtin_data); + } + + case BuiltinOperator_CAST: { + return ParseCast(op, allocator, builtin_data); + } + case BuiltinOperator_LSH_PROJECTION: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* lshParams = + op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* sequence_rnn_params = + op->builtin_options_as_SequenceRNNOptions()) { + params->activation = + ConvertActivation(sequence_rnn_params->fused_activation_function()); + params->time_major = sequence_rnn_params->time_major(); + params->asymmetric_quantize_inputs = + sequence_rnn_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { + auto params = + safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* bidi_sequence_rnn_params = + op->builtin_options_as_BidirectionalSequenceRNNOptions()) { + params->activation = ConvertActivation( + bidi_sequence_rnn_params->fused_activation_function()); + params->time_major = bidi_sequence_rnn_params->time_major(); + params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); + params->asymmetric_quantize_inputs = + bidi_sequence_rnn_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_RNN: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + ConvertActivation(rnn_params->fused_activation_function()); + params->asymmetric_quantize_inputs = + rnn_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + auto params = + safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + *builtin_data = params.release(); + return OkStatus(); + } + + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + return OkStatus(); + + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { + return ParseUnidirectionalSequenceLSTM(op, allocator, builtin_data); + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { + auto params = + safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* bidi_lstm_params = + op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { + params->activation = + ConvertActivation(bidi_lstm_params->fused_activation_function()); + params->cell_clip = bidi_lstm_params->cell_clip(); + params->proj_clip = bidi_lstm_params->proj_clip(); + params->merge_outputs = bidi_lstm_params->merge_outputs(); + params->time_major = bidi_lstm_params->time_major(); + params->asymmetric_quantize_inputs = + bidi_lstm_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_SKIP_GRAM: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* skip_gram_params = + op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + *builtin_data = params.release(); + return OkStatus(); + } + + case BuiltinOperator_GATHER: { + return ParseGather(op, allocator, builtin_data); + } + case BuiltinOperator_SPARSE_TO_DENSE: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_DELEGATE: { + auto error_msg = "DELEGATE op shouldn't exist in model."; + LOG(ERROR) << error_msg; + return absl::InvalidArgumentError(error_msg); + } + case BuiltinOperator_FAKE_QUANT: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* schema_params = + op->builtin_options_as_FakeQuantOptions()) { + params->min = schema_params->min(); + params->max = schema_params->max(); + params->num_bits = schema_params->num_bits(); + params->narrow_range = schema_params->narrow_range(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_ONE_HOT: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* schema_params = op->builtin_options_as_OneHotOptions()) { + params->axis = schema_params->axis(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_UNIQUE: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + const auto* unique_params = op->builtin_options_as_UniqueOptions(); + if (unique_params != nullptr) { + params->index_out_type = + unique_params->idx_out_type() == tflite::TensorType_INT64 + ? TfLiteType::kTfLiteInt64 + : TfLiteType::kTfLiteInt32; + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_REVERSE_SEQUENCE: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* reverse_seq_params = + op->builtin_options_as_ReverseSequenceOptions()) { + params->seq_dim = reverse_seq_params->seq_dim(); + params->batch_dim = reverse_seq_params->batch_dim(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_IF: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* if_params = op->builtin_options_as_IfOptions()) { + params->then_subgraph_index = if_params->then_subgraph_index(); + params->else_subgraph_index = if_params->else_subgraph_index(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_WHILE: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* while_params = op->builtin_options_as_WhileOptions()) { + params->cond_subgraph_index = while_params->cond_subgraph_index(); + params->body_subgraph_index = while_params->body_subgraph_index(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_CONV_3D: + case BuiltinOperator_CONV_3D_TRANSPOSE: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* conv3d_params = op->builtin_options_as_Conv3DOptions()) { + params->padding = ConvertPadding(conv3d_params->padding()); + params->activation = + ConvertActivation(conv3d_params->fused_activation_function()); + params->stride_depth = conv3d_params->stride_d(); + params->stride_height = conv3d_params->stride_h(); + params->stride_width = conv3d_params->stride_w(); + params->dilation_depth_factor = conv3d_params->dilation_d_factor(); + params->dilation_height_factor = conv3d_params->dilation_h_factor(); + params->dilation_width_factor = conv3d_params->dilation_w_factor(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_HASHTABLE: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* hashtable_params = + op->builtin_options_as_HashtableOptions()) { + params->table_id = hashtable_params->table_id(); + TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( + hashtable_params->key_dtype(), ¶ms->key_dtype)); + TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( + hashtable_params->value_dtype(), ¶ms->value_dtype)); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_MULTINOMIAL: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* multinomial_params = + op->builtin_options_as_RandomOptions()) { + params->seed = multinomial_params->seed(); + params->seed2 = multinomial_params->seed2(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_RANDOM_STANDARD_NORMAL: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* random_std_normal_params = + op->builtin_options_as_RandomOptions()) { + params->seed = random_std_normal_params->seed(); + params->seed2 = random_std_normal_params->seed2(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_BUCKETIZE: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* bucketize_params = + op->builtin_options_as_BucketizeOptions()) { + const flatbuffers::Vector* boundaries = + bucketize_params->boundaries(); + if (boundaries == nullptr) { + auto error_message = + "boundaries array not provided for operation 'bucketize'.\n"; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + params->num_boundaries = boundaries->size(); + if (boundaries->data() == nullptr) { + auto error_message = + "boundaries.data() returned nullptr for " + "operation 'bucketize'.\n"; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + params->boundaries = boundaries->data(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_RANDOM_UNIFORM: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* random_uniform_params = + op->builtin_options_as_RandomOptions()) { + params->seed = random_uniform_params->seed(); + params->seed2 = random_uniform_params->seed2(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_GELU: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* gelu_params = op->builtin_options_as_GeluOptions()) { + params->approximate = gelu_params->approximate(); + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_STABLEHLO_SCATTER: { + return ParseStablehloScatter(op, allocator, builtin_data); + } + case BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR: { + return ParseStablehloRngBitGenerator(op, allocator, builtin_data); + } + case BuiltinOperator_STABLEHLO_GATHER: { + return ParseStablehloGather(op, allocator, builtin_data); + } + case BuiltinOperator_STABLEHLO_REDUCE_WINDOW: { + return ParseStablehloReduceWindow(op, allocator, builtin_data); + } + case BuiltinOperator_REDUCE_WINDOW: { + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* reduce_params = + op->builtin_options_2_as_ReduceWindowOptions()) { + switch (reduce_params->reduce_function()) { + case ReduceWindowFunction_ADD: + params->reduce_function = TfLiteReduceWindowFunctionAdd; + break; + case ReduceWindowFunction_MUL: + params->reduce_function = TfLiteReduceWindowFunctionMul; + break; + case ReduceWindowFunction_MINIMUM: + params->reduce_function = TfLiteReduceWindowFunctionMin; + break; + case ReduceWindowFunction_MAXIMUM: + params->reduce_function = TfLiteReduceWindowFunctionMax; + break; + case ReduceWindowFunction_ALL: + params->reduce_function = TfLiteReduceWindowFunctionAll; + break; + case ReduceWindowFunction_ANY: + params->reduce_function = TfLiteReduceWindowFunctionAny; + break; + case ReduceWindowFunction_UNSUPPORTED: + default: + return absl::InvalidArgumentError("Unsupported reduce function"); + } + } + *builtin_data = params.release(); + return OkStatus(); + } + case BuiltinOperator_STABLEHLO_PAD: { + return ParseStablehloPad(op, allocator, builtin_data); + } + case BuiltinOperator_STABLEHLO_COMPOSITE: { + return ParseStablehloComposite(op, allocator, builtin_data); + } + case BuiltinOperator_STABLEHLO_SHIFT_LEFT: { + return ParseStablehloShiftLeft(op, allocator, builtin_data); + } + // TODO: skip param parsing for now since ops below don't have kernels + case BuiltinOperator_STABLEHLO_SLICE: + case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM: + case BuiltinOperator_STABLEHLO_CONVOLUTION: + case BuiltinOperator_STABLEHLO_LOGISTIC: + case BuiltinOperator_STABLEHLO_ADD: + case BuiltinOperator_STABLEHLO_DIVIDE: + case BuiltinOperator_STABLEHLO_MULTIPLY: + case BuiltinOperator_STABLEHLO_MAXIMUM: + case BuiltinOperator_STABLEHLO_RESHAPE: + case BuiltinOperator_STABLEHLO_CLAMP: + case BuiltinOperator_STABLEHLO_CONCATENATE: + case BuiltinOperator_STABLEHLO_CUSTOM_CALL: + case BuiltinOperator_STABLEHLO_REDUCE: + case BuiltinOperator_STABLEHLO_ABS: + case BuiltinOperator_STABLEHLO_AND: + case BuiltinOperator_STABLEHLO_COSINE: + case BuiltinOperator_STABLEHLO_EXPONENTIAL: + case BuiltinOperator_STABLEHLO_FLOOR: + case BuiltinOperator_STABLEHLO_LOG: + case BuiltinOperator_STABLEHLO_MINIMUM: + case BuiltinOperator_STABLEHLO_NEGATE: + case BuiltinOperator_STABLEHLO_OR: + case BuiltinOperator_STABLEHLO_POWER: + case BuiltinOperator_STABLEHLO_REMAINDER: + case BuiltinOperator_STABLEHLO_RSQRT: + case BuiltinOperator_STABLEHLO_SELECT: + case BuiltinOperator_STABLEHLO_SUBTRACT: + case BuiltinOperator_STABLEHLO_TANH: + case BuiltinOperator_STABLEHLO_DYNAMIC_SLICE: + case BuiltinOperator_STABLEHLO_DYNAMIC_UPDATE_SLICE: + case BuiltinOperator_STABLEHLO_IOTA: + case BuiltinOperator_STABLEHLO_COMPARE: + case BuiltinOperator_STABLEHLO_CONVERT: + case BuiltinOperator_STABLEHLO_DOT_GENERAL: + case BuiltinOperator_STABLEHLO_SORT: + case BuiltinOperator_STABLEHLO_WHILE: + case BuiltinOperator_STABLEHLO_TRANSPOSE: + case BuiltinOperator_STABLEHLO_CBRT: + + // Below are the ops with no builtin_data structure. + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_COMPLEX_ABS: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_COS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_DENSIFY: + case BuiltinOperator_DYNAMIC_UPDATE_SLICE: + case BuiltinOperator_EQUAL: + case BuiltinOperator_HASHTABLE_FIND: + case BuiltinOperator_HASHTABLE_IMPORT: + case BuiltinOperator_HASHTABLE_SIZE: + case BuiltinOperator_IMAG: + case BuiltinOperator_MATRIX_DIAG: + case BuiltinOperator_MATRIX_SET_DIAG: + case BuiltinOperator_NON_MAX_SUPPRESSION_V4: + case BuiltinOperator_NON_MAX_SUPPRESSION_V5: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RELU_0_TO_1: + case BuiltinOperator_SCATTER_ND: + case BuiltinOperator_SELECT: + case BuiltinOperator_SLICE: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: + case BuiltinOperator_RANGE: + case BuiltinOperator_RANK: + case BuiltinOperator_REAL: + case BuiltinOperator_RFFT2D: + case BuiltinOperator_SEGMENT_SUM: + case BuiltinOperator_REVERSE_V2: + case BuiltinOperator_UNSORTED_SEGMENT_MAX: + case BuiltinOperator_UNSORTED_SEGMENT_MIN: + case BuiltinOperator_UNSORTED_SEGMENT_PROD: + case BuiltinOperator_UNSORTED_SEGMENT_SUM: + case BuiltinOperator_ATAN2: + case BuiltinOperator_SIGN: + case BuiltinOperator_BITCAST: + case BuiltinOperator_WHERE: + case BuiltinOperator_DILATE: + return OkStatus(); + case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES: + return absl::UnimplementedError("Unsupported op"); + } + return absl::UnimplementedError("Unsupported op"); +} // NOLINT[readability/fn_size] +#endif // !defined(TF_LITE_STATIC_MEMORY) +} // namespace + +absl::Status ConvertTensorType(TensorType tensor_type, TfLiteType* type) { + switch (tensor_type) { + case TensorType_FLOAT16: + *type = kTfLiteFloat16; + return OkStatus(); + case TensorType_BFLOAT16: + *type = kTfLiteBFloat16; + return OkStatus(); + case TensorType_FLOAT32: + *type = kTfLiteFloat32; + return OkStatus(); + case TensorType_FLOAT64: + *type = kTfLiteFloat64; + return OkStatus(); + case TensorType_INT16: + *type = kTfLiteInt16; + return OkStatus(); + case TensorType_UINT16: + *type = kTfLiteUInt16; + return OkStatus(); + case TensorType_INT32: + *type = kTfLiteInt32; + return OkStatus(); + case TensorType_UINT32: + *type = kTfLiteUInt32; + return OkStatus(); + case TensorType_UINT8: + *type = kTfLiteUInt8; + return OkStatus(); + case TensorType_INT8: + *type = kTfLiteInt8; + return OkStatus(); + case TensorType_INT64: + *type = kTfLiteInt64; + return OkStatus(); + case TensorType_UINT64: + *type = kTfLiteUInt64; + return OkStatus(); + case TensorType_STRING: + *type = kTfLiteString; + return OkStatus(); + case TensorType_BOOL: + *type = kTfLiteBool; + return OkStatus(); + case TensorType_COMPLEX64: + *type = kTfLiteComplex64; + return OkStatus(); + case TensorType_COMPLEX128: + *type = kTfLiteComplex128; + return OkStatus(); + case TensorType_RESOURCE: + *type = kTfLiteResource; + return OkStatus(); + case TensorType_VARIANT: + *type = kTfLiteVariant; + return OkStatus(); + case TensorType_INT4: + *type = kTfLiteInt4; + return OkStatus(); + default: + *type = kTfLiteNoType; + auto error_message = + absl::StrFormat("Unsupported data type %d in tensor", tensor_type); + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseAbs(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseAdd(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const AddOptions* schema_params = op->builtin_options_as_AddOptions(); + + if (schema_params != nullptr) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + params->pot_scale_int16 = schema_params->pot_scale_int16(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseAddN(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + return OkStatus(); +} + +absl::Status ParseArgMax(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const ArgMaxOptions* schema_params = op->builtin_options_as_ArgMaxOptions(); + + if (schema_params != nullptr) { + TFL_MIGRATION_ENSURE_STATUS( + ConvertTensorType(schema_params->output_type(), ¶ms->output_type)); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseArgMin(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const ArgMinOptions* schema_params = op->builtin_options_as_ArgMinOptions(); + + if (schema_params != nullptr) { + TFL_MIGRATION_ENSURE_STATUS( + ConvertTensorType(schema_params->output_type(), ¶ms->output_type)); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseAssignVariable(const Operator*, BuiltinDataAllocator*, + void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseBatchMatMul(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* bmm_params = op->builtin_options_as_BatchMatMulOptions()) { + params->adj_x = bmm_params->adj_x(); + params->adj_y = bmm_params->adj_y(); + params->asymmetric_quantize_inputs = + bmm_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseBatchToSpaceNd(const Operator*, BuiltinDataAllocator*, + void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseBroadcastArgs(const Operator*, BuiltinDataAllocator*, + void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseBroadcastTo(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseCallOnce(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const CallOnceOptions* schema_params = + op->builtin_options_as_CallOnceOptions(); + + if (schema_params != nullptr) { + params->init_subgraph_index = schema_params->init_subgraph_index(); + + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseCast(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* schema_params = op->builtin_options_as_CastOptions()) { + TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType(schema_params->in_data_type(), + ¶ms->in_data_type)); + TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( + schema_params->out_data_type(), ¶ms->out_data_type)); + } + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseCeil(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseConcatenation(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const ConcatenationOptions* schema_params = + op->builtin_options_as_ConcatenationOptions(); + + if (schema_params != nullptr) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + params->axis = schema_params->axis(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseConv2D(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const Conv2DOptions* schema_params = op->builtin_options_as_Conv2DOptions(); + + if (schema_params != nullptr) { + params->padding = ConvertPadding(schema_params->padding()); + params->stride_width = schema_params->stride_w(); + params->stride_height = schema_params->stride_h(); + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + + params->dilation_width_factor = schema_params->dilation_w_factor(); + params->dilation_height_factor = schema_params->dilation_h_factor(); + TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( + schema_params->quantized_bias_type(), ¶ms->quantized_bias_type)); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseCumsum(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* cumsum_params = op->builtin_options_as_CumsumOptions()) { + params->exclusive = cumsum_params->exclusive(); + params->reverse = cumsum_params->reverse(); + } + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseCos(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseDepthToSpace(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const auto* schema_params = op->builtin_options_as_DepthToSpaceOptions(); + if (schema_params != nullptr) { + params->block_size = schema_params->block_size(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseDepthwiseConv2D(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const DepthwiseConv2DOptions* schema_params = + op->builtin_options_as_DepthwiseConv2DOptions(); + + if (schema_params != nullptr) { + params->padding = ConvertPadding(schema_params->padding()); + params->stride_width = schema_params->stride_w(); + params->stride_height = schema_params->stride_h(); + params->depth_multiplier = schema_params->depth_multiplier(); + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + + params->dilation_width_factor = schema_params->dilation_w_factor(); + params->dilation_height_factor = schema_params->dilation_h_factor(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseDequantize(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseDiv(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* schema_params = op->builtin_options_as_DivOptions()) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + } + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseElu(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseEmbeddingLookup(const Operator*, BuiltinDataAllocator*, + void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseEqual(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseExp(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseExpandDims(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseFill(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseFloor(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseFloorDiv(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseFloorMod(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseFullyConnected(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const FullyConnectedOptions* schema_params = + op->builtin_options_as_FullyConnectedOptions(); + + if (schema_params != nullptr) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + params->keep_num_dims = schema_params->keep_num_dims(); + params->asymmetric_quantize_inputs = + schema_params->asymmetric_quantize_inputs(); + TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( + schema_params->quantized_bias_type(), ¶ms->quantized_bias_type)); + switch (schema_params->weights_format()) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + params->weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + break; + default: + auto error_message = "Unhandled fully-connected weights format."; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseGather(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + params->axis = 0; + params->batch_dims = 0; + if (const auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + params->batch_dims = gather_params->batch_dims(); + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseGatherNd(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseGreater(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseGreaterEqual(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseHardSwish(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseIf(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const IfOptions* schema_params = op->builtin_options_as_IfOptions(); + + if (schema_params != nullptr) { + params->then_subgraph_index = schema_params->then_subgraph_index(); + params->else_subgraph_index = schema_params->else_subgraph_index(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseL2Normalization(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const L2NormOptions* schema_params = op->builtin_options_as_L2NormOptions(); + + if (schema_params != nullptr) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseLeakyRelu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* leaky_relu_params = + op->builtin_options_as_LeakyReluOptions()) { + params->alpha = leaky_relu_params->alpha(); + } + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseLess(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseLessEqual(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseLog(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseLogicalAnd(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseLogicalNot(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseLogicalOr(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseLogistic(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseLogSoftmax(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseLSTM(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + ConvertActivation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + switch (lstm_params->kernel_type()) { + case LSTMKernelType_FULL: + params->kernel_type = kTfLiteLSTMFullKernel; + break; + case LSTMKernelType_BASIC: + params->kernel_type = kTfLiteLSTMBasicKernel; + break; + default: + auto error_message = absl::StrFormat("Unhandled LSTM kernel type: %d", + lstm_params->kernel_type()); + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + params->asymmetric_quantize_inputs = + lstm_params->asymmetric_quantize_inputs(); + } else { + auto error_message = "No valid LSTM builtin options exist"; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseMaximum(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseMinimum(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseMirrorPad(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const MirrorPadOptions* schema_params = + op->builtin_options_as_MirrorPadOptions(); + + if (schema_params != nullptr) { + params->mode = ConvertMirrorPadding(schema_params->mode()); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseMul(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const MulOptions* schema_params = op->builtin_options_as_MulOptions(); + + if (schema_params != nullptr) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseNeg(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseNotEqual(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParsePack(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const PackOptions* schema_params = op->builtin_options_as_PackOptions(); + + if (schema_params != nullptr) { + params->values_count = schema_params->values_count(); + params->axis = schema_params->axis(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParsePad(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParsePadV2(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParsePool(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const Pool2DOptions* schema_params = op->builtin_options_as_Pool2DOptions(); + + if (schema_params != nullptr) { + params->padding = ConvertPadding(schema_params->padding()); + params->stride_width = schema_params->stride_w(); + params->stride_height = schema_params->stride_h(); + params->filter_width = schema_params->filter_width(); + params->filter_height = schema_params->filter_height(); + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParsePow(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParsePrelu(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseQuantize(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseReadVariable(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseReducer(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const ReducerOptions* schema_params = op->builtin_options_as_ReducerOptions(); + + if (schema_params != nullptr) { + params->keep_dims = schema_params->keep_dims(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseRelu(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseRelu6(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseReshape(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const ReshapeOptions* schema_params = op->builtin_options_as_ReshapeOptions(); + + if (schema_params != nullptr) { + const flatbuffers::Vector* new_shape = schema_params->new_shape(); + if (new_shape != nullptr) { + TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( + sizeof(params->shape), new_shape, params->shape, "reshape")); + params->num_dimensions = new_shape->size(); + } else { + // TODO(b/157480169) TODO(b/147203660): We should either return + // kTfLiteError or fill in some reasonable defaults in the params struct. + // We are not doing so until we better undertand the ramifications of + // changing the legacy behavior. + } + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseResizeBilinear(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const ResizeBilinearOptions* schema_params = + op->builtin_options_as_ResizeBilinearOptions(); + + if (schema_params != nullptr) { + params->align_corners = schema_params->align_corners(); + params->half_pixel_centers = schema_params->half_pixel_centers(); + } else { + params->align_corners = false; + params->half_pixel_centers = false; + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseResizeNearestNeighbor(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const ResizeNearestNeighborOptions* schema_params = + op->builtin_options_as_ResizeNearestNeighborOptions(); + + if (schema_params != nullptr) { + params->align_corners = schema_params->align_corners(); + params->half_pixel_centers = schema_params->half_pixel_centers(); + } else { + params->align_corners = false; + params->half_pixel_centers = false; + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseStablehloReduceWindow(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + + const StablehloReduceWindowOptions* schema_params = + op->builtin_options_2_as_StablehloReduceWindowOptions(); + if (schema_params) { + if (!schema_params->window_dimensions() || + schema_params->window_dimensions()->size() == 0) { + auto error_message = + "'window_dimensions' attribute is not optional for " + "'stablehlo.reduce_window' and cannot be empty."; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + + const size_t rank = schema_params->window_dimensions()->size(); + + auto LoadAttr = [](int64_t* params_array, size_t params_array_size_bytes, + const flatbuffers::Vector* flatbuffer_vector, + const char* attr_name, const size_t expected_size, + const int64_t fill_value) -> absl::Status { + if (flatbuffer_vector && flatbuffer_vector->size()) { + if (expected_size != 0 && flatbuffer_vector->size() != expected_size) { + auto error_message = absl::StrFormat( + "'%s' attribute of 'stablehlo.reduce_window' does not have the " + "expected size (%llu != %llu).", + attr_name, flatbuffer_vector->size(), expected_size); + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + absl::Status status = FlatBufferIntVectorToArray( + params_array_size_bytes, flatbuffer_vector, params_array, + "stablehlo.reduce_window"); + if (!status.ok()) { + auto error_message = absl::StrFormat("%s Check the '%s' attribute.", + status.message(), attr_name); + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + } else { + std::fill_n(params_array, params_array_size_bytes / sizeof(int64_t), + fill_value); + } + return OkStatus(); + }; + + TFL_MIGRATION_ENSURE_STATUS( + LoadAttr(params->window_dimensions, sizeof(params->window_dimensions), + schema_params->window_dimensions(), "window_dimensions", + /*expected_size=*/rank, /*fill_value=*/1)); + TFL_MIGRATION_ENSURE_STATUS( + LoadAttr(params->window_strides, sizeof(params->window_strides), + schema_params->window_strides(), "window_strides", + /*expected_size=*/rank, /*fill_value=*/1)); + TFL_MIGRATION_ENSURE_STATUS( + LoadAttr(params->base_dilations, sizeof(params->base_dilations), + schema_params->base_dilations(), "base_dilations", + /*expected_size=*/rank, /*fill_value=*/1)); + TFL_MIGRATION_ENSURE_STATUS( + LoadAttr(params->window_dilations, sizeof(params->window_dilations), + schema_params->window_dilations(), "window_dilations", + /*expected_size=*/rank, /*fill_value=*/1)); + TFL_MIGRATION_ENSURE_STATUS(LoadAttr(params->padding, + sizeof(params->padding), + schema_params->padding(), "padding", + /*expected_size=*/2 * rank, + /*fill_value=*/0)); + + params->body_subgraph_index = schema_params->body_subgraph_index(); + *builtin_data = params.release(); + return OkStatus(); + } + auto error_message = + "Could not get 'stablehlo.reduce_window' operation parameters."; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); +} + +absl::Status ParseStablehloScatter(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const StablehloScatterOptions* schema_params = + op->builtin_options_2_as_StablehloScatterOptions(); + if (schema_params) { + params->indices_are_sorted = schema_params->indices_are_sorted(); + + if (schema_params->update_window_dims()) { + TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->update_window_dims()->size() * sizeof(int64_t), + schema_params->update_window_dims(), params->update_window_dims, + "stablehlo_scatter")); + params->num_update_window_dims = + schema_params->update_window_dims()->size(); + } + + if (schema_params->inserted_window_dims()) { + TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->inserted_window_dims()->size() * sizeof(int64_t), + schema_params->inserted_window_dims(), params->inserted_window_dims, + "stablehlo_scatter")); + params->num_inserted_window_dims = + schema_params->inserted_window_dims()->size(); + } + + if (schema_params->scatter_dims_to_operand_dims()) { + TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->scatter_dims_to_operand_dims()->size() * + sizeof(int64_t), + schema_params->scatter_dims_to_operand_dims(), + params->scatter_dims_to_operand_dims, "stablehlo_scatter")); + params->num_scatter_dims_to_operand_dims = + schema_params->scatter_dims_to_operand_dims()->size(); + } + + params->index_vector_dim = schema_params->index_vector_dim(); + params->unique_indices = schema_params->unique_indices(); + params->update_computation_subgraph_index = + schema_params->update_computation_subgraph_index(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseStablehloRngBitGenerator(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const StablehloRngBitGeneratorOptions* schema_params = + op->builtin_options_2_as_StablehloRngBitGeneratorOptions(); + if (schema_params != nullptr) { + params->algorithm = ConvertRngAlgorithm(schema_params->algorithm()); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseStablehloGather(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const StablehloGatherOptions* schema_params = + op->builtin_options_2_as_StablehloGatherOptions(); + + if (schema_params != nullptr) { + TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( + /*max_size_of_buffer=*/schema_params->offset_dims()->size() * + sizeof(int64_t), + /*flat_vector=*/schema_params->offset_dims(), + /*buffer=*/params->offset_dims, + /*op_name=*/"stablehlo_gather")); + params->num_offset_dims = schema_params->offset_dims()->size(); + + TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->collapsed_slice_dims()->size() * sizeof(int64_t), + schema_params->collapsed_slice_dims(), params->collapsed_slice_dims, + "stablehlo_gather")); + params->num_collapsed_slice_dims = + schema_params->collapsed_slice_dims()->size(); + + TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->start_index_map()->size() * sizeof(int64_t), + schema_params->start_index_map(), params->start_index_map, + "stablehlo_gather")); + params->num_start_index_map = schema_params->start_index_map()->size(); + + params->index_vector_dim = schema_params->index_vector_dim(); + + TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->slice_sizes()->size() * sizeof(int64_t), + schema_params->slice_sizes(), params->slice_sizes, "stablehlo_gather")); + params->num_slice_sizes = schema_params->slice_sizes()->size(); + + params->indices_are_sorted = schema_params->indices_are_sorted(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseStablehloPad(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + const StablehloPadOptions* schema_params = + op->builtin_options_2_as_StablehloPadOptions(); + + if (schema_params) { + auto LoadAttr = + [](int64_t* params_array, const size_t params_array_size_bytes, + const flatbuffers::Vector* const flatbuffer_vector, + const char* const attr_name) -> absl::Status { + absl::Status status = + FlatBufferIntVectorToArray(params_array_size_bytes, flatbuffer_vector, + params_array, "stablehlo.pad"); + if (!status.ok()) { + auto error_message = absl::StrFormat("%s Check the '%s' attribute.", + status.message(), attr_name); + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + return status; + }; + + TFL_MIGRATION_ENSURE_STATUS( + LoadAttr(params->edge_padding_low, sizeof(params->edge_padding_low), + schema_params->edge_padding_low(), "edge_padding_low")); + TFL_MIGRATION_ENSURE_STATUS( + LoadAttr(params->edge_padding_high, sizeof(params->edge_padding_high), + schema_params->edge_padding_high(), "edge_padding_high")); + TFL_MIGRATION_ENSURE_STATUS( + LoadAttr(params->interior_padding, sizeof(params->interior_padding), + schema_params->interior_padding(), "interior_padding")); + if (schema_params->edge_padding_low()->size() != + schema_params->edge_padding_high()->size() || + schema_params->edge_padding_low()->size() != + schema_params->interior_padding()->size()) { + auto error_message = + "'stablehlo.pad' operation parameter array sizes are not consistent."; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); + } + *builtin_data = params.release(); + return OkStatus(); + } + auto error_message = "Could not get 'stablehlo.pad' operation parameters."; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); +} + +absl::Status ParseStablehloComposite(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + const StableHLOCompositeOptions* schema_params = + op->builtin_options_2_as_StableHLOCompositeOptions(); + if (schema_params) { + params->name = schema_params->name()->c_str(); + params->version = schema_params->version(); + params->subgraph_index = schema_params->decomposition_subgraph_index(); + params->attributes = schema_params->composite_attributes()->data(); + params->attributes_size = schema_params->composite_attributes()->size(); + *builtin_data = params.release(); + return OkStatus(); + } + auto error_message = + "Could not get 'stablehlo.composite' operation parameters."; + LOG(ERROR) << error_message; + return absl::InvalidArgumentError(error_message); +} + +absl::Status ParseStablehloShiftLeft(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseRound(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseRsqrt(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseSelectV2(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseShape(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const ShapeOptions* schema_params = op->builtin_options_as_ShapeOptions(); + + if (schema_params != nullptr) { + TFL_MIGRATION_ENSURE_STATUS( + ConvertTensorType(schema_params->out_type(), ¶ms->out_type)); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseSin(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseSlice(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseSoftmax(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const SoftmaxOptions* schema_params = op->builtin_options_as_SoftmaxOptions(); + + if (schema_params != nullptr) { + params->beta = schema_params->beta(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseSpaceToBatchNd(const Operator*, BuiltinDataAllocator*, + void**) { + return OkStatus(); +} + +absl::Status ParseSpaceToDepth(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const auto* schema_params = op->builtin_options_as_SpaceToDepthOptions(); + if (schema_params != nullptr) { + params->block_size = schema_params->block_size(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseSplit(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const SplitOptions* schema_params = op->builtin_options_as_SplitOptions(); + + if (schema_params != nullptr) { + params->num_splits = schema_params->num_splits(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseSplitV(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + SafeBuiltinDataAllocator safe_allocator(allocator); + + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const SplitVOptions* schema_params = op->builtin_options_as_SplitVOptions(); + + if (schema_params != nullptr) { + params->num_splits = schema_params->num_splits(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseUnidirectionalSequenceLSTM(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = + safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + if (const auto* seq_lstm_params = + op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { + params->activation = + ConvertActivation(seq_lstm_params->fused_activation_function()); + params->cell_clip = seq_lstm_params->cell_clip(); + params->proj_clip = seq_lstm_params->proj_clip(); + params->time_major = seq_lstm_params->time_major(); + params->asymmetric_quantize_inputs = + seq_lstm_params->asymmetric_quantize_inputs(); + params->diagonal_recurrent_tensors = + seq_lstm_params->diagonal_recurrent_tensors(); + } + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseSqueeze(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + SafeBuiltinDataAllocator safe_allocator(allocator); + + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const SqueezeOptions* schema_params = op->builtin_options_as_SqueezeOptions(); + + if (schema_params != nullptr) { + const auto* squeeze_dims = schema_params->squeeze_dims(); + if (squeeze_dims != nullptr) { + TFL_MIGRATION_ENSURE_STATUS( + FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, + params->squeeze_dims, "squeeze")); + params->num_squeeze_dims = squeeze_dims->size(); + } else { + params->num_squeeze_dims = 0; + } + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseSqrt(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseSquare(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseSquaredDifference(const Operator*, BuiltinDataAllocator*, + void**) { + return OkStatus(); +} + +absl::Status ParseStridedSlice(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const StridedSliceOptions* schema_params = + op->builtin_options_as_StridedSliceOptions(); + + if (schema_params != nullptr) { + params->begin_mask = schema_params->begin_mask(); + params->end_mask = schema_params->end_mask(); + params->ellipsis_mask = schema_params->ellipsis_mask(); + params->new_axis_mask = schema_params->new_axis_mask(); + params->shrink_axis_mask = schema_params->shrink_axis_mask(); + params->offset = schema_params->offset(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseSub(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const SubOptions* schema_params = op->builtin_options_as_SubOptions(); + + if (schema_params != nullptr) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + params->pot_scale_int16 = schema_params->pot_scale_int16(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseSvdf(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const SVDFOptions* schema_params = op->builtin_options_as_SVDFOptions(); + if (schema_params != nullptr) { + params->rank = schema_params->rank(); + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + params->asymmetric_quantize_inputs = + schema_params->asymmetric_quantize_inputs(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseTanh(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} +// +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseTranspose(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseTransposeConv(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + const TransposeConvOptions* transpose_conv_params = + op->builtin_options_as_TransposeConvOptions(); + if (transpose_conv_params != nullptr) { + params->padding = ConvertPadding(transpose_conv_params->padding()); + params->stride_width = transpose_conv_params->stride_w(); + params->stride_height = transpose_conv_params->stride_h(); + + params->activation = + ConvertActivation(transpose_conv_params->fused_activation_function()); + TFL_MIGRATION_ENSURE_STATUS( + ConvertTensorType(transpose_conv_params->quantized_bias_type(), + ¶ms->quantized_bias_type)); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseUnpack(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const UnpackOptions* schema_params = op->builtin_options_as_UnpackOptions(); + + if (schema_params != nullptr) { + params->num = schema_params->num(); + params->axis = schema_params->axis(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseVarHandle(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const VarHandleOptions* schema_params = + op->builtin_options_as_VarHandleOptions(); + + if (schema_params != nullptr) { + if (schema_params->container()) { + params->container = schema_params->container()->c_str(); + } + if (schema_params->shared_name()) { + params->shared_name = schema_params->shared_name()->c_str(); + } + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +absl::Status ParseWhile(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TFL_MIGRATION_ENSURE(params != nullptr); + + const WhileOptions* schema_params = op->builtin_options_as_WhileOptions(); + + if (schema_params != nullptr) { + params->cond_subgraph_index = schema_params->cond_subgraph_index(); + params->body_subgraph_index = schema_params->body_subgraph_index(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better understand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseZerosLike(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseBitwiseXor(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +// We have this parse function instead of directly returning OkStatus() from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +absl::Status ParseRightShift(const Operator*, BuiltinDataAllocator*, void**) { + return OkStatus(); +} + +absl::Status ParseOpData(const Operator* op, BuiltinOperator op_type, + BuiltinDataAllocator* allocator, void** builtin_data) { +// TODO(b/145762662): It would be preferable to have the build graph for TF Lite +// Micro not have the ParseOpData function at all. This would require splitting +// the current file into two separate files, one of which defines the +// ParseOpData function and the other that defines the operator specific parse +// functions (e.g. ParseAdd). +// +// Such a split was attempted but was not worth the effort at the time because +// of the following reasons: +// * We could either duplicate the functions and the SafeBuiltinDataAllocator +// class in the anonymous namespace of this file, or attempt to make a common +// library with these helper functions and class. +// * Making a common library with a separate build target was not feasible as +// it introduced circular dependencies due to the ErrorReporter and a common +// .cc and .h within the same api build target the also cause circular +// dependencies due to the BuiltinDataAllocator class. +// * If all the builtin operators were to have their own parse functions, or we +// were ok with some amount of code duplication, then this split of the .cc +// files would be a lot more feasible. +#ifdef TF_LITE_STATIC_MEMORY + auto error_message = + "ParseOpData is unsupported on TfLiteMicro, please use the operator " + "specific parse functions (e.g. ParseAdd etc.).\n"; + LOG(ERROR) << error_message; + return absl::UnimplementedError(error_message); +#else + return ParseOpDataTfLite(op, op_type, allocator, builtin_data); +#endif +} + +} // namespace tflite_migration diff --git a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h new file mode 100644 index 00000000000000..5a6aa526e0971b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h @@ -0,0 +1,440 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace tflite_migration { + +using tflite::Operator; + +// Interface class for builtin data allocations. +class BuiltinDataAllocator { + public: + virtual void* Allocate(size_t size, size_t alignment_hint) = 0; + virtual void Deallocate(void* data) = 0; + + // Allocate a structure, but make sure it is a POD structure that doesn't + // require constructors to run. The reason we do this, is that Interpreter's C + // extension part will take ownership so destructors will not be run during + // deallocation. + template + T* AllocatePOD() { + // TODO(b/154346074): Change this to is_trivially_destructible when all + // platform targets support that properly. + static_assert(std::is_pod::value, "Builtin data structure must be POD."); + void* allocated_memory = this->Allocate(sizeof(T), alignof(T)); + return new (allocated_memory) T(); + } + + virtual ~BuiltinDataAllocator() = default; +}; + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The +// calling function has to pass in an allocator object, and this allocator +// will be called to reserve space for the output data. If the calling +// function's allocator reserves memory on the heap, then it's the calling +// function's responsibility to free it. +// If it returns kTfLiteError, `builtin_data` will be `nullptr`. +absl::Status ParseOpData(const tflite::Operator* op, + tflite::BuiltinOperator op_type, + BuiltinDataAllocator* allocator, void** builtin_data); + +// Converts the tensor data type used in the flat buffer to the representation +// used by the runtime. +absl::Status ConvertTensorType(tflite::TensorType tensor_type, + TfLiteType* type); + +absl::Status ParseAbs(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseAdd(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseAddN(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseArgMax(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseArgMin(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseAssignVariable(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBatchMatMul(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBatchToSpaceNd(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBroadcastArgs(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBroadcastTo(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCallOnce(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCeil(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCast(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseConcatenation(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseConv2D(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCos(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCumsum(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseDepthToSpace(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseDepthwiseConv2D(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseDequantize(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseDiv(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseElu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseEmbeddingLookup(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseEqual(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseExp(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseExpandDims(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFill(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFloor(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFloorDiv(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFloorMod(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFullyConnected(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseGather(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseGatherNd(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseGreater(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseGreaterEqual(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseHardSwish(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseIf(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseL2Normalization(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLeakyRelu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLess(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLessEqual(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLog(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogicalAnd(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogicalNot(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogicalOr(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogistic(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogSoftmax(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLSTM(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseMaximum(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseMinimum(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseMirrorPad(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseMul(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseNeg(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseNotEqual(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePack(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePad(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePadV2(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePool(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePow(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePrelu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseQuantize(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseReadVariable(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseReducer(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRelu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRelu6(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseReshape(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseResizeBilinear(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseResizeNearestNeighbor(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRound(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRsqrt(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSelectV2(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseShape(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSin(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSlice(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSoftmax(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSpaceToBatchNd(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSpaceToDepth(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSplit(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSplitV(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSqueeze(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSqrt(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSquare(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSquaredDifference(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStridedSlice(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSub(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSvdf(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseTanh(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseTranspose(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseTransposeConv(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseUnpack(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseUnidirectionalSequenceLSTM(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseVarHandle(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseWhile(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseZerosLike(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBitwiseXor(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRightShift(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloScatter(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloRngBitGenerator(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloGather(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloReduceWindow(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloPad(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloComposite(const Operator* op, + + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloShiftLeft(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +} // namespace tflite_migration + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions_test.cc new file mode 100644 index 00000000000000..ac6ba0243eaa7f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions_test.cc @@ -0,0 +1,873 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +using testing::AllOf; +using testing::Each; +using testing::ElementsAre; +using testing::Eq; +using testing::HasSubstr; +using testing::StrEq; +using tflite::BuiltinOptions; +using tflite::BuiltinOptions2; +using tflite::BuiltinOptions_SqueezeOptions; +using tflite::CustomOptionsFormat_FLEXBUFFERS; + +namespace tflite_migration { +using tflite::ActivationFunctionType_RELU; +using tflite::BuiltinOperator_CONV_2D; +using tflite::BuiltinOperator_CUSTOM; +using tflite::BuiltinOperator_FULLY_CONNECTED; +using tflite::BuiltinOperator_RESHAPE; +using tflite::BuiltinOperator_SQUEEZE; +using tflite::BuiltinOperator_STABLEHLO_PAD; +using tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW; +using tflite::BuiltinOptions2_StablehloPadOptions; +using tflite::BuiltinOptions2_StablehloReduceWindowOptions; +using tflite::BuiltinOptions_Conv2DOptions; +using tflite::BuiltinOptions_FullyConnectedOptions; +using tflite::BuiltinOptions_NONE; +using tflite::BuiltinOptions_ReshapeOptions; +using tflite::CreateReshapeOptions; +using tflite::CreateSqueezeOptions; +using tflite::CreateStablehloPadOptions; +using tflite::CreateStablehloReduceWindowOptions; +using tflite::FullyConnectedOptionsWeightsFormat; +using tflite::Padding_SAME; +using tflite::TensorType_BFLOAT16; +using tflite::TensorType_FLOAT16; +using tflite::TensorType_FLOAT32; +using tflite::TensorType_INT4; + +namespace { + +using std::string; + +// Used to determine how the op data parsing function creates its working space. +class MockDataAllocator : public BuiltinDataAllocator { + public: + MockDataAllocator() : is_allocated_(false) {} + void* Allocate(size_t size, size_t alignment_hint) override { + EXPECT_FALSE(is_allocated_); + const int max_size = kBufferSize; + EXPECT_LE(size, max_size); + is_allocated_ = true; + return buffer_; + } + void Deallocate(void* data) override { is_allocated_ = false; } + + private: + static constexpr int kBufferSize = 1024; + char buffer_[kBufferSize]; + bool is_allocated_; +}; + +} // namespace + +class FlatbufferConversionsTest : public ::testing::Test { + public: + const Operator* BuildTestOperator(BuiltinOptions op_type, + flatbuffers::Offset options) { + flatbuffers::Offset offset = + CreateOperatorDirect(builder_, 0, nullptr, nullptr, op_type, options, + nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr); + builder_.Finish(offset); + void* pointer = builder_.GetBufferPointer(); + return flatbuffers::GetRoot(pointer); + } + + const Operator* BuildTestOperator(BuiltinOptions2 op_type, + flatbuffers::Offset options) { + flatbuffers::Offset offset = CreateOperatorDirect( + builder_, /*opcode_index=*/0, /*inputs=*/nullptr, /*outputs=*/nullptr, + /*builtin_options_type=*/tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, /*custom_options=*/nullptr, + /*custom_options_format=*/tflite::CustomOptionsFormat_FLEXBUFFERS, + /*mutating_variable_inputs=*/nullptr, /*intermediates=*/nullptr, + /*large_custom_options_offset=*/0, /*large_custom_options_size=*/0, + /*builtin_options_2_type=*/op_type, + /*builtin_options_2=*/options); + builder_.Finish(offset); + void* pointer = builder_.GetBufferPointer(); + return flatbuffers::GetRoot(pointer); + } + + protected: + MockDataAllocator mock_allocator_; + flatbuffers::FlatBufferBuilder builder_; +}; + +TEST_F(FlatbufferConversionsTest, ParseSqueezeAll) { + const Operator* op = BuildTestOperator( + BuiltinOptions_SqueezeOptions, CreateSqueezeOptions(builder_).Union()); + void* output_data = nullptr; + EXPECT_TRUE( + ParseOpData(op, BuiltinOperator_SQUEEZE, &mock_allocator_, &output_data) + .ok()); +} + +TEST_F(FlatbufferConversionsTest, ParseDynamicReshape) { + const Operator* op = BuildTestOperator( + BuiltinOptions_ReshapeOptions, CreateReshapeOptions(builder_).Union()); + void* output_data = nullptr; + EXPECT_TRUE( + ParseOpData(op, BuiltinOperator_RESHAPE, &mock_allocator_, &output_data) + .ok()); +} + +TEST_F(FlatbufferConversionsTest, TestParseOpDataConv) { + const Operator* conv_op = + BuildTestOperator(BuiltinOptions_Conv2DOptions, + CreateConv2DOptions(builder_, Padding_SAME, 1, 2, + ActivationFunctionType_RELU, 3, 4) + .Union()); + void* output_data = nullptr; + EXPECT_TRUE(ParseOpData(conv_op, BuiltinOperator_CONV_2D, &mock_allocator_, + &output_data) + .ok()); + EXPECT_NE(nullptr, output_data); + TfLiteConvParams* params = reinterpret_cast(output_data); + EXPECT_EQ(kTfLitePaddingSame, params->padding); + EXPECT_EQ(1, params->stride_width); + EXPECT_EQ(2, params->stride_height); + EXPECT_EQ(kTfLiteActRelu, params->activation); + EXPECT_EQ(3, params->dilation_width_factor); + EXPECT_EQ(4, params->dilation_height_factor); +} + +TEST_F(FlatbufferConversionsTest, ParseBadFullyConnected) { + const Operator* conv_op = BuildTestOperator( + BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions( + builder_, ActivationFunctionType_RELU, + static_cast(-1), true) + .Union()); + void* output_data = nullptr; + EXPECT_FALSE(ParseOpData(conv_op, BuiltinOperator_FULLY_CONNECTED, + &mock_allocator_, &output_data) + .ok()); +} + +TEST_F(FlatbufferConversionsTest, TestParseOpDataCustom) { + const Operator* custom_op = + BuildTestOperator(BuiltinOptions_NONE, flatbuffers::Offset()); + void* output_data = nullptr; + EXPECT_TRUE(ParseOpData(custom_op, BuiltinOperator_CUSTOM, &mock_allocator_, + &output_data) + .ok()); + EXPECT_EQ(nullptr, output_data); +} + +TEST_F(FlatbufferConversionsTest, TestConvertTensorType) { + TfLiteType type; + EXPECT_TRUE(ConvertTensorType(TensorType_FLOAT32, &type).ok()); + EXPECT_EQ(kTfLiteFloat32, type); +} + +TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeFloat16) { + TfLiteType type; + EXPECT_TRUE(ConvertTensorType(TensorType_FLOAT16, &type).ok()); + EXPECT_EQ(kTfLiteFloat16, type); +} + +TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeBFloat16) { + TfLiteType type; + EXPECT_TRUE(ConvertTensorType(TensorType_BFLOAT16, &type).ok()); + EXPECT_EQ(kTfLiteBFloat16, type); +} + +TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeInt4) { + TfLiteType type; + EXPECT_TRUE(ConvertTensorType(TensorType_INT4, &type).ok()); + EXPECT_EQ(kTfLiteInt4, type); +} + +class StablehloReduceWindowFlatbufferConversionsTest + : public FlatbufferConversionsTest { + public: + static constexpr int kMaxDims = + TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT; + static constexpr int64_t kValidValue = 5; + + auto ValidAttr() { + return builder_.CreateVector(std::vector(kMaxDims, kValidValue)); + } + + auto InvalidAttr() { + return builder_.CreateVector( + std::vector(kMaxDims + 1, kValidValue)); + } + + auto ValidPaddingAttr() { + return builder_.CreateVector( + std::vector(2 * kMaxDims, kValidValue)); + } + + auto InvalidPaddingAttr() { + return builder_.CreateVector( + std::vector(2 * kMaxDims + 1, kValidValue)); + } + + auto EmptyAttr() { return builder_.CreateVector({}); } +}; + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, Succeeds) { + const Operator* stablehlo_reduce_window_op = BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions( + builder_, + /*window_dimensions=*/builder_.CreateVector({1, 2}), + /*window_strides=*/builder_.CreateVector({3, 4}), + /*base_dilations=*/builder_.CreateVector({5, 6}), + /*window_dilations=*/builder_.CreateVector({7, 8}), + /*padding=*/builder_.CreateVector({9, 10, 11, 12}), + /*body_subgraph_index=*/13) + .Union()); + TfLiteStablehloReduceWindowParams* output_data = nullptr; + EXPECT_TRUE(ParseOpData(stablehlo_reduce_window_op, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, + &mock_allocator_, (void**)&output_data) + .ok()); + + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, 2), + ElementsAre(1, 2)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, 2), + ElementsAre(3, 4)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, 2), + ElementsAre(5, 6)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, 2), + ElementsAre(7, 8)); + EXPECT_THAT(std::make_tuple(output_data->padding, 4), + ElementsAre(9, 10, 11, 12)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + FailsWithNoWindowDimensions) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/0, + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("'window_dimensions' attribute is not optional for " + "'stablehlo.reduce_window' and cannot be empty.")); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + SucceedsWithNoWindowStrides) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/0, + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), Each(1)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), + Each(kValidValue)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + SucceedsWithNoBaseDilations) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/0, + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), Each(1)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), + Each(kValidValue)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + SucceedsWithNoWindowDilations) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/0, + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), + Each(1)); + EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), + Each(kValidValue)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, SucceedsWithNoPadding) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/0, + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), Each(0)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + FailsWithEmptyWindowDimensions) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/EmptyAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("'window_dimensions' attribute is not optional for " + "'stablehlo.reduce_window' and cannot be empty.")); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + SucceedsWithEmptyWindowStrides) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/EmptyAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), Each(1)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), + Each(kValidValue)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + SucceedsWithEmptyBaseDilations) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/EmptyAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), Each(1)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), + Each(kValidValue)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + SucceedsWithEmptyWindowDilations) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/EmptyAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), + Each(1)); + EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), + Each(kValidValue)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + SucceedsWithEmptyPadding) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/EmptyAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); + EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), + Each(kValidValue)); + EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), Each(0)); + EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + SucceedsWithParamsAtMaxDims) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(status.message(), StrEq("")); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + FailsWhenWindowDimensionsHasMoreThanMaxDims) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions( + builder_, + /*window_dimensions=*/InvalidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + AllOf(HasSubstr("Found too many dimensions in the input array of " + "operation 'stablehlo.reduce_window'."), + HasSubstr("Check the 'window_dimensions' attribute."))); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + FailsWhenWindowStridesHasWrongDimCount) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/InvalidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + HasSubstr("'window_strides' attribute of 'stablehlo.reduce_window' does " + "not have the expected size")); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + FailsWhenBaseDilationsHasWrongDimCount) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/InvalidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + HasSubstr("'base_dilations' attribute of 'stablehlo.reduce_window' does " + "not have the expected size")); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + FailsWhenWindowDilationsHasWrongDimCount) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/InvalidAttr(), + /*padding=*/ValidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + HasSubstr( + "'window_dilations' attribute of 'stablehlo.reduce_window' does " + "not have the expected size")); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, + FailsWhenPaddingHasWrongDimCount) { + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData( + BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions(builder_, + /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/InvalidPaddingAttr(), + /*body_subgraph_index=*/13) + .Union()), + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, + (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("'padding' attribute of 'stablehlo.reduce_window' does " + "not have the expected size")); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, FailsWithWrongOptions) { + const Operator* stablehlo_reduce_window_op = + BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, 0); + TfLiteStablehloReduceWindowParams* output_data = nullptr; + auto status = ParseOpData(stablehlo_reduce_window_op, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, + &mock_allocator_, (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + HasSubstr( + "Could not get 'stablehlo.reduce_window' operation parameters.")); +} + +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, DeathTests) { + const Operator* stablehlo_reduce_window_op = BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions( + builder_, /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), /*body_subgraph_index=*/13) + .Union()); + TfLiteStablehloReduceWindowParams* output_data = nullptr; +#ifdef NDEBUG + GTEST_SKIP(); +#endif + EXPECT_DEATH(ParseOpData(nullptr, BuiltinOperator_STABLEHLO_REDUCE_WINDOW, + &mock_allocator_, (void**)&output_data) + .IgnoreError(), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, nullptr, + (void**)&output_data) + .IgnoreError(), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, + &mock_allocator_, nullptr) + .IgnoreError(), + ""); +} + +class StablehloPadFlatbufferConversionsTest : public FlatbufferConversionsTest { + public: + static constexpr int kMaxDims = + TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT; + static constexpr int64_t kValidValue = 5; +}; + +TEST_F(StablehloPadFlatbufferConversionsTest, Succeeds) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), + /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), + /*interior_padding=*/builder_.CreateVector({3, 0, 3})) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + EXPECT_TRUE(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_allocator_, (void**)&output_data) + .ok()); + EXPECT_THAT(std::make_tuple(output_data->edge_padding_low, 3), + ElementsAre(1, 0, -1)); + EXPECT_THAT(std::make_tuple(output_data->edge_padding_high, 3), + ElementsAre(2, 0, -2)); + EXPECT_THAT(std::make_tuple(output_data->interior_padding, 3), + ElementsAre(3, 0, 3)); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingLowPadding) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/0, + /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), + /*interior_padding=*/builder_.CreateVector({3, 0, 3})) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_allocator_, (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + AllOf( + HasSubstr("Input array not provided for operation 'stablehlo.pad'."), + HasSubstr("Check the 'edge_padding_low' attribute."))); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingHighPadding) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), + /*edge_padding_high=*/0, + /*interior_padding=*/builder_.CreateVector({3, 0, 3})) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_allocator_, (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + AllOf( + HasSubstr("Input array not provided for operation 'stablehlo.pad'."), + HasSubstr("Check the 'edge_padding_high' attribute."))); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingInteriorPadding) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), + /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), + /*interior_padding=*/0) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_allocator_, (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + AllOf( + HasSubstr("Input array not provided for operation 'stablehlo.pad'."), + HasSubstr("Check the 'interior_padding' attribute."))); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsInconsistentSizes) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), + /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), + /*interior_padding=*/builder_.CreateVector({3, 0, -3, 5})) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_allocator_, (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("'stablehlo.pad' operation parameter array sizes are " + "not consistent.")); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithWrongOptions) { + const Operator* stablehlo_pad_op = BuildTestOperator(BuiltinOptions_NONE, 0); + TfLiteStablehloPadParams* output_data = nullptr; + auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_allocator_, (void**)&output_data); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("Could not get 'stablehlo.pad' operation parameters.")); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, DeathTests) { + const Operator* stablehlo_pad_op = BuildTestOperator(BuiltinOptions_NONE, 0); + TfLiteStablehloPadParams* output_data = nullptr; +#ifdef NDEBUG + GTEST_SKIP(); +#endif + EXPECT_DEATH(ParseOpData(nullptr, BuiltinOperator_STABLEHLO_PAD, + &mock_allocator_, (void**)&output_data) + .IgnoreError(), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + nullptr, (void**)&output_data) + .IgnoreError(), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_allocator_, nullptr) + .IgnoreError(), + ""); +} + +} // namespace tflite_migration diff --git a/tensorflow/compiler/mlir/lite/core/c/BUILD b/tensorflow/compiler/mlir/lite/core/c/BUILD index 3338e5b8940fca..2e07c72e817602 100644 --- a/tensorflow/compiler/mlir/lite/core/c/BUILD +++ b/tensorflow/compiler/mlir/lite/core/c/BUILD @@ -9,19 +9,30 @@ package( licenses = ["notice"], ) +exports_files( + srcs = [ + "builtin_op_data.h", + "tflite_types.h", + ], + visibility = [ + "//tensorflow/lite:__subpackages__", + ], +) + # LINT.IfChange(common) cc_library( name = "tflite_common", srcs = [], hdrs = [ "builtin_op_data.h", - "dimension_type.h", + "tflite_types.h", ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = [ "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:__pkg__", + "//tensorflow/lite/core/c:__subpackages__", ], alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) diff --git a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h index 7a67c630fe1ebd..836d80ab59eabf 100644 --- a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h +++ b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,16 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +/// WARNING: Users of TensorFlow Lite should not include this file directly, +/// but should instead include +/// "third_party/tensorflow/lite/c/builtin_op_data.h". +/// Only the TensorFlow Lite implementation itself should include this +/// file directly. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ -// LINT.IfChange(enum) +#include // IWYU pragma: keep +#include +#include + +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TfLiteReshapeParams can't have dynamic data so we fix the maximum possible +// number of dimensions. +#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT 8 + +// TODO(aselle): Consider using "if this then that" for testing. + +// Useful placeholder to put in otherwise empty structs to avoid size warnings. +typedef struct { + char dummy; +} EmptyStructPlaceholder; + +// IMPORTANT: All new members of structs must be added at the end to ensure +// backwards compatibility. + +// Possible padding types (for convolutions) typedef enum { kTfLitePaddingUnknown = 0, kTfLitePaddingSame, kTfLitePaddingValid, } TfLitePadding; +typedef enum { + kTfLiteMirrorPaddingUnknown = 0, + kTfLiteMirrorPaddingReflect, + kTfLiteMirrorPaddingSymmetric, +} TfLiteMirrorPaddingMode; + +// TODO(b/130259536): We should move this out of builtin_op_data. +typedef struct { + int width; + int height; + int width_offset; + int height_offset; +} TfLitePaddingValues; + +typedef struct { + TfLiteMirrorPaddingMode mode; +} TfLiteMirrorPaddingParams; + // Possible fused activation functions. typedef enum { kTfLiteActNone = 0, @@ -32,16 +83,36 @@ typedef enum { kTfLiteActSignBit, kTfLiteActSigmoid, } TfLiteFusedActivation; -// LINT.ThenChange(//tensorflow/lite/core/c/builtin_op_data.h) -// LINT.IfChange(struct) -// TODO(b/130259536): We should move this out of builtin_op_data. typedef struct { - int width; - int height; - int width_offset; - int height_offset; -} TfLitePaddingValues; + // Parameters for CONV_2D version 1. + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; + + // Parameters for CONV_2D version 2. + // Note: Version 2 supports dilation values not equal to 1. + int dilation_width_factor; + int dilation_height_factor; + + // Parameters for CONV_2D version 7 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int stride_depth; + int dilation_width_factor; + int dilation_height_factor; + int dilation_depth_factor; + TfLiteFusedActivation activation; +} TfLiteConv3DParams; + +typedef TfLiteConv3DParams TfLiteConv3DTransposeParams; typedef struct { TfLitePadding padding; @@ -54,6 +125,537 @@ typedef struct { TfLitePaddingValues padding; } computed; } TfLitePoolParams; -// LINT.ThenChange(//tensorflow/lite/core/c/builtin_op_data.h) + +typedef struct { + // Parameters for DepthwiseConv version 1 or above. + TfLitePadding padding; + int stride_width; + int stride_height; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // + // The information can be deduced from the shape of input and the shape of + // weights. Since the TFLiteConverter toolchain doesn't support partially + // specified shapes, relying on `depth_multiplier` stops us from supporting + // graphs with dynamic shape tensors. + // + // Note: Some of the delegates (e.g. NNAPI, GPU) are still relying on this + // field. + int depth_multiplier; + TfLiteFusedActivation activation; + // Parameters for DepthwiseConv version 2 or above. + int dilation_width_factor; + int dilation_height_factor; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; + + // Parameter for SVDF version 4. + bool asymmetric_quantize_inputs; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; + + // Parameter for RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + + // Parameter for Sequence RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteSequenceRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + bool merge_outputs; + + // Parameter for Bidirectional RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteBidirectionalSequenceRNNParams; + +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + +typedef struct { + // Parameters for FullyConnected version 1 or above. + TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimensions in the input and the output + // tensors are the same. Furthermore, all but the last dimension of the input + // and output shapes will be equal. + bool keep_num_dims; + + // Parameters for FullyConnected version 7 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; + + // Parameters for FullyConnected version 10 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; + +typedef struct { + float beta; +} TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; + // Parameter added for the version 4. + bool pot_scale_int16; +} TfLiteAddParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteSpaceToBatchNDParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteBatchToSpaceNDParams; + +typedef struct { + bool adj_x; + bool adj_y; + // Parameters for BatchMatMul version 4 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; +} TfLiteBatchMatMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; + // Parameter added for the version 5. + bool pot_scale_int16; +} TfLiteSubParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteDivParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + +typedef struct { + // Parameters for LSTM version 1. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; + + // Parameters for LSTM version 4. + bool asymmetric_quantize_inputs; +} TfLiteLSTMParams; + +typedef struct { + // Parameters needed for the underlying LSTM. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If set to true then the first dimension is time, otherwise batch. + bool time_major; + + // Parameter for unidirectional sequence RNN version 3. + bool asymmetric_quantize_inputs; + + // Parameter for unidirectional sequence RNN version 4. + bool diagonal_recurrent_tensors; +} TfLiteUnidirectionalSequenceLSTMParams; + +typedef struct { + // Parameters supported by version 1: + // Parameters inherited for the LSTM kernel. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If true, store the outputs of both directions in the first output. + bool merge_outputs; + + // Parameters supported by version 2: + // If set to true then the first dimension is time, otherwise batch. + bool time_major; + + // Parameters supported by version 3: + // If set to true, then hybrid ops use asymmetric quantization for inputs. + bool asymmetric_quantize_inputs; +} TfLiteBidirectionalSequenceLSTMParams; + +typedef struct { + bool align_corners; + // half_pixel_centers assumes pixels are of half the actual dimensions, and + // yields more accurate resizes. Corresponds to the same argument for the + // original TensorFlow op in TF2.0. + bool half_pixel_centers; +} TfLiteResizeBilinearParams; + +typedef struct { + bool align_corners; + bool half_pixel_centers; +} TfLiteResizeNearestNeighborParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLitePadParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLitePadV2Params; + +typedef struct { + // These fields are only used in old models for backward compatibility. + // In the current implementation, we use the 2nd input of the op as the shape, + // and these fields are unused. + int32_t shape[TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef struct { + int block_size; +} TfLiteDepthToSpaceParams; + +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +typedef struct { + int axis; + int batch_dims; +} TfLiteGatherParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteTransposeParams; + +typedef struct { + bool keep_dims; +} TfLiteReducerParams; + +typedef struct { + int num_splits; +} TfLiteSplitParams; + +typedef struct { + int num_splits; +} TfLiteSplitVParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int32_t squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; + + // Parameters supported by version 8: + // If true, then the end tensor is an offset of the begin tensor. + bool offset; +} TfLiteStridedSliceParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMaxParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + +typedef struct { + // Parameters supported by version 1: + TfLitePadding padding; + int stride_width; + int stride_height; + + // Parameters supported by version 4: + TfLiteFusedActivation activation; + + // Parameters for TransposeConv version 5 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteRankParams; + +typedef struct { + // Parameters supported by version 1: + float min; + float max; + int num_bits; + + // Parameters supported by version 2: + bool narrow_range; +} TfLiteFakeQuantParams; + +typedef struct { + int values_count; + int axis; +} TfLitePackParams; + +typedef struct { + int axis; +} TfLiteOneHotParams; + +typedef struct { + int num; + int axis; +} TfLiteUnpackParams; + +typedef struct { + float alpha; +} TfLiteLeakyReluParams; + +typedef struct { + TfLiteType index_out_type; +} TfLiteUniqueParams; + +typedef struct { + int seq_dim; + int batch_dim; +} TfLiteReverseSequenceParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixDiagParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixSetDiagParams; + +typedef struct { + int then_subgraph_index; + int else_subgraph_index; +} TfLiteIfParams; + +typedef struct { + int cond_subgraph_index; + int body_subgraph_index; +} TfLiteWhileParams; + +typedef struct { + bool exclusive; + bool reverse; +} TfLiteCumsumParams; + +typedef struct { + int init_subgraph_index; +} TfLiteCallOnceParams; + +typedef struct { + int table_id; + TfLiteType key_dtype; + TfLiteType value_dtype; +} TfLiteHashtableParams; + +typedef struct { + const char* container; + const char* shared_name; +} TfLiteVarHandleParams; + +typedef struct { + int seed; + int seed2; +} TfLiteRandomParams; + +typedef struct { + int num_boundaries; + // This points to the memory stored in the model (flatbuffer), + // and is not owned. + const float* boundaries; +} TfLiteBucketizeParams; + +typedef struct { + bool approximate; +} TfLiteGeluParams; + +typedef struct { + int64_t dimension; +} TfLiteStablehloConcatenateParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter + bool indices_are_sorted; + int64_t + update_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_update_window_dims; + int64_t + inserted_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_inserted_window_dims; + int64_t scatter_dims_to_operand_dims + [TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_scatter_dims_to_operand_dims; + int64_t index_vector_dim; + bool unique_indices; + int update_computation_subgraph_index; +} TfLiteStablehloScatterParams; + +typedef enum { + kTfLiteRngAlgorithmUnknown = 0, + // An algorithm auto-selected by the system according to device type. + kTfLiteRngAlgorithmDefault, + // The Philox algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + kTfLiteRngAlgorithmPhilox, + // The ThreeFry algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + kTfLiteRngAlgorithmThreefry, +} TfLiteRngAlgorithm; + +typedef struct { + TfLiteRngAlgorithm algorithm; +} TfLiteStablehloRngBitGeneratorParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather + int64_t offset_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_offset_dims; + int64_t + collapsed_slice_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_collapsed_slice_dims; + int64_t start_index_map[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_start_index_map; + int64_t index_vector_dim; + int64_t slice_sizes[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_slice_sizes; + bool indices_are_sorted; +} TfLiteStablehloGatherParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window + int64_t window_dimensions + [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t window_dilations + [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int body_subgraph_index; +} TfLiteStablehloReduceWindowParams; + +enum TfLiteReduceWindowFunction { + TfLiteReduceWindowFunctionUnsupported, + TfLiteReduceWindowFunctionAdd, + TfLiteReduceWindowFunctionMul, + TfLiteReduceWindowFunctionMin, + TfLiteReduceWindowFunctionMax, + TfLiteReduceWindowFunctionAll, + TfLiteReduceWindowFunctionAny +}; + +typedef struct { + enum TfLiteReduceWindowFunction reduce_function; +} TfLiteReduceWindowParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad + int64_t edge_padding_low[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; + int64_t edge_padding_high[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; + int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; +} TfLiteStablehloPadParams; + +typedef struct { + const char* name; + int32_t subgraph_index; + int32_t version; + const uint8_t* attributes; + size_t attributes_size; +} TfLiteStablehloCompositeParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/compiler/mlir/lite/core/c/dimension_type.h b/tensorflow/compiler/mlir/lite/core/c/dimension_type.h deleted file mode 100644 index fd2c6122897065..00000000000000 --- a/tensorflow/compiler/mlir/lite/core/c/dimension_type.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ - -// LINT.IfChange - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - - -/// Storage format of each dimension in a sparse tensor. -typedef enum TfLiteDimensionType { - kTfLiteDimDense = 0, - kTfLiteDimSparseCSR, -} TfLiteDimensionType; - -#ifdef __cplusplus -} // extern "C" - -#endif // __cplusplus -#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ - -// LINT.ThenChange(//tensorflow/lite/core/c/common.h) diff --git a/tensorflow/compiler/mlir/lite/core/c/tflite_types.h b/tensorflow/compiler/mlir/lite/core/c/tflite_types.h new file mode 100644 index 00000000000000..6006b2d3c2ee5d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/c/tflite_types.h @@ -0,0 +1,70 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Types supported by tensor +// LINT.IfChange +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, + kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, + kTfLiteInt8 = 9, + kTfLiteFloat16 = 10, + kTfLiteFloat64 = 11, + kTfLiteComplex128 = 12, + kTfLiteUInt64 = 13, + kTfLiteResource = 14, + kTfLiteVariant = 15, + kTfLiteUInt32 = 16, + kTfLiteUInt16 = 17, + kTfLiteInt4 = 18, + kTfLiteBFloat16 = 19, +} TfLiteType; +// LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType) + +/// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`. +/// If per-layer quantization is specified this field will still be populated in +/// addition to `TfLiteAffineQuantization`. +/// Parameters for asymmetric quantization. Quantized values can be converted +/// back to float using: `real_value = scale * (quantized_value - zero_point)` +typedef struct TfLiteQuantizationParams { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +/// Storage format of each dimension in a sparse tensor. +typedef enum TfLiteDimensionType { + kTfLiteDimDense = 0, + kTfLiteDimSparseCSR, +} TfLiteDimensionType; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc index 4a28c1474e9be8..e5db23a8831872 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive -#include "tensorflow/compiler/mlir/lite/core/c/dimension_type.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" namespace tflite_migration { namespace internal { diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h index 12b54502b46369..56ba7181098c79 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h @@ -18,7 +18,7 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive -#include "tensorflow/compiler/mlir/lite/core/c/dimension_type.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" namespace tflite_migration { namespace internal { diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index f1664849f36e50..f7032741c62060 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -416,6 +416,7 @@ filegroup( "c_api.h", "c_api_types.h", "common.h", + "//tensorflow/compiler/mlir/lite/core/c:tflite_types.h", ], ) diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index 00a1a27ec6d819..efbdc19755744f 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -284,6 +284,9 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__", "//tensorflow/lite:__subpackages__", ] + c_api_visibility_allowlist(), + deps = [ + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + ], ) # Test the C extension API code. @@ -338,6 +341,7 @@ tflite_cc_library_with_c_headers_test( visibility = ["//tensorflow/lite:__subpackages__"] + common_header_visibility_allowlist(), deps = [ ":c_api_types", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", "//tensorflow/lite:tflite_kernel_use_xnnpack_optional", ] + select({ "//tensorflow/lite:tensorflow_profiler_config": [ diff --git a/tensorflow/lite/core/c/builtin_op_data.h b/tensorflow/lite/core/c/builtin_op_data.h index e1428e72307134..cfe3d825a7fa2a 100644 --- a/tensorflow/lite/core/c/builtin_op_data.h +++ b/tensorflow/lite/core/c/builtin_op_data.h @@ -20,642 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ -#include -#include -#include - -#include "tensorflow/lite/core/c/common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// TfLiteReshapeParams can't have dynamic data so we fix the maximum possible -// number of dimensions. -#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT 8 - -// TODO(aselle): Consider using "if this then that" for testing. - -// Useful placeholder to put in otherwise empty structs to avoid size warnings. -typedef struct { - char dummy; -} EmptyStructPlaceholder; - -// IMPORTANT: All new members of structs must be added at the end to ensure -// backwards compatibility. - -// Possible padding types (for convolutions) -typedef enum { - kTfLitePaddingUnknown = 0, - kTfLitePaddingSame, - kTfLitePaddingValid, -} TfLitePadding; - -typedef enum { - kTfLiteMirrorPaddingUnknown = 0, - kTfLiteMirrorPaddingReflect, - kTfLiteMirrorPaddingSymmetric, -} TfLiteMirrorPaddingMode; - -// TODO(b/130259536): We should move this out of builtin_op_data. -typedef struct { - int width; - int height; - int width_offset; - int height_offset; -} TfLitePaddingValues; - -typedef struct { - TfLiteMirrorPaddingMode mode; -} TfLiteMirrorPaddingParams; - -// Possible fused activation functions. -typedef enum { - kTfLiteActNone = 0, - kTfLiteActRelu, - kTfLiteActReluN1To1, // min(max(-1, x), 1) - kTfLiteActRelu6, // min(max(0, x), 6) - kTfLiteActTanh, - kTfLiteActSignBit, - kTfLiteActSigmoid, -} TfLiteFusedActivation; - -typedef struct { - // Parameters for CONV_2D version 1. - TfLitePadding padding; - int stride_width; - int stride_height; - TfLiteFusedActivation activation; - - // Parameters for CONV_2D version 2. - // Note: Version 2 supports dilation values not equal to 1. - int dilation_width_factor; - int dilation_height_factor; - - // Parameters for CONV_2D version 7 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteConvParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int stride_depth; - int dilation_width_factor; - int dilation_height_factor; - int dilation_depth_factor; - TfLiteFusedActivation activation; -} TfLiteConv3DParams; - -typedef TfLiteConv3DParams TfLiteConv3DTransposeParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int filter_width; - int filter_height; - TfLiteFusedActivation activation; - struct { - TfLitePaddingValues padding; - } computed; -} TfLitePoolParams; - -typedef struct { - // Parameters for DepthwiseConv version 1 or above. - TfLitePadding padding; - int stride_width; - int stride_height; - // `depth_multiplier` is redundant. It's used by CPU kernels in - // TensorFlow 2.0 or below, but ignored in versions above. - // - // The information can be deduced from the shape of input and the shape of - // weights. Since the TFLiteConverter toolchain doesn't support partially - // specified shapes, relying on `depth_multiplier` stops us from supporting - // graphs with dynamic shape tensors. - // - // Note: Some of the delegates (e.g. NNAPI, GPU) are still relying on this - // field. - int depth_multiplier; - TfLiteFusedActivation activation; - // Parameters for DepthwiseConv version 2 or above. - int dilation_width_factor; - int dilation_height_factor; -} TfLiteDepthwiseConvParams; - -typedef struct { - int rank; - TfLiteFusedActivation activation; - - // Parameter for SVDF version 4. - bool asymmetric_quantize_inputs; -} TfLiteSVDFParams; - -typedef struct { - TfLiteFusedActivation activation; - - // Parameter for RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteRNNParams; - -typedef struct { - bool time_major; - TfLiteFusedActivation activation; - - // Parameter for Sequence RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteSequenceRNNParams; - -typedef struct { - bool time_major; - TfLiteFusedActivation activation; - bool merge_outputs; - - // Parameter for Bidirectional RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteBidirectionalSequenceRNNParams; - -typedef enum { - kTfLiteFullyConnectedWeightsFormatDefault = 0, - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, -} TfLiteFullyConnectedWeightsFormat; - -typedef struct { - // Parameters for FullyConnected version 1 or above. - TfLiteFusedActivation activation; - - // Parameters for FullyConnected version 2 or above. - TfLiteFullyConnectedWeightsFormat weights_format; - - // Parameters for FullyConnected version 5 or above. - // If set to true, then the number of dimensions in the input and the output - // tensors are the same. Furthermore, all but the last dimension of the input - // and output shapes will be equal. - bool keep_num_dims; - - // Parameters for FullyConnected version 7 or above. - // If set to true and the weights are quantized, then non constant inputs - // are quantized at evaluation time with asymmetric quantization. - bool asymmetric_quantize_inputs; - - // Parameters for FullyConnected version 10 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteFullyConnectedParams; - -typedef enum { - kTfLiteLshProjectionUnknown = 0, - kTfLiteLshProjectionSparse = 1, - kTfLiteLshProjectionDense = 2, -} TfLiteLSHProjectionType; - -typedef struct { - TfLiteLSHProjectionType type; -} TfLiteLSHProjectionParams; - -typedef struct { - float beta; -} TfLiteSoftmaxParams; - -typedef struct { - int axis; - TfLiteFusedActivation activation; -} TfLiteConcatenationParams; - -typedef struct { - TfLiteFusedActivation activation; - // Parameter added for the version 4. - bool pot_scale_int16; -} TfLiteAddParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteSpaceToBatchNDParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteBatchToSpaceNDParams; - -typedef struct { - bool adj_x; - bool adj_y; - // Parameters for BatchMatMul version 4 or above. - // If set to true and the weights are quantized, then non constant inputs - // are quantized at evaluation time with asymmetric quantization. - bool asymmetric_quantize_inputs; -} TfLiteBatchMatMulParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteMulParams; - -typedef struct { - TfLiteFusedActivation activation; - // Parameter added for the version 5. - bool pot_scale_int16; -} TfLiteSubParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteDivParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteL2NormParams; - -typedef struct { - int radius; - float bias; - float alpha; - float beta; -} TfLiteLocalResponseNormParams; - -typedef enum { - kTfLiteLSTMFullKernel = 0, - kTfLiteLSTMBasicKernel -} TfLiteLSTMKernelType; - -typedef struct { - // Parameters for LSTM version 1. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // Parameters for LSTM version 2. - // kTfLiteLSTMBasicKernel is only supported in version 2 or above. - TfLiteLSTMKernelType kernel_type; - - // Parameters for LSTM version 4. - bool asymmetric_quantize_inputs; -} TfLiteLSTMParams; - -typedef struct { - // Parameters needed for the underlying LSTM. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // If set to true then the first dimension is time, otherwise batch. - bool time_major; - - // Parameter for unidirectional sequence RNN version 3. - bool asymmetric_quantize_inputs; - - // Parameter for unidirectional sequence RNN version 4. - bool diagonal_recurrent_tensors; -} TfLiteUnidirectionalSequenceLSTMParams; - -typedef struct { - // Parameters supported by version 1: - // Parameters inherited for the LSTM kernel. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // If true, store the outputs of both directions in the first output. - bool merge_outputs; - - // Parameters supported by version 2: - // If set to true then the first dimension is time, otherwise batch. - bool time_major; - - // Parameters supported by version 3: - // If set to true, then hybrid ops use asymmetric quantization for inputs. - bool asymmetric_quantize_inputs; -} TfLiteBidirectionalSequenceLSTMParams; - -typedef struct { - bool align_corners; - // half_pixel_centers assumes pixels are of half the actual dimensions, and - // yields more accurate resizes. Corresponds to the same argument for the - // original TensorFlow op in TF2.0. - bool half_pixel_centers; -} TfLiteResizeBilinearParams; - -typedef struct { - bool align_corners; - bool half_pixel_centers; -} TfLiteResizeNearestNeighborParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLitePadParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLitePadV2Params; - -typedef struct { - // These fields are only used in old models for backward compatibility. - // In the current implementation, we use the 2nd input of the op as the shape, - // and these fields are unused. - int32_t shape[TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT]; - int num_dimensions; -} TfLiteReshapeParams; - -typedef struct { - int ngram_size; - int max_skip_size; - bool include_all_ngrams; -} TfLiteSkipGramParams; - -typedef struct { - int block_size; -} TfLiteSpaceToDepthParams; - -typedef struct { - int block_size; -} TfLiteDepthToSpaceParams; - -typedef struct { - TfLiteType in_data_type; - TfLiteType out_data_type; -} TfLiteCastParams; - -typedef enum { - kTfLiteCombinerTypeSum = 0, - kTfLiteCombinerTypeMean = 1, - kTfLiteCombinerTypeSqrtn = 2, -} TfLiteCombinerType; - -typedef struct { - TfLiteCombinerType combiner; -} TfLiteEmbeddingLookupSparseParams; - -typedef struct { - int axis; - int batch_dims; -} TfLiteGatherParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteTransposeParams; - -typedef struct { - bool keep_dims; -} TfLiteReducerParams; - -typedef struct { - int num_splits; -} TfLiteSplitParams; - -typedef struct { - int num_splits; -} TfLiteSplitVParams; - -typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int32_t squeeze_dims[8]; - int num_squeeze_dims; -} TfLiteSqueezeParams; - -typedef struct { - int begin_mask; - int end_mask; - int ellipsis_mask; - int new_axis_mask; - int shrink_axis_mask; - - // Parameters supported by version 8: - // If true, then the end tensor is an offset of the begin tensor. - bool offset; -} TfLiteStridedSliceParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMaxParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMinParams; - -typedef struct { - // Parameters supported by version 1: - TfLitePadding padding; - int stride_width; - int stride_height; - - // Parameters supported by version 4: - TfLiteFusedActivation activation; - - // Parameters for TransposeConv version 5 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteTransposeConvParams; - -typedef struct { - bool validate_indices; -} TfLiteSparseToDenseParams; - -typedef struct { - TfLiteType out_type; -} TfLiteShapeParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteRankParams; - -typedef struct { - // Parameters supported by version 1: - float min; - float max; - int num_bits; - - // Parameters supported by version 2: - bool narrow_range; -} TfLiteFakeQuantParams; - -typedef struct { - int values_count; - int axis; -} TfLitePackParams; - -typedef struct { - int axis; -} TfLiteOneHotParams; - -typedef struct { - int num; - int axis; -} TfLiteUnpackParams; - -typedef struct { - float alpha; -} TfLiteLeakyReluParams; - -typedef struct { - TfLiteType index_out_type; -} TfLiteUniqueParams; - -typedef struct { - int seq_dim; - int batch_dim; -} TfLiteReverseSequenceParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteMatrixDiagParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteMatrixSetDiagParams; - -typedef struct { - int then_subgraph_index; - int else_subgraph_index; -} TfLiteIfParams; - -typedef struct { - int cond_subgraph_index; - int body_subgraph_index; -} TfLiteWhileParams; - -typedef struct { - bool exclusive; - bool reverse; -} TfLiteCumsumParams; - -typedef struct { - int init_subgraph_index; -} TfLiteCallOnceParams; - -typedef struct { - int table_id; - TfLiteType key_dtype; - TfLiteType value_dtype; -} TfLiteHashtableParams; - -typedef struct { - const char* container; - const char* shared_name; -} TfLiteVarHandleParams; - -typedef struct { - int seed; - int seed2; -} TfLiteRandomParams; - -typedef struct { - int num_boundaries; - // This points to the memory stored in the model (flatbuffer), - // and is not owned. - const float* boundaries; -} TfLiteBucketizeParams; - -typedef struct { - bool approximate; -} TfLiteGeluParams; - -typedef struct { - int64_t dimension; -} TfLiteStablehloConcatenateParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter - bool indices_are_sorted; - int64_t - update_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_update_window_dims; - int64_t - inserted_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_inserted_window_dims; - int64_t scatter_dims_to_operand_dims - [TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_scatter_dims_to_operand_dims; - int64_t index_vector_dim; - bool unique_indices; - int update_computation_subgraph_index; -} TfLiteStablehloScatterParams; - -typedef enum { - kTfLiteRngAlgorithmUnknown = 0, - // An algorithm auto-selected by the system according to device type. - kTfLiteRngAlgorithmDefault, - // The Philox algorithm, as described in paper - // ['Parallel Random Numbers: As Easy as 1, 2, 3'] - // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) - kTfLiteRngAlgorithmPhilox, - // The ThreeFry algorithm, as described in paper - // ['Parallel Random Numbers: As Easy as 1, 2, 3'] - // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) - kTfLiteRngAlgorithmThreefry, -} TfLiteRngAlgorithm; - -typedef struct { - TfLiteRngAlgorithm algorithm; -} TfLiteStablehloRngBitGeneratorParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather - int64_t offset_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_offset_dims; - int64_t - collapsed_slice_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_collapsed_slice_dims; - int64_t start_index_map[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_start_index_map; - int64_t index_vector_dim; - int64_t slice_sizes[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_slice_sizes; - bool indices_are_sorted; -} TfLiteStablehloGatherParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window - int64_t window_dimensions - [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t window_dilations - [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int body_subgraph_index; -} TfLiteStablehloReduceWindowParams; - -enum TfLiteReduceWindowFunction { - TfLiteReduceWindowFunctionUnsupported, - TfLiteReduceWindowFunctionAdd, - TfLiteReduceWindowFunctionMul, - TfLiteReduceWindowFunctionMin, - TfLiteReduceWindowFunctionMax, - TfLiteReduceWindowFunctionAll, - TfLiteReduceWindowFunctionAny -}; - -typedef struct { - enum TfLiteReduceWindowFunction reduce_function; -} TfLiteReduceWindowParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad - int64_t edge_padding_low[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; - int64_t edge_padding_high[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; - int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; -} TfLiteStablehloPadParams; - -typedef struct { - const char* name; - int32_t subgraph_index; - int32_t version; - const uint8_t* attributes; - size_t attributes_size; -} TfLiteStablehloCompositeParams; - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" // IWYU pragma: export +#include "tensorflow/lite/core/c/common.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index f0b76bde0258cb..dc2601bf127169 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -36,12 +36,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_C_C_API_TYPES_H_ #define TENSORFLOW_LITE_CORE_C_C_API_TYPES_H_ -#include - #ifdef __cplusplus extern "C" { #endif +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" // IWYU pragma: export + // clang-format off // NOLINTBEGIN(whitespace/line_length) /** \defgroup c_api_types lite/c/c_api_types.h @@ -112,42 +112,6 @@ typedef enum TfLiteStatus { kTfLiteCancelled = 8, } TfLiteStatus; -/// Types supported by tensor -// LINT.IfChange -typedef enum { - kTfLiteNoType = 0, - kTfLiteFloat32 = 1, - kTfLiteInt32 = 2, - kTfLiteUInt8 = 3, - kTfLiteInt64 = 4, - kTfLiteString = 5, - kTfLiteBool = 6, - kTfLiteInt16 = 7, - kTfLiteComplex64 = 8, - kTfLiteInt8 = 9, - kTfLiteFloat16 = 10, - kTfLiteFloat64 = 11, - kTfLiteComplex128 = 12, - kTfLiteUInt64 = 13, - kTfLiteResource = 14, - kTfLiteVariant = 15, - kTfLiteUInt32 = 16, - kTfLiteUInt16 = 17, - kTfLiteInt4 = 18, - kTfLiteBFloat16 = 19, -} TfLiteType; -// LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType) - -/// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`. -/// If per-layer quantization is specified this field will still be populated in -/// addition to `TfLiteAffineQuantization`. -/// Parameters for asymmetric quantization. Quantized values can be converted -/// back to float using: `real_value = scale * (quantized_value - zero_point)` -typedef struct TfLiteQuantizationParams { - float scale; - int32_t zero_point; -} TfLiteQuantizationParams; - // -------------------------------------------------------------------------- // Opaque types used by c_api.h, c_api_opaque.h and common.h. diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 5d3100816492ae..29cf64aa9d351c 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -442,12 +442,6 @@ enum { kTfLiteNullBufferHandle = -1, }; -/// Storage format of each dimension in a sparse tensor. -typedef enum TfLiteDimensionType { - kTfLiteDimDense = 0, - kTfLiteDimSparseCSR, -} TfLiteDimensionType; - /// Metadata to encode each dimension in a sparse tensor. typedef struct TfLiteDimensionMetadata { TfLiteDimensionType format; diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index e4c1432e6c8dda..c11994f7c86c24 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -142,7 +142,9 @@ TFLITE_HEADERS = [ "//tensorflow/lite/core/c:c_api.h", "//tensorflow/lite/core/c:c_api_opaque.h", "//tensorflow/lite/core/c:c_api_types.h", + "//tensorflow/compiler/mlir/lite/core/c:tflite_types.h", "//tensorflow/lite/core/c:builtin_op_data.h", + "//tensorflow/compiler/mlir/lite/core/c:builtin_op_data.h", "//tensorflow/lite/core/c:c_api_experimental.h", "//tensorflow/lite/core/c:common.h", "//tensorflow/lite/core/c:operator.h", From f32d99afa974591d7774ea3912ce2c24aa70c6d0 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 24 Sep 2024 11:37:20 -0700 Subject: [PATCH 199/483] [xla] Avoid repeatedly traversing computations in a module by processing the computations in post-order. PiperOrigin-RevId: 678332958 --- .../while_loop_all_reduce_code_motion.cc | 163 +++++++++--------- 1 file changed, 81 insertions(+), 82 deletions(-) diff --git a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc index c9b34c702efc32..c67a34628cc401 100644 --- a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc @@ -936,7 +936,7 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool is_changed = false; - bool run_next_pass = true; + // In case of MPMD, all-reduces might be cross-module and should preserve // their channel ID. Do not move all-reduces in this case since the channel // ID might be changed. @@ -965,96 +965,95 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::Run( // loop. We recursively sink the all-reduce through nested while loops if // applicable by repeating this process. uint32_t count_all_reduce = 0, count_reduce_scatter = 0; - while (run_next_pass) { - run_next_pass = false; - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module); + // We process all callees of a computation before processing the computation, + // so that when we process a computation, the all-reduce instructions that + // need to be hoisted to the computation from its callees have been hoisted. + for (HloComputation* computation : + module->MakeComputationPostOrder(execution_threads)) { // A computation could be the while body of multiple while instructions, // so we start from the computation and find all of its callers that is a // kWhile if there is any. - for (HloComputation* computation : - module->computations(execution_threads)) { - std::vector computation_callers = - call_graph->GetComputationCallers(computation); - std::vector while_caller_instructions; - for (HloInstruction* caller_instruction : computation_callers) { - // For simplicity, we only support while instructions whose shape is - // tuple. - if (caller_instruction->opcode() == HloOpcode::kWhile && - caller_instruction->shape().IsTuple() && - caller_instruction->while_body() == computation) { - while_caller_instructions.push_back(caller_instruction); - } - } - // Skip to next computation if this computation is not the while body of - // any while instruction. - if (while_caller_instructions.empty()) { - continue; + std::vector computation_callers = + call_graph->GetComputationCallers(computation); + std::vector while_caller_instructions; + for (HloInstruction* caller_instruction : computation_callers) { + // For simplicity, we only support while instructions whose shape is + // tuple. + if (caller_instruction->opcode() == HloOpcode::kWhile && + caller_instruction->shape().IsTuple() && + caller_instruction->while_body() == computation) { + while_caller_instructions.push_back(caller_instruction); } - std::vector while_body_all_reduces; - for (HloInstruction* while_body_instruction : - computation->MakeInstructionPostOrder()) { - HloOpcode op = while_body_instruction->opcode(); - const bool is_candidate = - (op == HloOpcode::kAllReduce) || - (enable_reduce_scatter_ && op == HloOpcode::kReduceScatter); - if (!is_candidate) { - continue; - } - auto* all_reduce_instruction = - Cast(while_body_instruction); - if (all_reduce_instruction->constrain_layout()) { - return false; - } else { - while_body_all_reduces.push_back(all_reduce_instruction); - } - } - HloInstructionMap> - all_reduce_to_accumulations; - for (HloAllReduceInstructionBase* all_reduce : while_body_all_reduces) { - auto movable_all_reduce_context = IsAllReduceMovable( - all_reduce, computation, cross_replica_replication_analysis, - cross_partition_replication_analysis); - if (movable_all_reduce_context.is_movable) { - all_reduce_to_accumulations[all_reduce] = - std::move(movable_all_reduce_context.accumulation_contexts); - } - VLOG(3) << "WhileLoopAllReduceCodeMotion, all-reduce: " - << all_reduce->ToString() - << " is_movable: " << movable_all_reduce_context.is_movable - << " while loop: " << while_caller_instructions.front()->name() - << " num_accumulations: " - << (movable_all_reduce_context.is_movable - ? all_reduce_to_accumulations[all_reduce].size() - : 0); - } - if (all_reduce_to_accumulations.empty()) { + } + // Skip to next computation if this computation is not the while body of + // any while instruction. + if (while_caller_instructions.empty()) { + continue; + } + std::vector while_body_all_reduces; + for (HloInstruction* while_body_instruction : + computation->MakeInstructionPostOrder()) { + HloOpcode op = while_body_instruction->opcode(); + const bool is_candidate = + (op == HloOpcode::kAllReduce) || + (enable_reduce_scatter_ && op == HloOpcode::kReduceScatter); + if (!is_candidate) { continue; } - // For each while instruction calling this computation, create the - // corresponding all-reduces after the while loop. - for (HloInstruction* while_instruction : while_caller_instructions) { - TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( - while_instruction, all_reduce_to_accumulations)); - is_changed = true; - run_next_pass = true; + auto* all_reduce_instruction = + Cast(while_body_instruction); + if (all_reduce_instruction->constrain_layout()) { + return false; + } else { + while_body_all_reduces.push_back(all_reduce_instruction); } - // At last, remove the old all-reduce instructions in the while body. - for (const auto& all_reduce_accumulations_pair : - all_reduce_to_accumulations) { - HloInstruction* all_reduce = all_reduce_accumulations_pair.first; - if (all_reduce->opcode() == HloOpcode::kAllReduce) { - count_all_reduce++; - } else { - count_reduce_scatter++; - } - TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( - all_reduce, all_reduce->mutable_operand(0))); + } + HloInstructionMap> + all_reduce_to_accumulations; + for (HloAllReduceInstructionBase* all_reduce : while_body_all_reduces) { + auto movable_all_reduce_context = IsAllReduceMovable( + all_reduce, computation, cross_replica_replication_analysis, + cross_partition_replication_analysis); + if (movable_all_reduce_context.is_movable) { + all_reduce_to_accumulations[all_reduce] = + std::move(movable_all_reduce_context.accumulation_contexts); } - // Needs to rebuild the call graph or we could access removed - // instructions. - if (run_next_pass) { - break; + VLOG(3) << "WhileLoopAllReduceCodeMotion, all-reduce: " + << all_reduce->ToString() + << " is_movable: " << movable_all_reduce_context.is_movable + << " while loop: " << while_caller_instructions.front()->name() + << " num_accumulations: " + << (movable_all_reduce_context.is_movable + ? all_reduce_to_accumulations[all_reduce].size() + : 0); + } + if (all_reduce_to_accumulations.empty()) { + continue; + } + // For each while instruction calling this computation, create the + // corresponding all-reduces after the while loop. + for (HloInstruction* while_instruction : while_caller_instructions) { + TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( + while_instruction, all_reduce_to_accumulations)); + is_changed = true; + } + // At last, remove the old all-reduce instructions in the while body. + for (const auto& all_reduce_accumulations_pair : + all_reduce_to_accumulations) { + HloInstruction* all_reduce = all_reduce_accumulations_pair.first; + if (all_reduce->opcode() == HloOpcode::kAllReduce) { + count_all_reduce++; + } else { + count_reduce_scatter++; } + TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( + all_reduce, all_reduce->mutable_operand(0))); + } + // Needs to rebuild the call graph after we remove instructions to avoid + // accessing removed instructions. + if (!all_reduce_to_accumulations.empty()) { + call_graph = CallGraph::Build(module); } } VLOG(2) << "Hoisted " << count_all_reduce << " all-reduce and " From 13ea77ab753dc0e3bae6b92777cd835e464b26e7 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 24 Sep 2024 12:00:11 -0700 Subject: [PATCH 200/483] [GPU] Fix compilation with NVIDIA driver 560. This closes https://github.com/openxla/xla/pull/17549 PiperOrigin-RevId: 678341835 --- third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl | 1 + .../tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl | 1 + 2 files changed, 2 insertions(+) diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index 417a237053c8e7..140879d7e271ed 100644 --- a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -145,6 +145,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "nvidia_driver": { "repo_name": "cuda_driver", "version_to_template": { + "560": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "555": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "550": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "545": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index 417a237053c8e7..140879d7e271ed 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -145,6 +145,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "nvidia_driver": { "repo_name": "cuda_driver", "version_to_template": { + "560": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "555": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "550": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "545": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", From 816488e6079b98a43f37444ad011b8e0f5b2d841 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Tue, 24 Sep 2024 13:11:42 -0700 Subject: [PATCH 201/483] [XLA:GPU] Pure cleanup. Use `constexpr std::string_view kHloText` instead of `const std::string kHloText` in Triton tests. PiperOrigin-RevId: 678368341 --- ...riton_fusion_emitter_device_legacy_test.cc | 185 +++++++++--------- .../triton_fusion_emitter_device_test.cc | 62 +++--- .../triton_fusion_emitter_large_test.cc | 5 +- 3 files changed, 126 insertions(+), 126 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 0ea4702d7445af..e76faaba5c93b6 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -57,7 +57,6 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" -#include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -153,7 +152,7 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { }; TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY main { @@ -177,7 +176,7 @@ TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { TEST_F(TritonGemmTest, LHSInt4WithMinorDimEqualTo1) { // We prove that triton can handle int4 dot with non contracting dim size // equal to 1. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -206,7 +205,7 @@ TEST_F(TritonGemmTest, LHSInt4WithMinorDimEqualTo1) { TEST_F(TritonGemmTest, RHSInt4WithMinorDimEqualTo1) { // We prove that triton can handle int4 dot with non contracting dim size // equal to 1. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -236,7 +235,7 @@ TEST_F(TritonGemmTest, RHSInt4WithMinorDimEqualTo1) { TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDim) { // We prove that triton can handle int4 dot with non minor // lhs_contracting_dim. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -264,7 +263,7 @@ TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDim) { TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDimWithBatchDim0) { // We prove that triton can handle int4 dot with non minor // lhs_contracting_dim. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -292,7 +291,7 @@ TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDimWithBatchDim0) { TEST_F(TritonGemmTest, LHSInt4MinorContractingDim) { // We prove that triton can handle int4 dot with minor lhs_contracting_dim. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -316,7 +315,7 @@ TEST_F(TritonGemmTest, LHSInt4MinorContractingDim) { } TEST_F(TritonGemmTest, Int4ConvertPlusNegate) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -342,7 +341,7 @@ TEST_F(TritonGemmTest, Int4ConvertPlusNegate) { TEST_F(TritonGemmTest, LHSInt4MinorContractingDimWithBatchDim0) { // We prove that triton can handle int4 dot with minor lhs_contracting_dim. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -369,7 +368,7 @@ TEST_F(TritonGemmTest, LHSInt4MinorContractingDimWithBatchDim0) { } TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDim) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -394,7 +393,7 @@ TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDim) { } TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDim) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -419,7 +418,7 @@ TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDim) { } TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDimWithBatchDim) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -446,7 +445,7 @@ TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDimWithBatchDim) { } TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDimWithBatchDim0) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -473,7 +472,7 @@ TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDimWithBatchDim0) { } TEST_F(TritonTest, TestGemm) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm_r { @@ -565,7 +564,7 @@ CHECK: } } TEST_F(TritonTest, TestGemmWithTrivialNonContractingDimension) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t, is_scheduled=true triton_dot { @@ -655,7 +654,7 @@ CHECK: } } TEST_F(TritonTest, PredParametersAreTruncatedToI1) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm_computation { @@ -696,7 +695,7 @@ CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> } TEST_F(TritonTest, CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm { @@ -739,7 +738,7 @@ CHECK: %[[BLOCK_BASE_PTR:.*]] = tt.addptr %[[ARG_PTR]], %[[OFFSET]] TEST_F(TritonTest, CodegenDynamicSliceWithCorrectOffsets) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_gemm { @@ -789,7 +788,7 @@ CHECK-DAG: tt.make_tensor_ptr %[[DYNAMIC_SLICE_INPUT]], [%[[C2_i64]], %[[ROW_L } TEST_F(TritonTest, SparseDot) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -820,7 +819,7 @@ CHECK: triton_gpu.sparse_dot %[[LHS]], %[[RHS]], %{{[^:]+}}, %[[META]] : } TEST_F(TritonTest, SparseDotWithMasking) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -857,7 +856,7 @@ CHECK: triton_gpu.sparse_dot %[[LHS_MASKED]], %[[RHS_MASKED]], %{{[^:]+}}, %[[ME } TEST_F(TritonTest, SparseDotBroadcastMetadata) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -894,7 +893,7 @@ CHECK: triton_gpu.sparse_dot %[[LHS]], %[[RHS]], %{{[^:]+}}, %[[META]] : } TEST_F(TritonGemmTest, DoNotUseTensorCoresWithNonDefaultPrecision) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_gemm_r { parameter_0 = s8[80,15]{1,0} parameter(0) convert.3 = f32[80,15]{1,0} convert(parameter_0) @@ -924,7 +923,7 @@ CHECK-NOT: mma } TEST_F(TritonGemmTest, DebugOptionsArePropagated) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[30,30] parameter(0) p1 = s8[30,30] parameter(1) @@ -976,7 +975,7 @@ ENTRY main { } TEST_F(TritonGemmTest, UseTensorCoresForF32OnAmpere) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_gemm_r { parameter_0 = f16[80,15]{1,0} parameter(0) convert.3 = f32[80,15]{1,0} convert(parameter_0) @@ -1008,7 +1007,7 @@ TEST_F(TritonGemmTest, FailIfTooMuchShmem) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule module, is_scheduled=true triton_gemm_dot { @@ -1085,7 +1084,7 @@ TEST_F(TritonGemmTestWithSplitK, // The condition mentioned in the test name is fulfilled by // GemmKey(16, 64, 256, 8, 1, 4), which was part of the default configs for // Ampere at the time of the addition of this test case. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule extracted ENTRY e { @@ -1239,7 +1238,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SplitAndTransposeLhsExecutesCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -1269,7 +1268,7 @@ TEST_F(TritonGemmTest, NondefaultOperandLayoutIsSupported) { #ifndef NDEBUG GTEST_SKIP() << "This test times out when -UNDEBUG is set."; #endif - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY r { p1 = f16[9,140,128]{2,1,0} parameter(1) cp = f16[9,140,128]{2,0,1} copy(p1) @@ -1442,7 +1441,7 @@ ENTRY e { } TEST_F(TritonGemmTest, MultipleBatchRequireSeparateTranspose) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -1465,7 +1464,7 @@ ENTRY e { } TEST_F(TritonGemmTest, CanCodegenNonBatchedDotWithConcatenationCorrectly) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { parameter_0 = f32[3,10]{1,0} parameter(0) parameter_1 = f32[10,128]{1,0} parameter(1) @@ -1489,7 +1488,7 @@ ENTRY e { } TEST_F(TritonGemmTest, CanCodegenBatchedDotWithConcatenationCorrectly) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { parameter_0 = f32[2,3,10]{2,1,0} parameter(0) parameter_1 = f32[2,10,128]{2,1,0} parameter(1) @@ -1534,7 +1533,7 @@ ENTRY e { } TEST_F(TritonTest, FloatToSignedIntConversion) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm_r { @@ -1595,7 +1594,7 @@ ENTRY e { // This tests the complexity heuristics in TritonWrapper. TEST_F(TritonGemmTest, FailForTooComplexTiling) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule module, is_scheduled=true triton_gemm_dot { @@ -1836,7 +1835,7 @@ TEST_F(TritonGemmTest, DynamicSliceIsSupportedInLhsEndToEnd) { // is not strictly needed, because we also support clamping the indices. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -1867,7 +1866,7 @@ ENTRY e { TEST_F(TritonGemmTest, DynamicSliceIsSupportedInRhs) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm { @@ -1900,7 +1899,7 @@ ENTRY e { } TEST_F(TritonGemmTest, MultiplePathsToSameOperandWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p0 = bf16[8192,512]{1,0} parameter(0) p1 = bf16[512,512]{1,0} parameter(1) @@ -1983,7 +1982,7 @@ TEST_F(TritonGemmTest, DynamicSliceOfMajormostContractingDimIsSupported) { // dimension is contracted. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm { @@ -2020,7 +2019,7 @@ TEST_F(TritonGemmTest, DynamicSliceOfMajormostBatchDimIsSupported) { // dimension is a batch. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm { @@ -2059,7 +2058,7 @@ TEST_F(TritonGemmTest, DynamicSliceSingleDimensionIntoReshapeIsSupported) { // layer weights and extracting them with dynamic slice. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm { @@ -2126,7 +2125,7 @@ ENTRY e { } TEST_F(TritonGemmTest, BroadcastOfScalarWorksCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( fusion { p0 = f16[2,18] parameter(0) p1 = f16[256,2] parameter(1) @@ -2196,7 +2195,7 @@ class TritonGemmLevel2TestAny : public TritonGemmLevel2Test { }; TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2222,7 +2221,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BinaryOperationWithLargeInputsIsNotFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2253,7 +2252,7 @@ ENTRY e { TEST_F(TritonGemmLevel2Test, ParametersWithDifferentLayoutsAreSupportedInOneScope) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = s8[5,3] parameter(0) p0c = f16[5,3] convert(p0) @@ -2276,7 +2275,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BinaryOperationOnLargeParametersIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2301,7 +2300,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, LinkingLibdeviceTwiceWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = s8[7,3] parameter(0) c0 = f32[7,3] convert(p0) @@ -2332,7 +2331,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfScalarParameterIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[64,256] parameter(0) p0c = f32[64,256] convert(p0) @@ -2353,7 +2352,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfScalarConstantIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2379,7 +2378,7 @@ TEST_F(TritonGemmLevel2Test, DoubleBroadcastOfScalarConstantIsHandled) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { c = s32[] constant(1) bc1 = s32[21]{0} broadcast(c), dimensions={} @@ -2403,7 +2402,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfVectorConstantIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2427,7 +2426,7 @@ TEST_F(TritonGemmLevel2Test, AlwaysFuseScalarConstantAtBroadcastInput) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = bf16[2,3,3]{2,1,0} parameter(0) p1 = bf16[3,2,3]{2,1,0} parameter(1) @@ -2454,7 +2453,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfVectorParameterIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_dot { p0 = f16[75] parameter(0) bc0 = f16[75,67] broadcast(p0), dimensions={0} @@ -2483,7 +2482,7 @@ TEST_F(TritonGemmLevel2Test, FuseConcatenation) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( e { p0 = s8[153,1536] parameter(0) p1 = s8[153,128] parameter(1) @@ -2509,7 +2508,7 @@ e { } TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheLeft) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2532,7 +2531,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheRight) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2555,7 +2554,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheLeft) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2578,7 +2577,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheRight) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2601,7 +2600,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumReturnsLHS) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2626,7 +2625,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumReturnsRHS) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2651,7 +2650,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumReturnsLHS) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2676,7 +2675,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumReturnsRHS) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2701,7 +2700,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SineOutputIsNotFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2724,7 +2723,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SliceInputIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[97,121] parameter(0) s0 = f16[7,101] slice(p0), slice={[3:10], [10:111]} @@ -2745,7 +2744,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SliceInputWithReshapeIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f32[363,1536] parameter(0) p1 = f32[4,1536,611] parameter(1) @@ -2767,7 +2766,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, NestedSlicingWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p1 = f32[6,24] parameter(1) slice1 = f32[5,20] slice(p1), slice={[1:6], [3:23]} @@ -2789,7 +2788,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SlicedBatchDimensionIsSupported) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[3,3,256] parameter(0) s0 = f16[3,3,128] slice(p0), slice={[0:3], [0:3], [123:251]} @@ -2814,7 +2813,7 @@ ENTRY e { TEST_F(TritonGemmTestWithSplitK, SplitKDoesNotBreakSlicedFragmentedContractingDimension) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[16,8,128]{2,1,0} parameter(0) s0 = f16[16,4,128]{2,1,0} slice(p0), @@ -2838,7 +2837,7 @@ ENTRY e { } TEST_F(TritonGemmTestWithSplitK, SplitKWithTrivialDimension) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY entry_computation { p0 = f16[1001,1]{1,0} parameter(0) convert = f32[1001,1]{1,0} convert(p0) @@ -2851,7 +2850,7 @@ ENTRY entry_computation { } TEST_F(TritonGemmLevel2Test, NarrowingConvertOutputIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2877,7 +2876,7 @@ TEST_F(TritonGemmLevel2Test, ParameterAfterDotIsFused) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2909,7 +2908,7 @@ TEST_F(TritonGemmLevel2Test, OutputFusionExecutesCorrectly) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2945,7 +2944,7 @@ TEST_F(TritonGemmLevel2Test, SplitLHSOutputTransposeAloneIsNotFused) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2978,7 +2977,7 @@ TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) { GTEST_SKIP() << "Skipped until corresponding issue on ROCm is fixed."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0t = (s8[5,18,20,150]) parameter(0) p0 = s8[5,18,20,150] get-tuple-element(p0t), index=0 @@ -3003,7 +3002,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SupportPredParametersUsedInExpressions) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p = pred[2,2]{1,0} parameter(0) a = f32[2,2]{1,0} parameter(1) @@ -4364,7 +4363,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -4389,7 +4388,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -4414,7 +4413,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -4440,7 +4439,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -4510,7 +4509,7 @@ class Triton6xBF16GemmTestWithFlag : public TritonTest { }; TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmWhenBothInputsAreF32) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4552,7 +4551,7 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4593,7 +4592,7 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4622,7 +4621,7 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x32xbf16> } TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleInfinity) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4665,7 +4664,7 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> } TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleNaN) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4720,7 +4719,7 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> // x_lo: 5.17201445e+33 // The result of x*x would be NaN instead of positive infinity. TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForInputsWithLargeExponent) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4766,7 +4765,7 @@ TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on ROCM."; } - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -4838,7 +4837,7 @@ class Triton3xBF16GemmTestWithFlag : public TritonTest { }; TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4880,7 +4879,7 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4921,7 +4920,7 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton3xBF16GemmTestWithFlag, NoEmit3xBF16GemmWhenBothInputsAreNotF32) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4949,7 +4948,7 @@ CHECK-NOT: tt.dot } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -4978,7 +4977,7 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x32xbf16> } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleInfinity) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -5021,7 +5020,7 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleNaN) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -5066,7 +5065,7 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForInputsWithLargeExponent) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -5112,7 +5111,7 @@ TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X3 not supported on ROCM."; } - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -5353,7 +5352,7 @@ TEST_F(TritonGemmTest, TestNoAutotuner) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "Autotuner is always in pipeline on Cuda."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[30,30] parameter(0) p1 = s8[30,30] parameter(1) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 3cd5da3117c673..a2de97c39bfe04 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -59,7 +59,7 @@ class TritonEmitterTest : public GpuCodegenTest { }; TEST_F(TritonEmitterTest, ReductionOnMinormostAxisIsEmittedCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -87,7 +87,7 @@ CHECK: "tt.reduce"(%[[LOAD:.*]]) <{axis = 1 : i32}> } TEST_F(TritonEmitterTest, ReductionOnMajormostAxisIsEmittedCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -115,7 +115,7 @@ CHECK: "tt.reduce"(%[[LOAD:.*]]) <{axis = 0 : i32}> } TEST_F(TritonEmitterTest, ReductionOnIntermediateAxisIsEmittedCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -145,7 +145,7 @@ CHECK: "tt.reduce"(%[[SELECT:.*]]) <{axis = 2 : i32}> } TEST_F(TritonEmitterTest, TestReductionWithTileSizeLargerThanSourceTensor) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -186,7 +186,7 @@ CHECK: }) // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterWithSoftMaxSingleParameter) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t add { Arg_0 = f32[] parameter(0) @@ -247,7 +247,7 @@ CHECK: } // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterWithMultipleParameters) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t add { @@ -312,7 +312,7 @@ CHECK-DAG: tt.store {{.*}} : !tt.ptr> } TEST_F(TritonEmitterTest, TestGenericEmitterWithMultipleTiledDimensions) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t max { @@ -395,7 +395,7 @@ CHECK-NEXT: tt.store {{.*}} : !tt.ptr> TEST_F( TritonEmitterTest, DiamondWithAdditionalDiamondParameterBroadcastedAlongReductionDimProducesAccurateResults) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule h1 max_computation { @@ -433,7 +433,7 @@ TEST_F(TritonEmitterTest, NestedReducerFusionGetsCodegenedCorrectly) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule softmax fused_convert { @@ -472,7 +472,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalDiamondParameterBroadcastedAlongBatchDimProducesAccurateResults) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule h1 max_computation { @@ -505,7 +505,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalSplatDiamondScalarParameterProducesAccurateResults) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule h1 max_computation { @@ -560,7 +560,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalBroadcastOf1DParameterAlongNonReductionDimensionsProducesAccurateResults) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule h1 max_computation { @@ -594,7 +594,7 @@ ENTRY main { // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, EmitterFailsIfComputeCapabilityIsBelowAmpere) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p0 = f32[10,10] parameter(0) p1 = f32[10,10] parameter(1) @@ -694,7 +694,7 @@ ENTRY entry_computation { // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should b // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterReductionFusion) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t add { Arg_0 = f32[] parameter(0) @@ -736,7 +736,7 @@ CHECK: tt.store {{.*}} : !tt.ptr> TEST_F(TritonEmitterTest, TestGenericEmitterWithReductonAndMultidimensionalTile) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t max { Arg_0 = f32[] parameter(0) @@ -764,7 +764,7 @@ ENTRY main { } TEST_F(TritonEmitterTest, TestSoftMaxWithTileElementsNotAllContiguous) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m region { @@ -793,7 +793,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileThatNeedsMasking) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m fused_computation { @@ -812,7 +812,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileElementsNotAllContiguous) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m fused_computation { @@ -831,7 +831,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileElementsNotAllContiguousUnaligned) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m fused_computation { @@ -854,7 +854,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, ReshapeIntoBroadcastIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { param_0 = f32[128,256]{1,0} parameter(0) reshape = f32[64,2,256]{2,1,0} reshape(param_0) @@ -880,7 +880,7 @@ CHECK: tt.reshape } TEST_F(TritonEmitterTest, BitcastIntoBroadcastIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { param_0 = f32[128,256]{1,0} parameter(0) bitcast = f32[64,2,256]{2,1,0} bitcast(param_0) @@ -906,7 +906,7 @@ CHECK: tt.reshape } TEST_F(TritonEmitterTest, BitcastNormalizedLayoutsIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[5,42] parameter(0) ROOT bitcast = s8[5,6,7] bitcast(p) @@ -934,7 +934,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastNonNormalizedInputLayoutIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,6,7] bitcast(p) @@ -962,7 +962,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastNonNormalizedOutputLayoutIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[5,42] parameter(0) ROOT bitcast = s8[5,6,7]{1,2,0} bitcast(p) @@ -991,7 +991,7 @@ CHECK: tt.store TEST_F(TritonEmitterTest, BitcastNonNormalizedInputOutputLayoutIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,6,7]{1,2,0} bitcast(p) @@ -1019,7 +1019,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastTransposeOnlyIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,42] bitcast(p) @@ -1048,7 +1048,7 @@ CHECK: tt.store // TODO(b/353484968): move this test to a deviceless file. TEST_F(TritonEmitterTest, GenericEmitterLowersBroadcastFrom0dOperandCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { // TODO(b/348565795): make this a 0D scalar directly once this is known to be // supported. @@ -1076,7 +1076,7 @@ CHECK-SAME: tensor<1x1xf32> -> tensor<8x4xf32> TEST_F(TritonEmitterTest, PredOutputIsStoredCorrectly) { // The 'pred' element type in XLA is unpacked and uses i8 for storage. This // is the only sub-byte type to have this behavior. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_computation { @@ -1109,7 +1109,7 @@ CHECK: tt.store {{.*}} %[[CASTED_OUT]] TEST_F(TritonEmitterTest, PredInputIsLoadedCorrectly) { // The 'pred' element type in XLA is unpacked and uses i8 for storage. This // is the only sub-byte type to have this behavior. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_computation { @@ -1145,7 +1145,7 @@ CHECK: arith.trunci %[[I8_PARAM]] : tensor<4xi8> to tensor<4xi1> } TEST_F(TritonEmitterTest, Transpose3D) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_computation { @@ -1175,7 +1175,7 @@ CHECK: tt.trans %[[TILE]] {order = array} : tensor<8x4x1xf32> // TODO(b/353484968): Delete this test once we have constraints to only // propagate tile sizes that are a power of 2. TEST_F(TritonEmitterTest, Transpose3D_TileFullDimThatIsNotPowerOf2) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_computation { diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc index 3e52ae45936b67..3cb8c7ce9f451f 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "absl/log/check.h" @@ -87,7 +88,7 @@ ENTRY e { } TEST_F(TritonGemmTest, LargeNonContractingProductWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -111,7 +112,7 @@ ENTRY e { } TEST_F(TritonGemmTest, LargeBatchWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { From 0de80ea346b9b1b9a18ca792c00b1a74eb948f33 Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Tue, 24 Sep 2024 13:28:24 -0700 Subject: [PATCH 202/483] Reland #17228 A couple of changes from the original change: 1. Don't use HloInstruction::operand_index() - this only returns the *first* occurence of an instruction in the operand sequence, thus if the same instruction is used in place of multiple orepards, we'll miss the subsequent ones. 2. Handle propagating throught root instructions better. We originally only fixed up entry computation roots but we need should the same for any while/conditional root, otherwise inserting tokens in these types of roots is non-trivial. Simplify things by explicitly disjoining these instructions from being roots during canonicalization. Reverts 9e1e4354fe2283c39e5d79a3391d0f1d19958f6b PiperOrigin-RevId: 678374890 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 32 +- third_party/xla/xla/hlo/ir/hlo_instruction.h | 11 +- third_party/xla/xla/service/BUILD | 35 + .../xla/service/infeed_token_propagation.cc | 457 ++++++++++++ .../xla/service/infeed_token_propagation.h | 45 ++ .../service/infeed_token_propagation_test.cc | 653 ++++++++++++++++++ .../while_loop_invariant_code_motion.cc | 1 + 7 files changed, 1228 insertions(+), 6 deletions(-) create mode 100644 third_party/xla/xla/service/infeed_token_propagation.cc create mode 100644 third_party/xla/xla/service/infeed_token_propagation.h create mode 100644 third_party/xla/xla/service/infeed_token_propagation_test.cc diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 7ca026f7c94448..ffcfafbcc663d7 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -2759,6 +2759,20 @@ int64_t HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand: " << target->ToString(); } +std::vector HloInstruction::operand_indices( + const HloInstruction* target) const { + std::vector indices; + for (int64_t i = 0; i < operand_count(); ++i) { + if (target == operand(i)) { + indices.push_back(i); + } + } + if (indices.empty()) { + LOG(FATAL) << "target was not an operand: " << target->ToString(); + } + return indices; +} + HloInstruction::InstructionVector HloInstruction::unique_operands() const { InstructionVector unique; absl::flat_hash_set seen; @@ -3399,18 +3413,30 @@ const PtrVec& HloInstruction::branch_computations() const { return called_computations(); } -int HloInstruction::branch_count() const { +int32_t HloInstruction::branch_count() const { CHECK(HloOpcode::kConditional == opcode_); return called_computations().size(); } -HloComputation* HloInstruction::branch_computation(int b) const { - CHECK(HloOpcode::kConditional == opcode_); +HloComputation* HloInstruction::branch_computation(int32_t b) const { + CHECK_EQ(HloOpcode::kConditional, opcode_); CHECK_GE(b, 0); CHECK_LT(b, called_computations().size()); return called_computations()[b]; } +int32_t HloInstruction::branch_index(HloComputation* computation) const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + CHECK_NE(computation, nullptr); + for (int32_t idx = 0; idx < branch_count(); idx++) { + if (branch_computation(idx) == computation) { + return idx; + } + } + LOG(FATAL) << absl::StrFormat("Conditional %s does not contain branch %s", + name(), computation->name()); +} + void HloInstruction::set_branch_computation(int b, HloComputation* computation) { CHECK_EQ(HloOpcode::kConditional, opcode_); diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 42729daec64df3..c1b64c06c7633c 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1493,10 +1493,14 @@ class HloInstruction { // within the operand vector. InstructionVector unique_operands() const; - // Returns the index of 'target' in the operands sequence. + // Returns the first index of 'target' that occurs in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64_t operand_index(const HloInstruction* target) const; + // Returns all indices of 'target' that occur in the operands sequence. + // Precondition: target must be an operand (or a fatal error will occur). + std::vector operand_indices(const HloInstruction* target) const; + // Returns the number of users of this instruction. int64_t user_count() const { return users_.size(); } @@ -1808,8 +1812,9 @@ class HloInstruction { // // Precondition: The instruction is a Conditional instruction. const PtrVec& branch_computations() const; - int branch_count() const; - HloComputation* branch_computation(int b) const; + int32_t branch_count() const; + HloComputation* branch_computation(int32_t b) const; + int32_t branch_index(HloComputation* computation) const; // Sets a branch HloComputation for Conditional. // The setter should only be called by HloModule or HloComputation methods. // diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index bda32e96ee0fac..53d6b29fc5bcd7 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -8543,4 +8543,39 @@ xla_cc_test( ], ) +cc_library( + name = "infeed_token_propagation", + srcs = ["infeed_token_propagation.cc"], + hdrs = ["infeed_token_propagation.h"], + deps = [ + ":hlo_dce", + ":tuple_simplifier", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "infeed_token_propagation_test", + srcs = ["infeed_token_propagation_test.cc"], + deps = [ + ":infeed_token_propagation", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/third_party/xla/xla/service/infeed_token_propagation.cc b/third_party/xla/xla/service/infeed_token_propagation.cc new file mode 100644 index 00000000000000..c14fe7e1824086 --- /dev/null +++ b/third_party/xla/xla/service/infeed_token_propagation.cc @@ -0,0 +1,457 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/infeed_token_propagation.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_dce.h" +#include "xla/service/tuple_simplifier.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { +bool IsDanglingInfeed(HloInstruction* infeed) { + CHECK(infeed->opcode() == HloOpcode::kInfeed); + if (infeed->has_sharding()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + return false; + } + + // Check for dangling input token. + if (const HloInstruction* after_all = infeed->operand(0); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + return false; + } + + // Check for dangling output token. + for (const HloInstruction* user : infeed->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 1) { + return false; + } + } + + return true; +} + +bool IsDanglingOutfeed(HloInstruction* outfeed) { + CHECK(outfeed->opcode() == HloOpcode::kOutfeed); + if (outfeed->has_sharding()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + return false; + } + + // Check for dangling input token. + if (const HloInstruction* after_all = outfeed->operand(1); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + return false; + } + + // Check for dangling output token. + if (outfeed->user_count() != 0) { + return false; + } + + return true; +} + +HloInstruction* ReconstructTuple(HloInstruction* tuple) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + std::vector gtes; + gtes.resize(tuple->shape().tuple_shapes_size()); + for (int64_t idx = 0; idx < gtes.size(); ++idx) { + gtes[idx] = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(tuple, idx)); + } + + return computation->AddInstruction(HloInstruction::CreateTuple(gtes)); +} + +absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, + bool add_token_operand) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + // Recreate the original tuple, we'll need to pass this to all the users. + // Trying to use tuple->ReplaceAllUsesWith(original_tuple) cause a cycle. + std::vector original_users = tuple->users(); + HloInstruction* original_tuple = ReconstructTuple(tuple); + for (HloInstruction* original_user : original_users) { + for (int64_t idx : original_user->operand_indices(tuple)) { + TF_RETURN_IF_ERROR( + original_user->ReplaceOperandWith(idx, original_tuple)); + } + } + + // Append the token to the parameter tuple. + *tuple->mutable_shape()->add_tuple_shapes() = ShapeUtil::MakeTokenShape(); + if (add_token_operand) { + tuple->AppendOperand( + computation->AddInstruction(HloInstruction::CreateToken())); + } + + HloInstruction* input_token_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + tuple, tuple->shape().tuple_shapes_size() - 1)); + return input_token_gte; +} + +absl::Status CanonicalizeConditionalBranch(HloComputation* branch) { + CHECK(branch->IsConditionalBranchComputation()); + CHECK_EQ(branch->num_parameters(), 1); + + // Tuplify the branch parameter if needed. + HloInstruction* parameter = branch->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = branch->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the branch tuple if needed. + HloInstruction* conditional = branch->ConditionalCallInstruction(); + int64_t branch_operand_idx = conditional->branch_index(branch) + 1; + HloInstruction* branch_tuple = + conditional->mutable_operand(branch_operand_idx); + if (!branch_tuple->shape().IsTuple()) { + branch_tuple = conditional->parent()->AddInstruction( + HloInstruction::CreateTuple({branch_tuple})); + TF_RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + } + + // Explicitly disjoin computation parameters from branch inputs, so we can + // insert tokens into the input tuple. + if (branch_tuple->opcode() == HloOpcode::kParameter) { + branch_tuple = ReconstructTuple(branch_tuple); + TF_RETURN_IF_ERROR( + conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); + } + + // Explicitly make the root of the branch a tuple. + HloInstruction* root = branch->root_instruction(); + if (root->opcode() != HloOpcode::kTuple) { + root = ReconstructTuple(root); + branch->set_root_instruction(root); + } + + // ConditionalCanonicalizer should have already turned the conditional output + // to be a tuple. + CHECK(conditional->shape().IsTuple()); + + // Explicitly disjoin the conditional from being a computation root, so that + // we can insert tokens into, while preserving the original computation shape. + if (conditional->IsRoot()) { + HloInstruction* new_root = ReconstructTuple(conditional); + conditional->parent()->set_root_instruction(new_root); + } + + return absl::OkStatus(); +} + +absl::Status CanonicalizeWhileBody(HloComputation* body) { + CHECK(body->IsWhileBodyComputation()); + CHECK_EQ(body->num_parameters(), 1); + + // Tuplify the body parameter if needed. + HloInstruction* parameter = body->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = body->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the body root if needed. + HloInstruction* root = body->root_instruction(); + if (!root->shape().IsTuple()) { + root = body->AddInstruction(HloInstruction::CreateTuple({root})); + body->set_root_instruction(root, /*accept_different_shape=*/true); + } + + // Tuplify the condition parameter if needed. + HloInstruction* loop = body->WhileCallInstruction(); + HloComputation* cond = loop->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + if (!cond_parameter->shape().IsTuple()) { + *cond_parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({cond_parameter->shape()}); + HloInstruction* original = cond->AddInstruction( + HloInstruction::CreateGetTupleElement(cond_parameter, 0)); + TF_RETURN_IF_ERROR( + cond_parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while instruction if needed. + if (!loop->shape().IsTuple()) { + *loop->mutable_shape() = ShapeUtil::MakeTupleShape({loop->shape()}); + HloInstruction* original = loop->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(loop, 0)); + TF_RETURN_IF_ERROR(loop->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while tuple if needed. + HloInstruction* loop_tuple = loop->mutable_operand(0); + if (!loop_tuple->shape().IsTuple()) { + loop_tuple = loop->parent()->AddInstruction( + HloInstruction::CreateTuple({loop_tuple})); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(0, loop_tuple)); + } + + // Explicitly disjoin computation parameters from loop inputs, so we can + // insert tokens into the input tuple. + if (loop_tuple->opcode() == HloOpcode::kParameter) { + loop_tuple = ReconstructTuple(loop_tuple); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWith(0, loop_tuple)); + } + + // Explicitly make the root of the body a tuple. + if (root->opcode() != HloOpcode::kTuple) { + root = ReconstructTuple(root); + body->set_root_instruction(root); + } + + // Explicitly disjoin the loop from being a computation root, so that + // we can insert tokens into, while preserving the original computation shape. + if (loop->IsRoot()) { + HloInstruction* new_root = ReconstructTuple(loop); + loop->parent()->set_root_instruction(new_root); + } + + return absl::OkStatus(); +} + +absl::StatusOr> +PropagateTokenThroughConditionalBranch(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + // Conditional branches can diverge in inputs, but must converge on outputs. + + // Fixup every branch of the conditional, since we have to insert a token + // into each branches root. + HloComputation* comp = instruction->parent(); + HloInstruction* next_instruction = comp->ConditionalCallInstruction(); + for (HloComputation* branch : next_instruction->branch_computations()) { + TF_RETURN_IF_ERROR(CanonicalizeConditionalBranch(branch)); + } + + // Insert the output token into each branch. + for (HloComputation* branch : next_instruction->branch_computations()) { + HloInstruction* root = branch->root_instruction(); + if (branch == comp) { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token); + } else { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/true).status()); + } + } + + // Insert the input token into the branch parameter. + HloInstruction* parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the branch tuple. + int64_t branch_operand_idx = next_instruction->branch_index(comp) + 1; + HloInstruction* branch_tuple = + next_instruction->mutable_operand(branch_operand_idx); + TF_ASSIGN_OR_RETURN( + HloInstruction * next_input_token_gte, + InsertTokenIntoTuple(branch_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR(next_instruction->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + HloInstruction* next_input_token = + branch_tuple->mutable_operand(next_input_token_gte->tuple_index()); + + // Insert the output token into conditional instruction. + TF_ASSIGN_OR_RETURN( + HloInstruction * next_output_token, + InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + + return std::make_tuple(next_instruction, next_input_token, next_output_token); +} + +absl::StatusOr> +PropagateTokenThroughWhileBody(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + // While loops need to converge on input and output. + + // Fixup the while body. + HloComputation* comp = instruction->parent(); + TF_RETURN_IF_ERROR(CanonicalizeWhileBody(comp)); + HloInstruction* next_instruction = comp->WhileCallInstruction(); + + // Insert the output token into the body root. + HloInstruction* root = comp->root_instruction(); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token); + + // Insert the input token into the body parameter. + HloInstruction* body_parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(body_parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the condition parameter. + HloComputation* cond = next_instruction->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(cond_parameter, /*add_token_operand=*/false) + .status()); + + // Insert the input token into the while tuple. + HloInstruction* while_tuple = next_instruction->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * next_input_token, + InsertTokenIntoTuple(while_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR( + next_instruction->ReplaceOperandWithDifferentShape(0, while_tuple)); + + // Insert the input token into the while instruction. + TF_ASSIGN_OR_RETURN( + HloInstruction * next_output_token, + InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + + return std::make_tuple(next_instruction, next_input_token, next_output_token); +} + +absl::Status PropagateToken(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + HloComputation* comp = instruction->parent(); + if (comp->IsEntryComputation()) { + return absl::OkStatus(); + } + + HloInstruction* next_instruction = nullptr; + HloInstruction* next_input_token = nullptr; + HloInstruction* next_output_token = nullptr; + if (comp->IsConditionalBranchComputation()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + if (comp->ConditionalCallInstruction()->has_sharding()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + std::tie(next_instruction, next_input_token, next_output_token), + PropagateTokenThroughConditionalBranch(instruction, input_token, + output_token)); + } else if (comp->IsWhileBodyComputation()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + if (comp->WhileCallInstruction()->has_sharding()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + std::tie(next_instruction, next_input_token, next_output_token), + PropagateTokenThroughWhileBody(instruction, input_token, output_token)); + } else { + // We only expect to encounter computations behind while and conditional + // instructions. In the case of it being behind a while condition, there is + // no way to propagate the output token, as the root only returns a + // predicate. All other computations that could possibly contain infeed + // or outfeed ops should have already been inlined. + VLOG(2) << "Unhandled computation: " << comp->name(); + return absl::OkStatus(); + } + CHECK_NE(next_instruction, nullptr); + CHECK_NE(next_input_token, nullptr); + CHECK_NE(next_output_token, nullptr); + + return PropagateToken(next_instruction, next_input_token, next_output_token); +} +} // namespace + +absl::StatusOr InfeedTokenPropagation::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + VLOG(5) << "Before InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + + std::vector dangling_infeeds; + std::vector dangling_outfeeds; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + if (!computation->IsEntryComputation()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed && + IsDanglingInfeed(instruction)) { + VLOG(1) << "Found dangling infeed: " << instruction->ToString(); + dangling_infeeds.push_back(instruction); + } else if (instruction->opcode() == HloOpcode::kOutfeed && + IsDanglingOutfeed(instruction)) { + VLOG(1) << "Found dangling outfeed: " << instruction->ToString(); + dangling_outfeeds.push_back(instruction); + } + } + } + } + + for (HloInstruction* dangling_infeed : dangling_infeeds) { + HloInstruction* input_token = dangling_infeed->mutable_operand(0); + HloInstruction* output_token = dangling_infeed->AddInstruction( + HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); + TF_RETURN_IF_ERROR( + PropagateToken(dangling_infeed, input_token, output_token)); + } + for (HloInstruction* dangling_outfeed : dangling_outfeeds) { + HloInstruction* input_token = dangling_outfeed->mutable_operand(1); + HloInstruction* output_token = dangling_outfeed; + TF_RETURN_IF_ERROR( + PropagateToken(dangling_outfeed, input_token, output_token)); + } + + bool changed = !dangling_infeeds.empty() || !dangling_outfeeds.empty(); + if (changed) { + TF_RETURN_IF_ERROR( + TupleSimplifier().Run(module, execution_threads).status()); + TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); + } + + VLOG(5) << "After InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + return changed; +} +} // namespace xla diff --git a/third_party/xla/xla/service/infeed_token_propagation.h b/third_party/xla/xla/service/infeed_token_propagation.h new file mode 100644 index 00000000000000..cc6994a62a98a9 --- /dev/null +++ b/third_party/xla/xla/service/infeed_token_propagation.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ +#define XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +// Finds dangling infeed/outfeed tokens inside nested computations and bubbles +// them up through callers until they reach the entry computation. This is +// needed to prepare these computations to be inlined, otherwise the previous +// computation boundaries won't be there to stop infeeds/outfeeds from being +// reordered during scheduling. +// +// This pass assumes the HLO graph is flattened. +class InfeedTokenPropagation : public HloModulePass { + public: + std::string_view name() const override { return "infeed-token-propagation"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; +} // namespace xla + +#endif // XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/infeed_token_propagation_test.cc b/third_party/xla/xla/service/infeed_token_propagation_test.cc new file mode 100644 index 00000000000000..59df2cc631954e --- /dev/null +++ b/third_party/xla/xla/service/infeed_token_propagation_test.cc @@ -0,0 +1,653 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/infeed_token_propagation.h" + +#include +#include + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class InfeedTokenPropagationTest : public HloTestBase { + protected: + InfeedTokenPropagationTest() = default; +}; + +TEST_F(InfeedTokenPropagationTest, EntryComputationInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT gte.0 = get-tuple-element(infeed.0), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, EntryComputationOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.1 = tuple() +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The infeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(arg.0) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalDuplicateOperand) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, tuple.0, tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the true tuple. + const HloInstruction* true_tuple = cond->operand(1); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should not have propagated through the false tuple. + const HloInstruction* false_tuple = cond->operand(2); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The infeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleConditional) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = s32[] parameter(0) + outfeed_tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, arg.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = cond->mutable_operand(1); + EXPECT_TRUE(true_tuple->shape().IsTuple()); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, DisjointConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, WhileInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed output token should have propagated through the while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed input token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, WhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + gte.0 = get-tuple-element(arg.0), index=0 + ROOT tuple.0 = tuple(gte.0) +} + +cond { + arg.0 = (s32[]) parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + while_tuple.0 = tuple(arg.0) + ROOT while.0 = (s32[]) while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, DisjointWhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleWhile) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = s32[] parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + ROOT while.0 = s32[] while(arg.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_TRUE(loop->shape().IsTuple()); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + EXPECT_THAT(loop->operand(0), op::Tuple(op::Parameter(), op::AfterAll())); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NestedInfeedOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + gte.0 = get-tuple-element(infeed.0), index=0 + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(gte.0) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed and outfeed output tokens should have propagated through the + // loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed input tokens should have propagated through the loop + // tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed output tokens should have propagated through the + // while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1), + op::GetTupleElement(op::Conditional(), 0))); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc index ed44547af3fca4..b1aae51df132e9 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc @@ -232,6 +232,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( } if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kAfterAll || instruction->opcode() == HloOpcode::kParameter || !instruction->control_predecessors().empty() || !instruction->control_successors().empty()) { From d2933c4aed48277607c3e266e06b9f78db1970cf Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Tue, 24 Sep 2024 13:29:06 -0700 Subject: [PATCH 203/483] Add new Docker Container image now that we have hermetic CUDA and hermetic Python. The new Docker Container image is still a WIP and will be based on the existing one in tensorflow-sigs. The image size is reduced from 5.8GB to 1.8GB. PiperOrigin-RevId: 678375199 --- ci/official/containers/ml_build/Dockerfile | 44 ++++ ci/official/containers/ml_build/README.md | 8 + .../builder.devtoolset/build_devtoolset.sh | 198 ++++++++++++++++++ .../ml_build/builder.devtoolset/fixlinks.sh | 28 +++ .../builder.devtoolset/glibc2.17-inline.patch | 11 + .../ml_build/builder.devtoolset/rpm-patch.sh | 28 +++ .../containers/ml_build/builder.packages.txt | 31 +++ .../containers/ml_build/setup.packages.sh | 28 +++ .../containers/ml_build/setup.sources.sh | 35 ++++ 9 files changed, 411 insertions(+) create mode 100644 ci/official/containers/ml_build/Dockerfile create mode 100644 ci/official/containers/ml_build/README.md create mode 100755 ci/official/containers/ml_build/builder.devtoolset/build_devtoolset.sh create mode 100755 ci/official/containers/ml_build/builder.devtoolset/fixlinks.sh create mode 100644 ci/official/containers/ml_build/builder.devtoolset/glibc2.17-inline.patch create mode 100755 ci/official/containers/ml_build/builder.devtoolset/rpm-patch.sh create mode 100644 ci/official/containers/ml_build/builder.packages.txt create mode 100755 ci/official/containers/ml_build/setup.packages.sh create mode 100644 ci/official/containers/ml_build/setup.sources.sh diff --git a/ci/official/containers/ml_build/Dockerfile b/ci/official/containers/ml_build/Dockerfile new file mode 100644 index 00000000000000..60682de037a3ee --- /dev/null +++ b/ci/official/containers/ml_build/Dockerfile @@ -0,0 +1,44 @@ +################################################################################ +FROM ubuntu:22.04@sha256:58b87898e82351c6cf9cf5b9f3c20257bb9e2dcf33af051e12ce532d7f94e3fe AS devel +################################################################################ + +# Install devtoolset build dependencies +COPY setup.sources.sh /setup.sources.sh +COPY setup.packages.sh /setup.packages.sh +COPY builder.packages.txt /builder.packages.txt + +RUN /setup.sources.sh && /setup.packages.sh /builder.packages.txt + +# Install devtoolset-9 in /dt9 with glibc 2.17 and libstdc++ 4.8, for building +# manylinux2014-compatible packages. +COPY builder.devtoolset/fixlinks.sh /fixlinks.sh +COPY builder.devtoolset/rpm-patch.sh /rpm-patch.sh +COPY builder.devtoolset/build_devtoolset.sh /build_devtoolset.sh +COPY builder.devtoolset/glibc2.17-inline.patch /glibc2.17-inline.patch +RUN /build_devtoolset.sh devtoolset-9 /dt9 + +# Make sure clang is on the path +RUN ln -s /usr/lib/llvm-18/bin/clang /usr/bin/clang + +# Install various tools. +# - bats: bash unit testing framework +# - bazelisk: always use the correct bazel version +# - buildifier: clean bazel build deps +# - buildozer: clean bazel build deps +# - gcloud SDK: communicate with Google Cloud Platform (GCP) for RBE, CI +# - patchelf: Utility tool to modify existing ELF executables and libraries +RUN git clone --branch v1.11.0 https://github.com/bats-core/bats-core.git && bats-core/install.sh /usr/local && rm -rf bats-core +RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.21.0/bazelisk-linux-amd64 -O /usr/local/bin/bazel && chmod +x /usr/local/bin/bazel +RUN wget https://github.com/bazelbuild/buildtools/releases/download/v7.3.1/buildifier-linux-amd64 -O /usr/local/bin/buildifier && chmod +x /usr/local/bin/buildifier +RUN wget https://github.com/bazelbuild/buildtools/releases/download/v7.3.1/buildozer-linux-amd64 -O /usr/local/bin/buildozer && chmod +x /usr/local/bin/buildozer +RUN curl -sSL https://sdk.cloud.google.com > /tmp/gcloud && bash /tmp/gcloud --install-dir=~/usr/local/bin --disable-prompts +# Download and install patchelf v0.18.0 from GitHub. The default Ubuntu focal +# packages only provide the "0.10-2build1" version. We use patchelf to manipulate +# certain shared libraries during the wheel building process (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/build_pip_package.sh#L255-L262). +# When we use Patchelf versions <0.12, those shared libraries end up with a +# corrupted PT_NOTE program header. This was fixed in v0.12, see https://github.com/NixOS/patchelf/commit/43a33482b501b0f5ee9da312aabfca3806570cc9. +RUN wget https://github.com/NixOS/patchelf/releases/download/0.18.0/patchelf-0.18.0-x86_64.tar.gz && tar -zxvf patchelf-0.18.0-x86_64.tar.gz -C /usr && rm -rf patchelf-0.18.0-x86_64.tar.gz + +# Don't use the bazel cache when a new docker image is created. +RUN echo build --action_env=DOCKER_CACHEBUSTER=$(date +%s%N)$RANDOM >> /etc/bazel.bazelrc +RUN echo build --host_action_env=DOCKER_HOST_CACHEBUSTER=$(date +%s%N)$RANDOM >> /etc/bazel.bazelrc diff --git a/ci/official/containers/ml_build/README.md b/ci/official/containers/ml_build/README.md new file mode 100644 index 00000000000000..53c01f529b300b --- /dev/null +++ b/ci/official/containers/ml_build/README.md @@ -0,0 +1,8 @@ +WIP ML Build Docker container for ML repositories (Tensorflow, JAX and XLA). + +This container branches off from +/tensorflow/tools/tf_sig_build_dockerfiles/. However, since +hermetic CUDA and hermetic Python is now available for Tensorflow, a lot of the +requirements installed on the original container can be removed to reduce the +footprint of the container and make it more reusable across different ML +repositories. diff --git a/ci/official/containers/ml_build/builder.devtoolset/build_devtoolset.sh b/ci/official/containers/ml_build/builder.devtoolset/build_devtoolset.sh new file mode 100755 index 00000000000000..b4c63677d7ae76 --- /dev/null +++ b/ci/official/containers/ml_build/builder.devtoolset/build_devtoolset.sh @@ -0,0 +1,198 @@ +#!/bin/bash -eu +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Builds a devtoolset cross-compiler targeting manylinux 2010 (glibc 2.12 / +# libstdc++ 4.4) or manylinux2014 (glibc 2.17 / libstdc++ 4.8). + +VERSION="$1" +TARGET="$2" + +case "${VERSION}" in +devtoolset-7) + LIBSTDCXX_VERSION="6.0.24" + LIBSTDCXX_ABI="gcc4-compatible" + ;; +devtoolset-9) + LIBSTDCXX_VERSION="6.0.28" + LIBSTDCXX_ABI="new" + ;; +*) + echo "Usage: $0 {devtoolset-7|devtoolset-9} " + echo "Use 'devtoolset-7' to build a manylinux2010 compatible toolchain or 'devtoolset-9' to build a manylinux2014 compatible toolchain" + exit 1 + ;; +esac + +mkdir -p "${TARGET}" + +# Download glibc's shared and development libraries based on the value of the +# `VERSION` parameter. +# Note: 'Templatizing' this and the other conditional branches would require +# defining several variables (version, os, path) making it difficult to maintain +# and extend for future modifications. +case "${VERSION}" in +devtoolset-7) + # Download binary glibc 2.12 shared library release. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6_2.12.1-0ubuntu6_amd64.deb" && \ + unar "libc6_2.12.1-0ubuntu6_amd64.deb" && \ + tar -C "${TARGET}" -xvzf "libc6_2.12.1-0ubuntu6_amd64/data.tar.gz" && \ + rm -rf "libc6_2.12.1-0ubuntu6_amd64.deb" "libc6_2.12.1-0ubuntu6_amd64" + # Download binary glibc 2.12 development library release. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6-dev_2.12.1-0ubuntu6_amd64.deb" && \ + unar "libc6-dev_2.12.1-0ubuntu6_amd64.deb" && \ + tar -C "${TARGET}" -xvzf "libc6-dev_2.12.1-0ubuntu6_amd64/data.tar.gz" && \ + rm -rf "libc6-dev_2.12.1-0ubuntu6_amd64.deb" "libc6-dev_2.12.1-0ubuntu6_amd64" + ;; +devtoolset-9) + # Download binary glibc 2.17 shared library release. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6_2.17-0ubuntu5.1_amd64.deb" && \ + unar "libc6_2.17-0ubuntu5.1_amd64.deb" && \ + tar -C "${TARGET}" -xvzf "libc6_2.17-0ubuntu5.1_amd64/data.tar.gz" && \ + rm -rf "libc6_2.17-0ubuntu5.1_amd64.deb" "libc6_2.17-0ubuntu5.1_amd64" + # Download binary glibc 2.17 development library release. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6-dev_2.17-0ubuntu5.1_amd64.deb" && \ + unar "libc6-dev_2.17-0ubuntu5.1_amd64.deb" && \ + tar -C "${TARGET}" -xvzf "libc6-dev_2.17-0ubuntu5.1_amd64/data.tar.gz" && \ + rm -rf "libc6-dev_2.17-0ubuntu5.1_amd64.deb" "libc6-dev_2.17-0ubuntu5.1_amd64" + ;; +esac + +# Put the current kernel headers from ubuntu in place. +ln -s "/usr/include/linux" "/${TARGET}/usr/include/linux" +ln -s "/usr/include/asm-generic" "/${TARGET}/usr/include/asm-generic" +ln -s "/usr/include/x86_64-linux-gnu/asm" "/${TARGET}/usr/include/asm" + +# Symlinks in the binary distribution are set up for installation in /usr, we +# need to fix up all the links to stay within /${TARGET}. +/fixlinks.sh "/${TARGET}" + +# Patch to allow non-glibc 2.12 compatible builds to work. +sed -i '54i#define TCP_USER_TIMEOUT 18' "/${TARGET}/usr/include/netinet/tcp.h" + +# Download specific version of libstdc++ shared library based on the value of +# the `VERSION` parameter +case "${VERSION}" in +devtoolset-7) + # Download binary libstdc++ 4.4 release we are going to link against. + # We only need the shared library, as we're going to develop against the + # libstdc++ provided by devtoolset. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/g/gcc-4.4/libstdc++6_4.4.3-4ubuntu5_amd64.deb" && \ + unar "libstdc++6_4.4.3-4ubuntu5_amd64.deb" && \ + tar -C "/${TARGET}" -xvzf "libstdc++6_4.4.3-4ubuntu5_amd64/data.tar.gz" "./usr/lib/libstdc++.so.6.0.13" && \ + rm -rf "libstdc++6_4.4.3-4ubuntu5_amd64.deb" "libstdc++6_4.4.3-4ubuntu5_amd64" + ;; +devtoolset-9) + # Download binary libstdc++ 4.8 shared library release + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/g/gcc-4.8/libstdc++6_4.8.1-10ubuntu8_amd64.deb" && \ + unar "libstdc++6_4.8.1-10ubuntu8_amd64.deb" && \ + tar -C "/${TARGET}" -xvzf "libstdc++6_4.8.1-10ubuntu8_amd64/data.tar.gz" "./usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.18" && \ + rm -rf "libstdc++6_4.8.1-10ubuntu8_amd64.deb" "libstdc++6_4.8.1-10ubuntu8_amd64" + ;; +esac + +mkdir -p "${TARGET}-src" +cd "${TARGET}-src" + +# Build a devtoolset cross-compiler based on our glibc 2.12/glibc 2.17 sysroot setup. +case "${VERSION}" in +devtoolset-7) + wget "http://vault.centos.org/centos/6/sclo/Source/rh/devtoolset-7/devtoolset-7-gcc-7.3.1-5.15.el6.src.rpm" + rpm2cpio "devtoolset-7-gcc-7.3.1-5.15.el6.src.rpm" |cpio -idmv + tar -xvjf "gcc-7.3.1-20180303.tar.bz2" --strip 1 + ;; +devtoolset-9) + wget "https://vault.centos.org/centos/7/sclo/Source/rh/devtoolset-9-gcc-9.3.1-2.2.el7.src.rpm" + rpm2cpio "devtoolset-9-gcc-9.3.1-2.2.el7.src.rpm" |cpio -idmv + tar -xvf "gcc-9.3.1-20200408.tar.xz" --strip 1 + ;; +esac + +# Apply the devtoolset patches to gcc. +/rpm-patch.sh "gcc.spec" + +./contrib/download_prerequisites + +mkdir -p "${TARGET}-build" +cd "${TARGET}-build" + +"${TARGET}-src/configure" \ + --prefix=/"${TARGET}/usr" \ + --with-sysroot="/${TARGET}" \ + --disable-bootstrap \ + --disable-libmpx \ + --disable-libsanitizer \ + --disable-libunwind-exceptions \ + --disable-libunwind-exceptions \ + --disable-lto \ + --disable-multilib \ + --enable-__cxa_atexit \ + --enable-gnu-indirect-function \ + --enable-gnu-unique-object \ + --enable-initfini-array \ + --enable-languages="c,c++" \ + --enable-linker-build-id \ + --enable-plugin \ + --enable-shared \ + --enable-threads=posix \ + --with-default-libstdcxx-abi=${LIBSTDCXX_ABI} \ + --with-gcc-major-version-only \ + --with-linker-hash-style="gnu" \ + --with-tune="generic" \ + && \ + make -j 42 && \ + make install + + +# Create the devtoolset libstdc++ linkerscript that links dynamically against +# the system libstdc++ 4.4 and provides all other symbols statically. +case "${VERSION}" in +devtoolset-7) +mv "/${TARGET}/usr/lib/libstdc++.so.${LIBSTDCXX_VERSION}" \ + "/${TARGET}/usr/lib/libstdc++.so.${LIBSTDCXX_VERSION}.backup" +echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.13 -lstdc++_nonshared44 )" \ + > "/${TARGET}/usr/lib/libstdc++.so.${LIBSTDCXX_VERSION}" +cp "./x86_64-pc-linux-gnu/libstdc++-v3/src/.libs/libstdc++_nonshared44.a" \ + "/${TARGET}/usr/lib" + ;; +devtoolset-9) +# Note that the installation path for libstdc++ here is /${TARGET}/usr/lib64/ +mv "/${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}" \ + "/${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}.backup" +echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.18 -lstdc++_nonshared44 )" \ + > "/${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}" +cp "./x86_64-pc-linux-gnu/libstdc++-v3/src/.libs/libstdc++_nonshared44.a" \ + "/${TARGET}/usr/lib64" +;; +esac + +# Link in architecture specific includes from the system; note that we cannot +# link in the whole x86_64-linux-gnu folder, as otherwise we're overlaying +# system gcc paths that we do not want to find. +# TODO(klimek): Automate linking in all non-gcc / non-kernel include +# directories. +mkdir -p "/${TARGET}/usr/include/x86_64-linux-gnu" +PYTHON_VERSIONS=("python3.9" "python3.10" "python3.11" "python3.12") +for v in "${PYTHON_VERSIONS[@]}"; do + ln -s "/usr/local/include/${v}" "/${TARGET}/usr/include/x86_64-linux-gnu/${v}" +done + +# Patch glibc to be compatable with modern clang +case "${VERSION}" in +devtoolset-9) + cd / + patch -p0 < /glibc2.17-inline.patch +;; +esac diff --git a/ci/official/containers/ml_build/builder.devtoolset/fixlinks.sh b/ci/official/containers/ml_build/builder.devtoolset/fixlinks.sh new file mode 100755 index 00000000000000..86856d80d9ceb1 --- /dev/null +++ b/ci/official/containers/ml_build/builder.devtoolset/fixlinks.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Re-direct all links in $1 that point to /lib... to point to $1/lib... instead. + +BASE="$1" +find "${BASE}" -type l | \ + while read l ; do + if [[ "$(readlink "$l")" == /lib* ]]; then + ORIG="$(readlink "$l")"; + rm "$l"; + ln -s "${BASE}${ORIG}" "$l" + fi + done + diff --git a/ci/official/containers/ml_build/builder.devtoolset/glibc2.17-inline.patch b/ci/official/containers/ml_build/builder.devtoolset/glibc2.17-inline.patch new file mode 100644 index 00000000000000..db8c3423a38298 --- /dev/null +++ b/ci/official/containers/ml_build/builder.devtoolset/glibc2.17-inline.patch @@ -0,0 +1,11 @@ +--- /dt9/usr/include/x86_64-linux-gnu/sys/cdefs.h 2013-09-30 13:58:17.000000000 +0000 ++++ /dt9/usr/include/x86_64-linux-gnu/sys/cdefs.new.h 2022-11-04 17:17:31.727061220 +0000 +@@ -320,7 +320,7 @@ + + /* GCC 4.3 and above with -std=c99 or -std=gnu99 implements ISO C99 + inline semantics, unless -fgnu89-inline is used. */ +-#if (!defined __cplusplus || __GNUC_PREREQ (4,3)) && defined __GNUC__ ++#if (!defined __cplusplus || __GNUC_PREREQ (4,3) || defined __clang__) && defined __GNUC__ + # if defined __GNUC_STDC_INLINE__ || defined __cplusplus + # define __extern_inline extern __inline __attribute__ ((__gnu_inline__)) + # define __extern_always_inline \ \ No newline at end of file diff --git a/ci/official/containers/ml_build/builder.devtoolset/rpm-patch.sh b/ci/official/containers/ml_build/builder.devtoolset/rpm-patch.sh new file mode 100755 index 00000000000000..892ae2af86a3fa --- /dev/null +++ b/ci/official/containers/ml_build/builder.devtoolset/rpm-patch.sh @@ -0,0 +1,28 @@ +#!/bin/bash -eu +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Given an RPM spec file $1, apply its patches. + +SPEC="$1" +grep '%patch' "${SPEC}" |while read cmd ; do + N=$(echo "${cmd}" |sed 's,%patch\([0-9]\+\).*,\1,') + file=$(grep "Patch$N:" "${SPEC}" |sed 's,.*: ,,') + parg=$(echo "${cmd}" |sed 's,.*\(-p[0-9]\).*,\1,') + if [[ ! "${file}" =~ doxygen && "${cmd}" != \#* ]]; then + echo "patch ${parg} -s < ${file}" + patch ${parg} -s < "${file}" + fi +done diff --git a/ci/official/containers/ml_build/builder.packages.txt b/ci/official/containers/ml_build/builder.packages.txt new file mode 100644 index 00000000000000..043ee6e1a7fa54 --- /dev/null +++ b/ci/official/containers/ml_build/builder.packages.txt @@ -0,0 +1,31 @@ +# Packages to be installed for the new Docker image. + +# Packages needed to build devtoolset +file +flex +g++ +make +patch +rpm2cpio +unar +wget +xz-utils +cpio + +# Other build-related tools +apt-transport-https +autoconf +automake +build-essential +ca-certificates +llvm-18 +clang-18 +clang-tidy-18 +lld-18 +clang-format-12 +curl +git +sudo +swig +unzip +zip diff --git a/ci/official/containers/ml_build/setup.packages.sh b/ci/official/containers/ml_build/setup.packages.sh new file mode 100755 index 00000000000000..f808cf7d22a7ce --- /dev/null +++ b/ci/official/containers/ml_build/setup.packages.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# setup.packages.sh: Given a list of Ubuntu packages, install them and clean up. +# Usage: setup.packages.sh +set -e + +# Prevent apt install tzinfo from asking our location (assumes UTC) +export DEBIAN_FRONTEND=noninteractive + +apt-get update +# Remove commented lines and blank lines +apt-get install -y --no-install-recommends $(sed -e '/^\s*#.*$/d' -e '/^\s*$/d' "$1" | sort -u) +rm -rf /var/lib/apt/lists/* diff --git a/ci/official/containers/ml_build/setup.sources.sh b/ci/official/containers/ml_build/setup.sources.sh new file mode 100644 index 00000000000000..2039a4697927ab --- /dev/null +++ b/ci/official/containers/ml_build/setup.sources.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Sets up custom apt sources for our TF images. + +# Prevent apt install tzinfo from asking our location (assumes UTC) +export DEBIAN_FRONTEND=noninteractive + +# Set up shared custom sources +apt-get update +apt-get install -y gnupg ca-certificates + +# LLVM/Clang: https://apt.llvm.org/ +apt-key adv --fetch-keys https://apt.llvm.org/llvm-snapshot.gpg.key + +# Set up custom sources +cat >/etc/apt/sources.list.d/custom.list < Date: Tue, 24 Sep 2024 13:44:41 -0700 Subject: [PATCH 204/483] Add `-stablehlo-create-compatibility-expander` pass to `AddPreQuantizationStableHloToTfPasses` with `tflite_supported_stablehlo_version`. PiperOrigin-RevId: 678381401 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../batched_gather_round_trip.mlir | 32 +++++++++++++++++ .../batched_scatter_round_trip.mlir | 34 ++++++++++++++++++ .../compiler/mlir/lite/tf_tfl_passes.cc | 8 +++++ third_party/xla/xla/mlir_hlo/BUILD | 1 + .../stablehlo_ext/transforms/passes.h | 8 +++++ ...tablehlo_create_compatibility_expander.cpp | 35 +++++++++++++++++++ 7 files changed, 119 insertions(+) create mode 100644 tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir create mode 100644 tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir create mode 100644 third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 56c88a001d0f86..3da89f496218a3 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1593,6 +1593,7 @@ cc_library( ":tensorflow_lite_push_transpose_through_ewise", # buildcleaner: keep ":tensorflow_lite_quantize", # buildcleaner: keep ":tensorflow_lite_tf_unfreeze_global_tensors", + "//tensorflow/compiler/mlir/lite/core:macros", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", "//tensorflow/compiler/mlir/lite/stablehlo:build_stablehlo_composite", diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir new file mode 100644 index 00000000000000..12de9da5939573 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir @@ -0,0 +1,32 @@ +// RUN: tf_tfl_translate --enable-hlo-to-tf-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s --check-prefix=CHECK-ROUNDTRIP + + +module { + // CHECK-LABEL: func.func public @main + func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { + // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 + // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> + // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %2) <{ + // CHECK-ROUNDTRIP-SAME: dimension_numbers = #stablehlo.gather< + // CHECK-ROUNDTRIP-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], + // CHECK-ROUNDTRIP-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, + // CHECK-ROUNDTRIP-SAME: slice_sizes = array}> : + // CHECK-ROUNDTRIP-SAME: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32> + // CHECK-ROUNDTRIP: return %[[gather]] + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> + return %0 : tensor<4x3x5x8xi32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir new file mode 100644 index 00000000000000..44d1bb7dd8b72f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir @@ -0,0 +1,34 @@ +// RUN: tf_tfl_translate --enable-hlo-to-tf-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s --check-prefix=CHECK-ROUNDTRIP + + +module { + // CHECK-LABEL: func.func public @main + func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { + // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 + // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> + // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %2, %arg2) <{ + // CHECK-ROUNDTRIP-SAME: scatter_dimension_numbers = #stablehlo.scatter + // CHECK-ROUNDTRIP-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], + // CHECK-ROUNDTRIP-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>}> + // CHECK-ROUNDTRIP: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> + // CHECK-ROUNDTRIP: return %[[scatter]] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ + indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1, 3], + input_batching_dims = [0, 2], + scatter_indices_batching_dims = [1, 0], + scatter_dims_to_operand_dims = [1, 3], + index_vector_dim = 3 + >, + unique_indices = false + }> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> + return %0 : tensor<3x2x4x7x9xi32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index d415c9a63d8473..c5a4a766bbd23a 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/core/macros.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" @@ -194,6 +195,13 @@ void AddPreQuantizationStableHloToTfPasses( // to be consistent with other entrypoints. pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + // Expand backward compatibility with the given StableHLO version by + // decomposing newer StableHLO operations into equivalent operations supported + // by that older version. + pass_manager.addNestedPass( + mlir::stablehlo_ext::createStablehloCreateCompatibilityExpanderPass( + tflite_supported_stablehlo_version)); + // Decompose CHLO into StableHLO ops pass_manager.addNestedPass( mlir::odml::CreateLegalizeChloToTflPass()); diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 11764a0594c9f4..2602068d4b5d22 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -1130,6 +1130,7 @@ cc_library( srcs = [ "stablehlo_ext/transforms/chlo_recompose_ops.cpp", "stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp", + "stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp", "stablehlo_ext/transforms/stablehlo_refine_shapes.cpp", ], hdrs = [ diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.h b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.h index 5c54390de66f4f..c72a92f112b23d 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.h @@ -17,8 +17,10 @@ limitations under the License. #define STABLEHLO_EXT_TRANSFORMS_PASSES_H #include +#include #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -30,6 +32,12 @@ namespace stablehlo_ext { void createChloLegalizeToStablehloPipeline(OpPassManager &pm); +// Expand backward compatibility with the given StableHLO version by decomposing +// newer StableHLO operations into equivalent operations supported by that older +// version. +std::unique_ptr createStablehloCreateCompatibilityExpanderPass( + std::string targetVersionOption); + } // namespace stablehlo_ext } // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp new file mode 100644 index 00000000000000..0db5fd4780b67d --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp @@ -0,0 +1,35 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/Pass/Pass.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo_ext/transforms/passes.h" + +namespace mlir { +namespace stablehlo_ext { + +// TODO(b/369406385): remove this method (and file) once issue is resolved. + +std::unique_ptr<::mlir::Pass> createStablehloCreateCompatibilityExpanderPass( + std::string targetVersionOption) { + return mlir::stablehlo::createStablehloCreateCompatibilityExpanderPass( + {std::move(targetVersionOption)}); +} + +} // namespace stablehlo_ext +} // namespace mlir From 7402f35fa7310d7300f620d7bad42ec83c67def2 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Tue, 24 Sep 2024 14:08:56 -0700 Subject: [PATCH 205/483] [xla:SpmdPartitioner] Support partitioning along the explicit batch dimensions in scatter instructions. This cl is similar to cl/677049477. Explicit batch dimensions are added for scatter instructions in https://github.com/openxla/stablehlo/pull/2084. This cl allows us partitioning along these explicit batch dimensions. Before this cl, we already have `PartitionScatterIndexParallelDimensions`, where the index parallel dimensions are implicit batch dimensions. We reuse most of the code in this function and implement `PartitionScatterExplicitBatchDimensions`. PiperOrigin-RevId: 678390590 --- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 3 +- .../service/spmd/gather_scatter_handler.cc | 344 ++++++++++-------- .../xla/service/spmd/spmd_partitioner_test.cc | 94 +++++ 3 files changed, 296 insertions(+), 145 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index aa7a485cbc685d..6d01b8fe6424fa 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -1667,7 +1667,8 @@ std::vector GetScatterSliceSize(const Shape& operand_shape, std::vector slice_size(operand_shape.rank(), 1); int64_t num_update_window_dims = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(dnums.inserted_window_dims(), i)) { + if (absl::c_linear_search(dnums.inserted_window_dims(), i) || + absl::c_linear_search(dnums.input_batching_dims(), i)) { continue; } slice_size[i] = update_shape.dimensions( diff --git a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc index 7bac1768aebf30..4b9a700fe927ad 100644 --- a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc +++ b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc @@ -580,7 +580,7 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( return nullptr; } -absl::StatusOr PartitionGatherBatchDimensions( +absl::StatusOr PartitionGatherParallelDimensions( const HloGatherInstruction* gather, PartitionedHlo operand, PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, @@ -588,6 +588,12 @@ absl::StatusOr PartitionGatherBatchDimensions( bool allow_recursive, const hlo_sharding_util::GatherScatterParallelDims& parallel_dims, bool need_offset) { + auto gather_sharding = GatherScatterOperandsShardedAcrossParallelDims( + *operand.hlo(), *indices.hlo(), parallel_dims); + if (!gather_sharding.has_value()) { + return nullptr; + } + // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -596,12 +602,6 @@ absl::StatusOr PartitionGatherBatchDimensions( } }; - auto gather_sharding = GatherScatterOperandsShardedAcrossParallelDims( - *operand.hlo(), *indices.hlo(), parallel_dims); - if (!gather_sharding.has_value()) { - return nullptr; - } - SpmdBuilder* b = visitor->builder(); const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); const int64_t index_dim = dnums.index_vector_dim(); @@ -735,10 +735,10 @@ absl::StatusOr PartitionGatherIndexParallelDimensions( if (!parallel_dims.has_value()) { return nullptr; } - return PartitionGatherBatchDimensions(gather, operand, indices, output_shape, - output_sharding, batch_dims, - slice_sizes, visitor, allow_recursive, - *parallel_dims, /*need_offset=*/true); + return PartitionGatherParallelDimensions( + gather, operand, indices, output_shape, output_sharding, batch_dims, + slice_sizes, visitor, allow_recursive, *parallel_dims, + /*need_offset=*/true); } // Partition a gather over explicit batch dimensions defined in @@ -762,10 +762,10 @@ absl::StatusOr PartitionGatherExplicitBatchDimensions( dnums.start_indices_batching_dims().begin(), dnums.start_indices_batching_dims().end()); - return PartitionGatherBatchDimensions(gather, operand, indices, output_shape, - output_sharding, batch_dims, - slice_sizes, visitor, allow_recursive, - parallel_dims, /*need_offset=*/false); + return PartitionGatherParallelDimensions( + gather, operand, indices, output_shape, output_sharding, batch_dims, + slice_sizes, visitor, allow_recursive, parallel_dims, + /*need_offset=*/false); } // Returns a full list of partitioning methods used for gather. @@ -1078,16 +1078,20 @@ absl::StatusOr PartitionScatter( absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, bool allow_recursive = true); -// Partition a scatter over a indices dimensions that are cosidered parallel -// (which means that the indices access the operand in a monotonically -// increasing way across the respective operand dimension referenced by the -// index). -absl::StatusOr PartitionScatterIndexParallelDimensions( +absl::StatusOr PartitionScatterParallelDimensions( const HloScatterInstruction* scatter, std::vector operands, PartitionedHlo indices, std::vector updates, const Shape& output_shape, const HloSharding& output_sharding, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, - bool allow_recursive) { + bool allow_recursive, + const hlo_sharding_util::GatherScatterParallelDims& parallel_dims, + bool need_offset) { + auto scatter_sharding = GatherScatterOperandsShardedAcrossParallelDims( + *operands[0].hlo(), *indices.hlo(), parallel_dims); + if (!scatter_sharding) { + return nullptr; + } + // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -1099,134 +1103,182 @@ absl::StatusOr PartitionScatterIndexParallelDimensions( SpmdBuilder* b = visitor->builder(); const auto& dnums = scatter->scatter_dimension_numbers(); const int64_t index_dim = dnums.index_vector_dim(); - // Handle the case where operand is tile maximal. In this case we check if - // the index is not TileMaximal and in this case we use the index sharding - // to drive the output sharding. - if (std::optional - parallel_dims = hlo_sharding_util::GetScatterParallelBatchDims( - *scatter, visitor->call_graph())) { - if (auto scatter_sharding = GatherScatterOperandsShardedAcrossParallelDims( - *operands[0].hlo(), *indices.hlo(), *parallel_dims)) { - const auto operand_parallel_dims = parallel_dims->operand_parallel_dims; - const auto indices_parallel_dims = parallel_dims->indices_parallel_dims; - const auto update_parallel_dims = - hlo_sharding_util::GetScatterParallelUpdateDims(*scatter, - *parallel_dims); - for (auto& operand : operands) { - operand = operand.Reshard(scatter_sharding->operand_sharding); - } - indices = indices.Reshard(scatter_sharding->indices_sharding); - HloSharding update_sharding = hlo_sharding_util:: - GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( - indices.sharding(), updates[0].rank(), indices_parallel_dims, - update_parallel_dims); - // Refine update sharding from the operand. it should be inferred from - // operand sharding, so that the partitioned scatter can be either 1) - // directly created on the partitioned operand, or 2) recursively created - // without aligning the groups. - if (auto maybe_passthrough = hlo_sharding_util:: - ScatterUpdateShardingFromOutputOperandPassthroughDimensions( - operands[0].base_shape(), - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - operands[0].sharding(), operand_parallel_dims), - *scatter, slice_sizes)) { - hlo_sharding_util::MergeShardingIfCompatible(*maybe_passthrough, - &update_sharding); - } - for (auto& update : updates) { - update = update.Reshard(update_sharding); - } + const auto operand_parallel_dims = parallel_dims.operand_parallel_dims; + const auto indices_parallel_dims = parallel_dims.indices_parallel_dims; + const auto update_parallel_dims = + hlo_sharding_util::GetScatterParallelUpdateDims(*scatter, parallel_dims); + for (auto& operand : operands) { + operand = operand.Reshard(scatter_sharding->operand_sharding); + } + indices = indices.Reshard(scatter_sharding->indices_sharding); + HloSharding update_sharding = hlo_sharding_util:: + GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( + indices.sharding(), updates[0].rank(), indices_parallel_dims, + update_parallel_dims); + if (!need_offset) { + hlo_sharding_util::MergeShardingIfCompatible( + hlo_sharding_util:: + ScatterUpdateShardingFromIndexIndexPassthroughDimensions( + indices.sharding(), scatter), + &update_sharding); + } - // Construct the offsets for the operand sharding to be used to adjust - // the indices. Because we know the only dimensions partitioned are the - // parallel ones and because the partitioning is the same across indices - // and operands we can apply the offsets on the operands on the indices. - std::vector operand_offsets = MakePartitionOffsets( - operands[0].base_shape(), operands[0].sharding(), - operands[0].state().partition_id, b, operand_parallel_dims); - absl::InlinedVector index_offsets; - for (int start_idx = 0; - start_idx < dnums.scatter_dims_to_operand_dims_size(); ++start_idx) { - HloInstruction* index_offset = - indices.base_shape().dimensions_size() > index_dim - ? b->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(S32, {1}), - operand_offsets[dnums.scatter_dims_to_operand_dims( - start_idx)])) - : operand_offsets[dnums.scatter_dims_to_operand_dims( - start_idx)]; - index_offsets.push_back(index_offset); - } - HloInstruction* adjusted_indices = nullptr; - if (indices.base_shape().dimensions_size() > index_dim) { - // Concatenate the offsets for the parallel dimensions to subtract. - adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(S32, - {indices.base_shape().dimensions(index_dim)}), - index_offsets, 0)); - } else { - CHECK_EQ(index_offsets.size(), 1); - adjusted_indices = index_offsets[0]; - } - if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { - adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(adjusted_indices->shape(), - indices.hlo()->shape().element_type()), - adjusted_indices)); - } - if (adjusted_indices->shape().rank() == 0) { - adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {})); - } else { - adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {index_dim})); - } - // Adjust indices by subtracting the offsets based on the partition id. - adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( - indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + // Refine update sharding from the operand. it should be inferred from + // operand sharding, so that the partitioned scatter can be either 1) + // directly created on the partitioned operand, or 2) recursively created + // without aligning the groups. + if (auto maybe_passthrough = hlo_sharding_util:: + ScatterUpdateShardingFromOutputOperandPassthroughDimensions( + operands[0].base_shape(), + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + operands[0].sharding(), operand_parallel_dims), + *scatter, slice_sizes)) { + hlo_sharding_util::MergeShardingIfCompatible(*maybe_passthrough, + &update_sharding); + } + + for (auto& update : updates) { + update = update.Reshard(update_sharding); + } + + // Construct the offsets for the operand sharding to be used to adjust + // the indices. Because we know the only dimensions partitioned are the + // parallel ones and because the partitioning is the same across indices + // and operands we can apply the offsets on the operands on the indices. + PartitionedHlo new_indices = indices; + if (need_offset) { + std::vector operand_offsets = MakePartitionOffsets( + operands[0].base_shape(), operands[0].sharding(), + operands[0].state().partition_id, b, operand_parallel_dims); + absl::InlinedVector index_offsets; + for (int start_idx = 0; + start_idx < dnums.scatter_dims_to_operand_dims_size(); ++start_idx) { + HloInstruction* index_offset = + indices.base_shape().dimensions_size() > index_dim + ? b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1}), + operand_offsets[dnums.scatter_dims_to_operand_dims( + start_idx)])) + : operand_offsets[dnums.scatter_dims_to_operand_dims(start_idx)]; + index_offsets.push_back(index_offset); + } + HloInstruction* adjusted_indices = nullptr; + if (indices.base_shape().dimensions_size() > index_dim) { + // Concatenate the offsets for the parallel dimensions to subtract. + adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, + {indices.base_shape().dimensions(index_dim)}), + index_offsets, 0)); + } else { + CHECK_EQ(index_offsets.size(), 1); + adjusted_indices = index_offsets[0]; + } + if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(adjusted_indices->shape(), + indices.hlo()->shape().element_type()), adjusted_indices)); - PartitionedHlo new_indices = indices.CloneWithNewHlo(adjusted_indices); - const GroupedSharding new_indices_grouped = - hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), - indices_parallel_dims); - const GroupedSharding operand_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - operands[0].sharding(), operand_parallel_dims), - new_indices_grouped); - const GroupedSharding update_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - updates[0].sharding(), update_parallel_dims), - new_indices_grouped); - const GroupedSharding& output_grouped = operand_grouped; - std::vector per_group_operands = - PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); - std::vector per_group_updates = - PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); - PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo( - new_indices, new_indices_grouped, b, clean_ups); - auto pshape = - MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); - TF_ASSIGN_OR_RETURN( - HloInstruction * pscatter, - PartitionScatter( - scatter, per_group_operands, per_group_new_indices, - per_group_updates, pshape, - HloSharding::Single(scatter->shape(), output_grouped.sharding), - slice_sizes, visitor, allow_recursive)); - pscatter->set_sharding(HloSharding::Single( - pscatter->shape(), - hlo_sharding_util::UngroupSharding(output_grouped))); - if (allow_recursive) { - VLOG(5) << "[Scatter partitioning]: Partitioned as index parallel"; - } - return PartitionedHlo(pscatter, output_shape, operands[0].state()) - .Reshard(output_sharding) - .hlo(); } + if (adjusted_indices->shape().rank() == 0) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {})); + } else { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {index_dim})); + } + // Adjust indices by subtracting the offsets based on the partition id. + adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + adjusted_indices)); + new_indices = indices.CloneWithNewHlo(adjusted_indices); } - return nullptr; + + const GroupedSharding new_indices_grouped = + hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), + indices_parallel_dims); + const GroupedSharding operand_grouped = + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + operands[0].sharding(), operand_parallel_dims), + new_indices_grouped); + const GroupedSharding update_grouped = + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + updates[0].sharding(), update_parallel_dims), + new_indices_grouped); + const GroupedSharding& output_grouped = operand_grouped; + std::vector per_group_operands = + PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); + std::vector per_group_updates = + PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); + PartitionedHlo per_group_new_indices = + PerGroupPartitionedHlo(new_indices, new_indices_grouped, b, clean_ups); + auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); + TF_ASSIGN_OR_RETURN( + HloInstruction * pscatter, + PartitionScatter( + scatter, per_group_operands, per_group_new_indices, per_group_updates, + pshape, + HloSharding::Single(scatter->shape(), output_grouped.sharding), + slice_sizes, visitor, allow_recursive)); + pscatter->set_sharding(HloSharding::Single( + pscatter->shape(), hlo_sharding_util::UngroupSharding(output_grouped))); + if (allow_recursive) { + VLOG(5) << "[Scatter partitioning]: Partitioned as index parallel"; + } + return PartitionedHlo(pscatter, output_shape, operands[0].state()) + .Reshard(output_sharding) + .hlo(); } + +// Partition a scatter over a indices dimensions that are cosidered parallel +// (which means that the indices access the operand in a monotonically +// increasing way across the respective operand dimension referenced by the +// index). +absl::StatusOr PartitionScatterIndexParallelDimensions( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, + const Shape& output_shape, const HloSharding& output_sharding, + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + std::optional parallel_dims = + hlo_sharding_util::GetScatterParallelBatchDims(*scatter, + visitor->call_graph()); + if (!parallel_dims) { + return nullptr; + } + + return PartitionScatterParallelDimensions( + scatter, operands, indices, updates, output_shape, output_sharding, + slice_sizes, visitor, allow_recursive, *parallel_dims, true); +} + +// Partition a scatter over a indices dimensions that are cosidered parallel +// (which means that the indices access the operand in a monotonically +// increasing way across the respective operand dimension referenced by the +// index). +absl::StatusOr PartitionScatterExplicitBatchDimensions( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, + const Shape& output_shape, const HloSharding& output_sharding, + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers(); + if (dnums.input_batching_dims().empty()) { + return nullptr; + } + + hlo_sharding_util::GatherScatterParallelDims parallel_dims; + parallel_dims.operand_parallel_dims.assign( + dnums.input_batching_dims().begin(), dnums.input_batching_dims().end()); + parallel_dims.indices_parallel_dims.assign( + dnums.scatter_indices_batching_dims().begin(), + dnums.scatter_indices_batching_dims().end()); + + return PartitionScatterParallelDimensions( + scatter, operands, indices, updates, output_shape, output_sharding, + slice_sizes, visitor, allow_recursive, parallel_dims, false); +} + // Perform partitioning of Scatter when the operand is split in a update window // dimension that is passed through (slice size is the same size of the operand // dimension). @@ -1569,7 +1621,9 @@ absl::StatusOr PartitionScatterTrivialSlicedOperandDimensions( // Returns a full list of partitioning methods used for scatter. std::vector> ScatterPartitionMethods() { - return {{PartitionScatterIndexParallelDimensions, + return {{PartitionScatterExplicitBatchDimensions, + "PartitionScatterExplicitBatchDimensions"}, + {PartitionScatterIndexParallelDimensions, "PartitionScatterIndexParallelDimensions"}, {PartitionScatterOperandPassthroughDimensions, "PartitionScatterOperandPassthroughDimensions"}, @@ -1583,6 +1637,8 @@ ScatterPartitionMethods() { decltype(PartitionScatter)* GetScatterPartitionMethod( PartitioningMethod method) { switch (method) { + case PartitioningMethod::kExplicitBatch: + return PartitionScatterExplicitBatchDimensions; case PartitioningMethod::kIndexParallel: return PartitionScatterIndexParallelDimensions; case PartitioningMethod::kOperandPassthrough: diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index 29adc287913b38..2d1d352351784d 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -8256,6 +8256,100 @@ ENTRY entry { op::Shape("f32[2,9,8]"))); } +TEST_P(SpmdPartitioningTest, ScatterExplicitBatchDims) { + absl::string_view hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0), sharding={devices=[2,1,2,1]<=[4]} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,2,1,1]<=[2,2]T(1,0)} + %updates = f32[14,10,6,2] parameter(2), sharding={devices=[2,2,1,1]<=[2,2]T(1,0)} + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3, sharding={devices=[2,1,2,1]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[5,6,7,4]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[7,5,6,2]"), op::Parameter(1)); + auto updates = AllOf(op::Shape("f32[7,5,6,2]"), op::Parameter(2)); + auto scatter = + AllOf(op::Shape("f32[5,6,7,4]"), op::Scatter(input, indices, updates)); + EXPECT_THAT(module->entry_computation()->root_instruction(), scatter); +} + +TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndOperandPassthroughDims) { + absl::string_view hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0), sharding={devices=[1,1,2,2]<=[4]} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,1,1,1,2]<=[4] last_tile_dim_replicate} + %updates = f32[14,10,6,4] parameter(2), sharding={devices=[2,1,1,2]<=[4]} + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3, sharding={devices=[1,1,2,2]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[10,6,7,2]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[7,10,6,2]"), op::Parameter(1)); + auto updates = AllOf(op::Shape("f32[7,10,6,2]"), op::Parameter(2)); + auto scatter = + AllOf(op::Shape("f32[10,6,7,2]"), op::Scatter(input, indices, updates)); + EXPECT_THAT(module->entry_computation()->root_instruction(), scatter); +} + +TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndIndexPassthroughDims) { + absl::string_view hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0), sharding={devices=[1,1,2,1,2]<=[4] last_tile_dim_replicate} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,1,2,1]<=[4]} + %updates = f32[14,10,6,2] parameter(2), sharding={devices=[2,1,2,1]<=[4]} + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3, sharding={devices=[1,1,2,1,2]<=[4] last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = + AllOf(op::Shape("f32[10,6,7,4]"), op::Select(_, _, op::Parameter(0))); + auto indices = AllOf(op::Shape("s32[7,10,3,2]"), op::Parameter(1)); + auto updates = AllOf(op::Shape("f32[7,10,3,2]"), op::Parameter(2)); + auto scatter = AllOf(op::Shape("f32[10,6,7,4]"), + op::AllReduce(op::Scatter(input, indices, updates))); + EXPECT_THAT(module->entry_computation()->root_instruction(), scatter); +} + TEST_P(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { absl::string_view hlo_string = R"( HloModule module From 504ccb5925dc976992b4ecf843de76ca12df8b61 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 14:09:25 -0700 Subject: [PATCH 206/483] Fix expected curl error message in curl_http_request_test.cc. PiperOrigin-RevId: 678390747 --- .../tsl/tsl/platform/cloud/curl_http_request_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc index 31cde679f4978b..9cc1d8e075e0b8 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc @@ -496,7 +496,7 @@ TEST(CurlHttpRequestTest, GetRequest_CouldntResolveHost) { EXPECT_EQ(error::FAILED_PRECONDITION, status.code()); EXPECT_EQ( "Error executing an HTTP request: libcurl code 6 meaning " - "'Couldn't resolve host name', error details: Could not resolve host " + "'Could not resolve hostname', error details: Could not resolve host " "'metadata'", status.message()); EXPECT_EQ(0, http_request.GetResponseCode()); From 4d9084373a341cdb6b06e58578d690ac06ac98a3 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Tue, 24 Sep 2024 14:09:56 -0700 Subject: [PATCH 207/483] Add TF NumPy 2.0 support to 2.18 release note PiperOrigin-RevId: 678390928 --- RELEASE.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index c146e2215fb339..a156363a53d523 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -20,6 +20,10 @@ * * +* TensorFlow now supports and is compiled with NumPy 2.0 by default. + Compatibility with NumPy 1.26 will be maintained until 2025, aligning with + community standard deprecation timeline [here](https://scientific-python.org/specs/spec-0000/). + ### Bug Fixes and Other Changes * From c46888a69557569754b577bfd52c95f12b88c4a0 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Tue, 24 Sep 2024 14:12:08 -0700 Subject: [PATCH 208/483] Update support for NumPy 2.1 Note that NumPy 2.1 does not support Python 3.9 https://numpy.org/doc/stable/release/2.1.0-notes.html#numpy-2-1-0-release-notes PiperOrigin-RevId: 678391800 --- .../requirements_updater/requirements.in | 5 +- requirements_lock_3_10.txt | 101 ++++++++++-------- requirements_lock_3_11.txt | 101 ++++++++++-------- requirements_lock_3_12.txt | 101 ++++++++++-------- requirements_lock_3_9.txt | 1 + tensorflow/tools/pip_package/setup.py | 4 +- 6 files changed, 170 insertions(+), 143 deletions(-) diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 305eaee3dce946..884d0e8b6f04fd 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -1,4 +1,5 @@ -numpy ~= 2.0.0 +# Note that numpy 2.1.0 does not support python 3.9 +numpy >= 2.0.0, < 2.2.0 wheel ~= 0.41.2 h5py >= 3.11.0 lit ~= 17.0.2 @@ -21,7 +22,7 @@ tb-nightly ~= 2.18.0.a # Test dependencies grpcio >= 1.24.3, < 2.0 portpicker == 1.6.0 -scipy ~= 1.13.0 +scipy >= 1.13.0 requests >= 2.31.0 packaging==23.2 setuptools==70.0.0 diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 3530cd725b6d78..895a7b568ccaeb 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -380,52 +380,60 @@ namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==2.0.2 \ - --hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \ - --hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \ - --hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \ - --hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \ - --hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \ - --hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \ - --hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \ - --hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \ - --hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \ - --hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \ - --hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \ - --hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \ - --hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \ - --hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \ - --hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \ - --hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \ - --hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \ - --hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \ - --hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \ - --hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \ - --hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \ - --hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \ - --hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \ - --hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \ - --hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \ - --hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \ - --hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \ - --hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \ - --hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \ - --hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \ - --hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \ - --hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \ - --hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \ - --hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \ - --hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \ - --hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \ - --hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \ - --hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \ - --hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \ - --hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \ - --hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \ - --hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \ - --hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \ - --hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \ - --hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd +numpy==2.1.1 \ + --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ + --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ + --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ + --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ + --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ + --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ + --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ + --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ + --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ + --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ + --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ + --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ + --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ + --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ + --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ + --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ + --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ + --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ + --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ + --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ + --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ + --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ + --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ + --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ + --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ + --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ + --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ + --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ + --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ + --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ + --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ + --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ + --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ + --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ + --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ + --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ + --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ + --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ + --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ + --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ + --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ + --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ + --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ + --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ + --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ + --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ + --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ + --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ + --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ + --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ + --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ + --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ + --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b # via # -r ci/official/requirements_updater/requirements.in # h5py @@ -526,6 +534,7 @@ six==1.16.0 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly tb-nightly==2.18.0a20240611 \ --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index dced21fd467a60..2b4a27416bc97c 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -380,52 +380,60 @@ namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==2.0.2 \ - --hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \ - --hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \ - --hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \ - --hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \ - --hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \ - --hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \ - --hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \ - --hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \ - --hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \ - --hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \ - --hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \ - --hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \ - --hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \ - --hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \ - --hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \ - --hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \ - --hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \ - --hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \ - --hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \ - --hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \ - --hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \ - --hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \ - --hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \ - --hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \ - --hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \ - --hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \ - --hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \ - --hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \ - --hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \ - --hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \ - --hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \ - --hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \ - --hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \ - --hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \ - --hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \ - --hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \ - --hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \ - --hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \ - --hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \ - --hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \ - --hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \ - --hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \ - --hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \ - --hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \ - --hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd +numpy==2.1.1 \ + --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ + --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ + --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ + --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ + --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ + --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ + --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ + --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ + --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ + --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ + --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ + --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ + --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ + --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ + --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ + --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ + --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ + --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ + --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ + --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ + --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ + --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ + --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ + --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ + --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ + --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ + --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ + --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ + --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ + --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ + --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ + --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ + --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ + --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ + --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ + --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ + --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ + --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ + --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ + --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ + --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ + --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ + --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ + --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ + --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ + --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ + --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ + --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ + --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ + --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ + --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ + --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ + --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b # via # -r ci/official/requirements_updater/requirements.in # h5py @@ -526,6 +534,7 @@ six==1.16.0 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly tb-nightly==2.18.0a20240611 \ --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 581778cdc49d64..3b3099447ab18b 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -380,52 +380,60 @@ namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==2.0.2 \ - --hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \ - --hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \ - --hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \ - --hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \ - --hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \ - --hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \ - --hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \ - --hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \ - --hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \ - --hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \ - --hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \ - --hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \ - --hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \ - --hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \ - --hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \ - --hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \ - --hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \ - --hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \ - --hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \ - --hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \ - --hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \ - --hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \ - --hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \ - --hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \ - --hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \ - --hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \ - --hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \ - --hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \ - --hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \ - --hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \ - --hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \ - --hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \ - --hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \ - --hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \ - --hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \ - --hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \ - --hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \ - --hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \ - --hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \ - --hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \ - --hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \ - --hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \ - --hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \ - --hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \ - --hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd +numpy==2.1.1 \ + --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ + --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ + --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ + --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ + --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ + --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ + --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ + --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ + --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ + --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ + --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ + --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ + --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ + --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ + --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ + --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ + --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ + --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ + --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ + --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ + --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ + --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ + --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ + --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ + --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ + --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ + --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ + --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ + --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ + --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ + --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ + --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ + --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ + --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ + --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ + --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ + --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ + --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ + --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ + --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ + --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ + --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ + --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ + --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ + --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ + --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ + --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ + --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ + --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ + --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ + --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ + --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ + --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b # via # -r ci/official/requirements_updater/requirements.in # h5py @@ -526,6 +534,7 @@ six==1.16.0 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly tb-nightly==2.18.0a20240611 \ --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 5b0533315ddd82..a0b508c012fdd1 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -530,6 +530,7 @@ six==1.16.0 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly tb-nightly==2.18.0a20240611 \ --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 125c3cfd13a57c..1145b2fd83ca17 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -111,9 +111,7 @@ def standard_or_nightly(standard, nightly): # 'keras >= 2.14.0rc0, < 2.15' on the release branch after the branch cut. 'tb-nightly ~= 2.18.0.a', 'keras-nightly >= 3.2.0.dev', - # TODO(b/367877753): Update the upper bound to <2.2.0 once the compatibility - # issues with numpy 2.1.0 is fixed. - 'numpy >= 1.26.0, < 2.1.0', + 'numpy >= 1.26.0, < 2.2.0', 'h5py >= 3.11.0', 'ml_dtypes >= 0.4.0, < 0.5.0', ] From 34f0dcfa27d90b61a26ee540b9514ffe8fe025e5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 14:25:10 -0700 Subject: [PATCH 209/483] Add ConvertS8ToF32 function to support literal dtype conversion from int8 to f32. PiperOrigin-RevId: 678396439 --- third_party/xla/xla/literal_util.cc | 5 +++++ third_party/xla/xla/literal_util.h | 1 + 2 files changed, 6 insertions(+) diff --git a/third_party/xla/xla/literal_util.cc b/third_party/xla/xla/literal_util.cc index 745194cdc24b39..2330aca215483b 100644 --- a/third_party/xla/xla/literal_util.cc +++ b/third_party/xla/xla/literal_util.cc @@ -229,6 +229,11 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, ShapeUtil::MakeShape(primitive_type, dimensions)); } +/* static */ Literal LiteralUtil::ConvertS8ToF32( + const LiteralSlice& s8_literal) { + return ConvertType(s8_literal); +} + /* static */ Literal LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index a19ed6fb1e529e..1048682e2d5f4e 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -239,6 +239,7 @@ class LiteralUtil { // If the given literal's data type is , converts it to a // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. + static Literal ConvertS8ToF32(const LiteralSlice& s8_literal); static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal); static Literal ConvertF32ToF8E4M3FNUZ(const LiteralSlice& f32_literal); From 48e9f5b978762176dc027745d8407e1c17646ca8 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 24 Sep 2024 15:16:52 -0700 Subject: [PATCH 210/483] Relax verifier to allow for partially pipelined async collectives PiperOrigin-RevId: 678414993 --- third_party/xla/xla/service/hlo_verifier.cc | 140 ++-------- .../xla/xla/service/hlo_verifier_test.cc | 264 ++++++++++-------- 2 files changed, 167 insertions(+), 237 deletions(-) diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 0b7e4bd9f916b1..6578b4ff765ec9 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -2439,40 +2439,8 @@ absl::Status VerifyChannels(const HloModule& module, absl::flat_hash_map> channel_instructions; - // For Async operations, we need to make sure: - // (1) AsyncStart and AsyncDone are used in pairs - // (2) AsynStart and Asyndone are connected, that is, an AsynDone has an - // AsyncStart as its only operand, and an AsynStart has an AsyncDone as - // its only user - // (3) the channel ID used by a pair of Async operations is unique - // - // Send and SendDone, Recv and RecvDone are such pairs of Async operations. - // Different from other Async operations, a channel ID can be used by one - // Send-SendDone pair and one Recv-RecvDone pair. As such, we verify the - // above three invariants for Send/Recv related instructions with adjustment - // to (3): - // (3*) the channel ID used by a pair of Send-SendDone can be shared by at - // most one pair of Recv-RecvDone. - // - // Currently, the GPU compiler can decomposed collective-permute into a group - // of instructions with a pair of Send-SendDone and a pair of Recv-RecvDone - // that use the same channel ID. When a while-body contains such instructions, - // the GPU compiler can also peel off Send and Recv, and statically order - // SendDone/RecvDone inside the while-body before Send/Recv. This breaks - // invariants (2) and (3*) for the pipelined Send/Recv case. We verify the - // following for a group of instructions using the same channel ID but don't - // satisfy invariants (1)(2)(3*): - // (4) All instructions in the group are annotated with frontend attributes. - // We avoid verifying the content of such a frontend attribute to avoid - // making the general HLO instruction verifier depend on the compiler pass - // that performs the transformation. - // (5) the group should contain equal number uses of each Send/Recv related - // instructions. - // - // Comparing the verification of unpipelined Send/Recv with the verification - // of pipelined, what we missing verifying is that the direct connection - // between Send/Recv and SendDone/RecvDone through operands. - // + // Send/recv instruction must have a unique user. If it is the corresponding + // send-done/recv-done operation, channel IDs must match. for (const HloComputation* computation : module.computations()) { for (const HloInstruction* instruction : computation->instructions()) { auto channel_instr = DynCast(instruction); @@ -2483,55 +2451,27 @@ absl::Status VerifyChannels(const HloModule& module, switch (instruction->opcode()) { case HloOpcode::kSend: { - bool pipelined = true; - if (instruction->users().size() == 1) { - const HloInstruction* send_user = instruction->users().front(); - if (send_user->opcode() == HloOpcode::kSendDone) { - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_user)); - TF_RETURN_IF_ERROR( - CheckSameIsHostTransfer(instruction, send_user)); - pipelined = false; - } + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + if (send_done->opcode() == HloOpcode::kSendDone) { + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); } - // Pipelined Send should be annotated with frontend attributes. - TF_RET_CHECK(pipelined == false || - !instruction->frontend_attributes().map().empty()); break; } case HloOpcode::kRecv: { - bool pipelined = true; - if (instruction->users().size() == 1) { - const HloInstruction* recv_user = instruction->users().front(); - if (recv_user->opcode() == HloOpcode::kRecvDone) { - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_user)); - TF_RETURN_IF_ERROR( - CheckSameIsHostTransfer(instruction, recv_user)); - pipelined = false; - } + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + if (recv_done->opcode() == HloOpcode::kRecvDone) { + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); } - // Pipelined Recv should be annotated with frontend attributes. - TF_RET_CHECK(pipelined == false || - !instruction->frontend_attributes().map().empty()); break; } - case HloOpcode::kSendDone: { + case HloOpcode::kSendDone: + case HloOpcode::kRecvDone: TF_RET_CHECK(instruction->operands().size() == 1); - const HloInstruction* send_done_operand = instruction->operand(0); - // If the operand is not a Send, the Send-done is pipelined and should - // have frontend attributes. - TF_RET_CHECK(send_done_operand->opcode() == HloOpcode::kSend || - !instruction->frontend_attributes().map().empty()); break; - } - case HloOpcode::kRecvDone: { - TF_RET_CHECK(instruction->operands().size() == 1); - const HloInstruction* recv_done_operand = instruction->operand(0); - // If the operand is not a Recv, the Recv-done is pipelined and should - // have frontend attributes. - TF_RET_CHECK(recv_done_operand->opcode() == HloOpcode::kRecv || - !instruction->frontend_attributes().map().empty()); - break; - } default: break; } @@ -2542,55 +2482,19 @@ absl::Status VerifyChannels(const HloModule& module, for (auto& pair : channel_instructions) { auto& instructions = pair.second; const HloInstruction* first = instructions[0]; - auto sendrecv = DynCast(first); - if (sendrecv) { - // Check that all instructions are Send/Recv related and count the - // appearance of each opcode in the group. - absl::flat_hash_map opcode_to_count; + if (const auto* sendrecv = DynCast(first)) { + absl::flat_hash_set opcodes; for (const HloInstruction* instr : instructions) { - auto it = opcode_to_count.find(instr->opcode()); - if (it != opcode_to_count.end()) { - it->second++; - } else { - opcode_to_count[instr->opcode()] = 1; - } - if (opts.verify_unique_channel_ids) { - TF_RET_CHECK(DynCast(instr) != nullptr) - << "channel " << pair.first - << " is used for different types of channel instructions"; - } + opcodes.insert(instr->opcode()); + auto cast = DynCast(instr); + TF_RET_CHECK(cast != nullptr) + << "channel " << pair.first + << " is used for different types of channel instructions"; } - - int count = opcode_to_count.begin()->second; - bool consistent_count = - absl::c_all_of(opcode_to_count, [count](const auto& opcode_count) { - return opcode_count.second == count; - }); - // A pipelined group of Send/Recv should all have frontend attributes. - bool maybe_pipelined = - absl::c_all_of(instructions, [](const HloInstruction* inst) { - return !inst->frontend_attributes().map().empty(); - }); - if (sendrecv->is_host_transfer()) { - TF_RET_CHECK(consistent_count && count == 1 && instructions.size() == 2) + TF_RET_CHECK(instructions.size() == 2) << "channel " << pair.first << " is used for multiple host send/recv instructions"; - } else { - if (consistent_count && count == 1) { - TF_RET_CHECK(instructions.size() == opcode_to_count.size()) - << "channel " << pair.first - << " is used for multiple send/recv instructions"; - } else { - TF_RET_CHECK(maybe_pipelined) << "channel " << pair.first - << " is used for multiple send/recv " - "instructions but not pipelined"; - TF_RET_CHECK(consistent_count && opcode_to_count.size() % 2 == 0) - << "channel " << pair.first - << " is pipelined. Not all Send/Recv related instructions are" - " used the same number of times or channel is used for other " - "instructions"; - } } } else { for (const HloInstruction* instr : instructions) { diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index 877c445c6f5aa8..1737bea0eca27b 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -997,7 +997,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); - auto status = verifier().Run(module.get()).status(); + absl::Status status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT( status.message(), @@ -2263,93 +2263,153 @@ TEST_F(HloVerifierTest, ChannelVerifier) { HasSubstr("used for different types of channel instructions")); } -TEST_F(HloVerifierTest, ChannelVerifierPipelinedMissingDones) { +TEST_F(HloVerifierTest, ChannelVerifierPartiallyPipelinedAsyncRecv) { const char* const kModuleStr = R"( - HloModule test - cond { - param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) - count = get-tuple-element(%param), index=0 - ub = u32[] constant(1) - ROOT result = pred[] compare(count, ub), direction=LT - } - - body { - param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) - count = get-tuple-element(%param), index=0 - - recv.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=1 - recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 - - c1 = u32[] constant(1) - new_count = u32[] add(count, c1) - - send.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=2 - send-done.0 = (u32[2], token[]) recv-done(send.0), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - - after-all.0.n = token[] after-all() - recv.0.n = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{1,0}}", - _xla_send_recv_pipeline="0" - } - - - after-all.1.n = token[] after-all() - send.0.n = (u32[2], u32[], token[]) send(recv-data.0, after-all.1.n), - channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{1,0}}", - _xla_send_recv_pipeline="0" - } - - ROOT result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) - tuple(new_count, recv.0.n, send.0.n) - } - - ENTRY test_computation { - c0 = u32[] constant(0) - init = u32[2] broadcast(c0), dimensions={} - after-all.0.p = token[] after-all() - recv.0.p = (u32[2], u32[], token[]) recv(after-all.0.p), channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{1,0}}", - _xla_send_recv_pipeline="0" - } - - after-all.1.p = token[] after-all() - send.0.p = (u32[2], u32[], token[]) send(init, after-all.1.p), - channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{1,0}}", - _xla_send_recv_pipeline="0" - } - - while_init = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) - tuple(c0, recv.0.p, send.0.p) - while_result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) - while(while_init), body=body, condition=cond - - recv.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=1 - recv-done.0.q = (u32[2], token[]) recv-done(recv.0.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - - ROOT recv-data.0.q = u32[2] get-tuple-element(recv-done.0.q), index=0 - })"; + HloModule test + + while_body { + param = ((f32[16], u32[], token[])) parameter(0) + prev_recv = (f32[16], u32[], token[]) get-tuple-element(param), index=0 + recv_done = (f32[16], token[]) recv-done(prev_recv), channel_id=1 + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT tuple = ((f32[16], u32[], token[])) tuple(recv) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16], u32[], token[])) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16], u32[], token[])) tuple(recv) + while = ((f32[16], u32[], token[])) while(init), + condition=while_condition, body=while_body + recv_ctx = (f32[16], u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16], token[]) recv-done(recv_ctx), channel_id=1 + ROOT result = f32[16] get-tuple-element(recv_done), index=0 + })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kModuleStr)); - EXPECT_THAT( - verifier().Run(module.get()).status().message(), - HasSubstr("is pipelined. Not all Send/Recv related instructions are used" - " the same number of times")); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierPartiallyPipelinedAsyncSend) { + const char* const kModuleStr = R"( + HloModule test + + while_body { + param = ((f32[16], u32[], token[]), f32[16]) parameter(0) + prev_send = (f32[16], u32[], token[]) get-tuple-element(param), index=0 + data = f32[16] get-tuple-element(param), index=1 + send_done = (f32[16], token[]) send-done(prev_send), channel_id=1 + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT tuple = ((f32[16], u32[], token[]), f32[16]) tuple(send, data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16], u32[], token[]), f32[16]) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16] parameter(0) + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16], u32[], token[]), f32[16]) tuple(send, data) + while = ((f32[16], u32[], token[]), f32[16]) while(init), + condition=while_condition, body=while_body + send_ctx = (f32[16], u32[], token[]) get-tuple-element(while), index=0 + ROOT send_done = (f32[16], token[]) send-done(send_ctx), channel_id=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierAsyncSend) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + data = f32[16] parameter(0) + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(after_all, data), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT send_done = (f32[16], token[]) send-done(send), channel_id=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierAsyncRecv) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + recv_done = (f32[16], token[]) recv-done(recv), channel_id=1 + ROOT result = f32[16] get-tuple-element(recv_done), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierMultipleSendUsers) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + data = f32[16] parameter(0) + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + send_done = (f32[16], token[]) send-done(send), channel_id=1 + ROOT result = ((f32[16], u32[], token[]), f32[16]) tuple(send, send_done) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT(verifier().Run(module.get()).status().message(), + HasSubstr("send instruction requires one consumer, found 2")); +} + +TEST_F(HloVerifierTest, ChannelVerifierMultipleRecvUsers) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + recv_done = (f32[16], token[]) recv-done(recv), channel_id=1 + ROOT result = (((f32[16], u32[], token[])), f32[16]) + tuple(recv, recv_done) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT(verifier().Run(module.get()).status().message(), + HasSubstr("recv instruction requires one consumer, found 2")); } TEST_F(HloVerifierTest, CollectiveChannelVerifier) { @@ -3428,39 +3488,5 @@ TEST_F(HloVerifierTest, NoErrorOnDuplicateChannelId) { ASSERT_IS_OK(verifier.Run(module.get()).status()); } -TEST_F(HloVerifierTest, ChannelVerifierAsyncSend) { - const char* const kModuleStr = R"( - HloModule test - - ENTRY main_spmd { - data = f32[16] parameter(0) - after_all = token[] after-all() - send = (f32[16], u32[], token[]) send(after_all, data), channel_id=1, - frontend_attributes={ - _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} - ROOT send_done = (f32[16], token[]) send-done(send), channel_id=1 - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kModuleStr)); - TF_ASSERT_OK(verifier().Run(module.get())); -} - -TEST_F(HloVerifierTest, ChannelVerifierAsyncRecv) { - const char* const kModuleStr = R"( - HloModule test - - ENTRY main_spmd { - after_all = token[] after-all() - recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, - frontend_attributes={ - _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} - recv_done = (f32[16], token[]) recv-done(recv), channel_id=1 - ROOT result = f32[16] get-tuple-element(recv_done), index=0 - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kModuleStr)); - TF_ASSERT_OK(verifier().Run(module.get())); -} - } // namespace } // namespace xla From 2bd80ea311a6f745bce72a375b5538d6413f66e2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 15:18:50 -0700 Subject: [PATCH 211/483] Copy versioning tools to compiler PiperOrigin-RevId: 678415782 --- tensorflow/compiler/mlir/lite/BUILD | 2 +- .../compiler/mlir/lite/flatbuffer_export.cc | 4 +- .../compiler/mlir/lite/tools/versioning/BUILD | 87 + .../lite/tools/versioning/op_signature.cc | 264 +++ .../mlir/lite/tools/versioning/op_signature.h | 96 ++ .../tools/versioning/op_signature_test.cc | 99 ++ .../mlir/lite/tools/versioning/op_version.cc | 1134 +++++++++++++ .../mlir/lite/tools/versioning/op_version.h | 33 + .../lite/tools/versioning/op_version_test.cc | 1435 +++++++++++++++++ .../lite/tools/versioning/runtime_version.cc | 510 ++++++ .../lite/tools/versioning/runtime_version.h | 40 + tensorflow/lite/CMakeLists.txt | 2 + tensorflow/lite/kernels/CMakeLists.txt | 1 + tensorflow/lite/toco/tflite/BUILD | 6 +- .../lite/toco/tflite/builtin_operator.h | 1 + tensorflow/lite/toco/tflite/export.cc | 2 +- tensorflow/lite/toco/tflite/operator.cc | 4 +- tensorflow/lite/toco/tflite/operator.h | 3 +- tensorflow/lite/toco/tflite/simple_operator.h | 1 + .../lite/tools/benchmark/CMakeLists.txt | 1 + tensorflow/lite/tools/versioning/BUILD | 1 + .../lite/tools/versioning/op_signature.cc | 241 +-- .../lite/tools/versioning/op_signature.h | 74 +- 23 files changed, 3721 insertions(+), 320 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/BUILD create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_signature_test.cc create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_version.h create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc create mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 3da89f496218a3..b938f25519809b 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1329,6 +1329,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", + "//tensorflow/compiler/mlir/lite/tools/versioning", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", @@ -1337,7 +1338,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index fad3e5c1409372..869173bb2f8d99 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -96,6 +96,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h" #include "tensorflow/compiler/mlir/lite/utils/control_edges.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" @@ -117,8 +119,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/tstring.h" -#include "tensorflow/lite/tools/versioning/op_version.h" -#include "tensorflow/lite/tools/versioning/runtime_version.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/status.h" #include "tsl/platform/tstring.h" diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/BUILD b/tensorflow/compiler/mlir/lite/tools/versioning/BUILD new file mode 100644 index 00000000000000..8cb1f84debe84d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/BUILD @@ -0,0 +1,87 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +cc_library( + name = "versioning", + srcs = [ + "op_version.cc", + "runtime_version.cc", + ], + hdrs = [ + "op_version.h", + "runtime_version.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":op_signature", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + "//tensorflow/compiler/mlir/lite/kernels/internal:compatibility_macros", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "versioning_test", + srcs = [ + "op_version_test.cc", + ], + deps = [ + ":op_signature", + ":versioning", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "op_signature", + srcs = [ + "op_signature.cc", + ], + hdrs = [ + "op_signature.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/core/api:flatbuffer_conversions", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", + "@flatbuffers//:runtime_cc", + ], +) + +tf_cc_test( + name = "op_signature_test", + srcs = [ + "op_signature_test.cc", + ], + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/multi_signatures.bin", + ], + deps = [ + ":op_signature", + "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/core/platform:resource_loader", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc new file mode 100644 index 00000000000000..13a71443f6f0b5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc @@ -0,0 +1,264 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" + +#include +#include +#include +#include + +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" + +namespace tflite { +namespace { + +// A BuiltinDataAllocator which just uses malloc()/free(). +class MallocDataAllocator : public tflite_migration::BuiltinDataAllocator { + public: + void* Allocate(size_t size, size_t alignment_hint) override { + return malloc(size); + } + void Deallocate(void* data) override { free(data); } +}; + +// Get the number of dimensions of a tensor with idx of an operator op. +inline int GetNumDims(const SubGraph* subgraph, const Operator* op, int idx) { + const flatbuffers::Vector* ret = + subgraph->tensors()->Get(op->inputs()->Get(idx))->shape(); + if (ret) { + return ret->size(); + } else { + return 0; + } +} + +std::vector GetOpSignatureTensorSpecs( + const flatbuffers::Vector* tensors, const SubGraph* subgraph, + const Model* model) { + std::vector tensor_specs; + if (!tensors) { + return tensor_specs; + } + + for (int32_t i = 0; i < tensors->size(); ++i) { + int32_t tensor_no = tensors->Get(i); + + OpSignatureTensorSpec tensor_spec = {kTfLiteNoType}; + if (tensor_no >= 0) { + if (subgraph->tensors() && tensor_no < subgraph->tensors()->size()) { + auto* fb_tensor = subgraph->tensors()->Get(tensor_no); + tflite_migration::ConvertTensorType(fb_tensor->type(), + &tensor_spec.type) + .IgnoreError(); + auto buffer_idx = fb_tensor->buffer(); + // Check if the tensor is a constant tensor. + if (buffer_idx != 0 && buffer_idx < model->buffers()->size()) { + auto* buffer = model->buffers()->Get(buffer_idx); + if (buffer->data() && buffer->data()->size() != 0) { + tensor_spec.is_const = true; + } + } + const flatbuffers::Vector* shape_vec = fb_tensor->shape(); + if (shape_vec) { + for (int32_t j = 0; j < shape_vec->size(); ++j) { + tensor_spec.dims.push_back(shape_vec->Get(j)); + } + } + const flatbuffers::Vector* shape_signature_vec = + fb_tensor->shape_signature(); + tensor_spec.is_shape_dynamic = false; + if (shape_signature_vec) { + for (int32_t j = 0; j < shape_signature_vec->size(); ++j) { + if (shape_signature_vec->Get(j) == -1) { + tensor_spec.is_shape_dynamic = true; + break; + } + } + } + } + } + tensor_specs.push_back(tensor_spec); + } + return tensor_specs; +} + +} // namespace + +OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, + const SubGraph* subgraph, const Model* model) { + auto builtin_code = GetBuiltinCode(op_code); + OpSignature op_sig = {builtin_code}; + std::memset(&op_sig.ext_options, 0, sizeof(op_sig.ext_options)); + + if (builtin_code != BuiltinOperator_CUSTOM) { + MallocDataAllocator allocator; + tflite_migration::ParseOpData(op, builtin_code, &allocator, + &op_sig.builtin_data) + .IgnoreError(); + } else { + op_sig.custom_name = op_code->custom_code()->str(); + } + + switch (builtin_code) { + case BuiltinOperator_DEPTHWISE_CONV_2D: { + const Tensor* filter_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + const QuantizationParameters* filter_quant = + filter_tensor->quantization(); + int num_channels = filter_tensor->shape()->Get(3); + if (filter_quant && filter_quant->scale() && + filter_quant->scale()->size() && + filter_quant->scale()->size() == num_channels) { + op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized = true; + } + } break; + + case BuiltinOperator_FULLY_CONNECTED: { + const Tensor* weight_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + op_sig.ext_options.fully_connected.sparse_weight = + (weight_tensor->sparsity() != nullptr); + const QuantizationParameters* weight_quant = + weight_tensor->quantization(); + if (weight_quant && weight_quant->scale() && + weight_quant->scale()->size() && weight_tensor->shape() && + weight_tensor->shape()->size()) { + op_sig.ext_options.fully_connected.is_per_channel_quantized = + weight_quant->scale()->size() > 1 && + weight_quant->scale()->size() == weight_tensor->shape()->Get(0); + } + } break; + + case BuiltinOperator_MUL: { + if (op->inputs()->size() < 2 || op->outputs()->size() < 1) { + break; + } + const Tensor* input1_tensor = + subgraph->tensors()->Get(op->inputs()->Get(0)); + const Tensor* input2_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + const Tensor* output_tensor = + subgraph->tensors()->Get(op->outputs()->Get(0)); + const QuantizationParameters* input1_quant = + input1_tensor->quantization(); + const QuantizationParameters* input2_qunt = input2_tensor->quantization(); + const QuantizationParameters* output_quant = + output_tensor->quantization(); + if (input1_quant && input1_quant->scale() && + input1_quant->scale()->size() && input2_qunt && + input2_qunt->scale() && input2_qunt->scale()->size() && + output_quant && output_quant->scale() && + output_quant->scale()->size()) { + op_sig.ext_options.mul.input1_scale = input1_quant->scale()->Get(0); + op_sig.ext_options.mul.input2_scale = input2_qunt->scale()->Get(0); + op_sig.ext_options.mul.output_scale = output_quant->scale()->Get(0); + } + if (input1_quant || input2_qunt) { + op_sig.ext_options.mul.input_quantized = true; + } + } break; + + case BuiltinOperator_CONV_2D: { + const Tensor* input_tensor = + subgraph->tensors()->Get(op->inputs()->Get(0)); + const Tensor* filter_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + const QuantizationParameters* filter_quant = + filter_tensor->quantization(); + int num_filters = filter_tensor->shape()->Get(0); + if (filter_quant && filter_quant->scale() && + filter_quant->scale()->size() && + filter_quant->scale()->size() == num_filters) { + op_sig.ext_options.conv_2d.is_per_channel_quantized = true; + } + if (input_tensor->shape() && input_tensor->shape()->size()) { + int num_input_channels = input_tensor->shape()->Get(3); + int num_filter_input_channels = filter_tensor->shape()->Get(3); + op_sig.ext_options.conv_2d.is_grouped_convolution = + num_input_channels != num_filter_input_channels; + } else { + op_sig.ext_options.conv_2d.is_grouped_convolution = false; + } + } break; + + case BuiltinOperator_STRIDED_SLICE: { + op_sig.ext_options.strided_slice.num_dims = GetNumDims(subgraph, op, 0); + } break; + + case BuiltinOperator_ABS: { + if (subgraph->tensors()->Get(op->inputs()->Get(0))->quantization()) { + op_sig.ext_options.abs.input_quantized = true; + } + } break; + + case BuiltinOperator_DEQUANTIZE: { + const Tensor* input_tensor = + subgraph->tensors()->Get(op->inputs()->Get(0)); + const QuantizationParameters* input_quant = input_tensor->quantization(); + if (input_quant && input_quant->scale() && + input_quant->scale()->size() > 1 && + input_quant->scale()->size() == + input_tensor->shape()->Get(input_quant->quantized_dimension())) { + op_sig.ext_options.dequantize.is_per_channel_quantized = true; + } + } break; + + case BuiltinOperator_QUANTIZE: { + const Tensor* output_tensor = + subgraph->tensors()->Get(op->outputs()->Get(0)); + const QuantizationParameters* output_quant = + output_tensor->quantization(); + if (output_quant && output_quant->scale() && + output_quant->scale()->size() > 1 && + output_quant->scale()->size() == + output_tensor->shape()->Get( + output_quant->quantized_dimension())) { + op_sig.ext_options.quantize.is_per_channel_quantized = true; + } + } break; + + case BuiltinOperator_ADD: { + if (subgraph->tensors()->Get(op->inputs()->Get(0))->quantization()) { + op_sig.ext_options.add.input_quantized = true; + } + } break; + + case BuiltinOperator_EMBEDDING_LOOKUP: { + const Tensor* table_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + const QuantizationParameters* table_quant = table_tensor->quantization(); + if (table_quant && table_quant->scale() && table_quant->scale()->size() && + table_tensor->shape() && table_tensor->shape()->size()) { + op_sig.ext_options.embedding_lookup.is_per_channel_quantized = + table_quant->scale()->size() > 1 && + table_quant->scale()->size() == table_tensor->shape()->Get(0); + } + } break; + + default: + break; + } + + op_sig.inputs = GetOpSignatureTensorSpecs(op->inputs(), subgraph, model); + op_sig.outputs = GetOpSignatureTensorSpecs(op->outputs(), subgraph, model); + op_sig.version = op_code->version(); + return op_sig; +} + +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h new file mode 100644 index 00000000000000..5799194f8770e7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h @@ -0,0 +1,96 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace tflite { + +// OpSignature contains operator parameters for version functions. +typedef struct { + TfLiteType type; + std::vector dims; + bool is_const; + bool is_shape_dynamic; +} OpSignatureTensorSpec; + +typedef struct { + BuiltinOperator op; + std::vector inputs; + std::vector outputs; + void* builtin_data; + int version; + const void* custom_initial_data; + std::string custom_name; + union { + struct { + bool is_per_channel_quantized; + bool is_grouped_convolution; + } conv_2d; + struct { + bool is_per_channel_quantized; + } depthwise_conv_2d; + struct { + // TODO(b/156530611): Make this global when more ops support sparse + // computation. + bool sparse_weight; + bool is_per_channel_quantized; + } fully_connected; + struct { + float input1_scale; + float input2_scale; + float output_scale; + bool input_quantized; + } mul; + struct { + int32_t num_dims; + } strided_slice; + struct { + bool input_quantized; + } abs; + struct { + bool is_per_channel_quantized; + } dequantize; + struct { + bool is_per_channel_quantized; + } quantize; + struct { + bool input_quantized; + } add; + struct { + bool is_per_channel_quantized; + } embedding_lookup; + } ext_options; +} OpSignature; + +// Generate OpSignature with the given OperatorCode, Operator and Tensors (from +// SubGraph). The OpSignature will be used by GetBuiltinOperatorVersion() and +// mostly input and output tensor types are enough to figure out op version. +// But some ops (DEPTHWISE_CONV_2D, FULLY_CONNECTED, ...) require to pass their +// options to decide op version. +// +// WARNING: The caller is responsible to free the allocated +// OpSignature.builtin_data memory. +OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, + const SubGraph* subgraph, const Model* model); + +} // namespace tflite +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature_test.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature_test.cc new file mode 100644 index 00000000000000..e3db7b8da8ca49 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" + +#include +#include +#include +#include + +#include +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/core/platform/resource_loader.h" + +namespace tflite { + +TEST(GetOpSignature, FlatBufferModel) { + const std::string& full_path = + tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin"); + auto fb_model = + mlir::TFL::FlatBufferModelAbslError::BuildFromFile(full_path.data()); + ASSERT_TRUE(fb_model); + auto model = fb_model->GetModel(); + auto subgraphs = model->subgraphs(); + const SubGraph* subgraph = subgraphs->Get(0); + const Operator* op1 = subgraph->operators()->Get(0); + const OperatorCode* op_code1 = + model->operator_codes()->Get(op1->opcode_index()); + OpSignature op_sig = GetOpSignature(op_code1, op1, subgraph, model); + EXPECT_EQ(op_sig.op, BuiltinOperator_ADD); + EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32); + EXPECT_EQ(op_sig.inputs[0].dims.size(), 4); + EXPECT_FALSE(op_sig.inputs[0].is_const); + EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic); + EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32); + EXPECT_FALSE(op_sig.outputs[0].is_const); + EXPECT_EQ(op_sig.outputs[0].dims.size(), 4); + EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic); + EXPECT_NE(op_sig.builtin_data, nullptr); + EXPECT_EQ(op_sig.version, 1); + free(op_sig.builtin_data); + + const Operator* op2 = subgraph->operators()->Get(1); + const OperatorCode* op_code2 = + model->operator_codes()->Get(op2->opcode_index()); + op_sig = GetOpSignature(op_code2, op2, subgraph, model); + EXPECT_EQ(op_sig.op, BuiltinOperator_ADD); + EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32); + EXPECT_EQ(op_sig.inputs[0].dims.size(), 4); + EXPECT_FALSE(op_sig.inputs[0].is_const); + EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic); + EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32); + EXPECT_FALSE(op_sig.outputs[0].is_const); + EXPECT_EQ(op_sig.outputs[0].dims.size(), 4); + EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic); + EXPECT_NE(op_sig.builtin_data, nullptr); + EXPECT_EQ(op_sig.version, 1); + free(op_sig.builtin_data); + + const std::string& full_path3 = tensorflow::GetDataDependencyFilepath( + "tensorflow/lite/testdata/multi_signatures.bin"); + auto fb_model3 = + mlir::TFL::FlatBufferModelAbslError::BuildFromFile(full_path3.data()); + ASSERT_TRUE(fb_model3); + auto model3 = fb_model3->GetModel(); + auto subgraphs3 = model3->subgraphs(); + const SubGraph* subgraph3 = subgraphs3->Get(0); + const Operator* op3 = subgraph3->operators()->Get(0); + const OperatorCode* op_code3 = + model3->operator_codes()->Get(op3->opcode_index()); + op_sig = GetOpSignature(op_code3, op3, subgraph3, model3); + EXPECT_EQ(op_sig.op, BuiltinOperator_ADD); + EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32); + EXPECT_EQ(op_sig.inputs[0].dims.size(), 1); + EXPECT_FALSE(op_sig.inputs[0].is_const); + EXPECT_TRUE(op_sig.inputs[0].is_shape_dynamic); + EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32); + EXPECT_FALSE(op_sig.outputs[0].is_const); + EXPECT_EQ(op_sig.outputs[0].dims.size(), 1); + EXPECT_TRUE(op_sig.outputs[0].is_shape_dynamic); + EXPECT_NE(op_sig.builtin_data, nullptr); + EXPECT_EQ(op_sig.version, 1); + free(op_sig.builtin_data); +} + +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc new file mode 100644 index 00000000000000..c0c1da3a158761 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc @@ -0,0 +1,1134 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" + +#include +#include +#include + +#include "absl/log/log.h" +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" + +namespace tflite { +namespace { + +bool NeedBroadcastForBinaryInputs(const OpSignature& op_sig) { + if (op_sig.inputs.size() < 2) { + return false; + } + return (op_sig.inputs.at(0).dims != op_sig.inputs.at(1).dims); +} + +int GetInputMaxDims(const OpSignature& op_sig) { + int max_dims = 0; + for (auto& input : op_sig.inputs) { + if (input.dims.size() > max_dims) { + max_dims = input.dims.size(); + } + } + return max_dims; +} + +} // namespace + +int GetBuiltinOperatorVersion(const OpSignature& op_sig) { + switch (op_sig.op) { + case BuiltinOperator_CONV_2D: { + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + // `quantized_bias_type` is supported at version 8. + auto conv_params = + reinterpret_cast(op_sig.builtin_data); + TFLITE_DCHECK(conv_params != nullptr); + if (conv_params->quantized_bias_type) { + return 8; + } + } + + if (op_sig.ext_options.conv_2d.is_grouped_convolution) { + return 6; + } + // If the op has signed int16 op_sig.inputs and op_sig.outputs, its + // version 4. + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(1).type == kTfLiteInt16 && + op_sig.outputs.at(1).type == kTfLiteInt16) { + return 4; + } + + // If the op has signed int8 op_sig.inputs and op_sig.outputs, its + // version 3. + if (op_sig.inputs.at(0).type == kTfLiteInt8 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteInt8) { + return 3; + } + // If the op has signed int8 and int4 op_sig.inputs and op_sig.outputs, + // its version 7. + if (op_sig.inputs.at(0).type == kTfLiteInt8 && + op_sig.inputs.at(1).type == kTfLiteInt4 && + op_sig.outputs.at(0).type == kTfLiteInt8) { + return 7; + } + // If the op is a signed int8 hybrid operation, we need to return + // version 2 or 5 if per channel. + if (op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + if (op_sig.ext_options.conv_2d.is_per_channel_quantized) { + return 5; + } + return 2; + } + return 1; + } + case BuiltinOperator_DEPTHWISE_CONV_2D: { + // If the op accepts int16, we return version 5. + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(1).type == kTfLiteInt16 && + op_sig.outputs.at(1).type == kTfLiteInt16) { + return 5; + } + + // If the op is a signed int8 hybrid operation, we need to return + // version 4 or 6 if per-channel. + if (op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + if (op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized) { + return 6; + } + return 4; + } + // If the op has signed int8 op_sig.inputs and op_sig.outputs, its + // version 3. + if (op_sig.inputs.at(0).type == kTfLiteInt8 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteInt8) { + return 3; + } + + // If the op has signed int8 and int4 op_sig.inputs and op_sig.outputs, + // its version 7. + if (op_sig.inputs.at(0).type == kTfLiteInt8 && + op_sig.inputs.at(1).type == kTfLiteInt4 && + op_sig.outputs.at(0).type == kTfLiteInt8) { + return 7; + } + + auto depthwise_conv_params = + reinterpret_cast(op_sig.builtin_data); + TFLITE_DCHECK(depthwise_conv_params != nullptr); + if (depthwise_conv_params->dilation_width_factor != 1 || + depthwise_conv_params->dilation_height_factor != 1) { + return 2; + } + return 1; + } + + case BuiltinOperator_EMBEDDING_LOOKUP: { + if (op_sig.inputs.at(1).type == kTfLiteInt4 || + op_sig.ext_options.embedding_lookup.is_per_channel_quantized) { + return 4; + } + return 1; + } + + case BuiltinOperator_FAKE_QUANT: { + auto fake_quant_params = + reinterpret_cast(op_sig.builtin_data); + TFLITE_DCHECK(fake_quant_params != nullptr); + if (fake_quant_params->narrow_range) { + return 2; + } + return 1; + } + + case BuiltinOperator_FULLY_CONNECTED: { + // +-----------------+--------------------+--------------------------+ + // | | Weight::Default | Weight::Shuffled4x16Int8 | + // +-----------------+--------------------+--------------------------+ + // | Float | 1 | 2 | + // | Quantized Uint8 | 1 | 2 | + // | Hybrid | 3 | 3 | + // | Quantized Int8 | 4 | 4 | + // +-----------------+--------------------+--------------------------+ + + auto fully_connected_params = + reinterpret_cast(op_sig.builtin_data); + TFLITE_DCHECK(fully_connected_params != nullptr); + + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(1).type == kTfLiteInt4 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 13; + } + + if (op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32 && + op_sig.ext_options.fully_connected.is_per_channel_quantized) { + return 12; + } + + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + // `quantized_bias_type` is supported at version 11. + if (fully_connected_params->quantized_bias_type) { + return 11; + } + } + + // FullyConnected with sparse weight is supported at version 8. + if (op_sig.ext_options.fully_connected.sparse_weight) { + return 8; + } + + // Int16 fully fixed point kernel is at version 7. + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(1).type == kTfLiteInt16 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 7; + } + + // 2 op_sig.inputs (no bias) use case is supported starting from + // version 6. + if (op_sig.inputs.size() == 2) { + return 6; + } + // `keep_num_dims` is supported at version 5. + if (fully_connected_params->keep_num_dims) { + return 5; + } + // Int8 fully fixed point kernel is at version 4. + if (op_sig.inputs.at(0).type == kTfLiteInt8 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteInt8) { + return 4; + } + + // If the op has signed int8 and int4 op_sig.inputs and op_sig.outputs, + // its version 7. + if (op_sig.inputs.at(0).type == kTfLiteInt8 && + op_sig.inputs.at(1).type == kTfLiteInt4 && + op_sig.outputs.at(0).type == kTfLiteInt8) { + return 10; + } + + // If the op is a signed int8 hybrid operation, we need to return + // version 3. + if (op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + if (fully_connected_params->asymmetric_quantize_inputs) { + // This is to use the updated quantization scheme. + return 9; + } + return 3; + } + // For float and uint8 fixed point kernels, if the weight is + // Shuffled4x16Int8, it is version 2. + if (fully_connected_params->weights_format == + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) { + return 2; + } + // Otherwise (weight is default), the version is 1. + return 1; + } + + case BuiltinOperator_GATHER: { + if (op_sig.inputs.at(0).type == kTfLiteInt4) { + return 7; + } + if (op_sig.inputs.at(1).type == kTfLiteInt16) { + return 6; + } + auto gather_params = + reinterpret_cast(op_sig.builtin_data); + if (gather_params && gather_params->batch_dims != 0) { + return 5; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 4; + } + // If the op takes bool input, it is version 3. + if (op_sig.inputs.at(0).type == kTfLiteBool) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + } + + case BuiltinOperator_SVDF: { + // Fully integer SVDF has int8 as input and is of version 3. + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 3; + } + // If the op is a signed int8 hybrid operation, we need to return + // version 2. + if (op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + auto svdf_params = + reinterpret_cast(op_sig.builtin_data); + // This is to use the updated quantization scheme + if (svdf_params && svdf_params->asymmetric_quantize_inputs) { + return 4; + } + return 2; + } + return 1; + } + + case BuiltinOperator_SIGN: + // Version 2 supports int32 inputs + if (op_sig.inputs.at(0).type == kTfLiteInt32) { + return 2; + } + return 1; + + case BuiltinOperator_MUL: + // Version 7 supports int16 and uint32 inputs + if ((op_sig.inputs.at(0).type == kTfLiteInt16 && + !op_sig.ext_options.mul.input_quantized) || + op_sig.inputs.at(0).type == kTfLiteUInt32) { + return 7; + } + // Version 6 supports complex32 inputs + if (op_sig.inputs.at(0).type == kTfLiteComplex64) { + return 6; + } + // Version 5 supports int64 inputs + if (op_sig.inputs.at(0).type == kTfLiteInt64) { + return 5; + } + // Version 4 supports int16 inputs + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 4; + } + // Version 3 supports have a rescale value greater than or equal to 1. + if (op_sig.ext_options.mul.input1_scale != 0 && + op_sig.ext_options.mul.input2_scale != 0 && + op_sig.ext_options.mul.output_scale != 0 && + (op_sig.ext_options.mul.input1_scale * + op_sig.ext_options.mul.input2_scale / + op_sig.ext_options.mul.output_scale) >= 1.0) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_AVERAGE_POOL_2D: + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 3; + } + + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_TRANSPOSE: + if (op_sig.inputs.at(0).dims.size() > 5) { + return 6; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 5; + } + if (op_sig.inputs.at(0).dims.size() > 4) { + return 4; + } + // If the op takes bool input, it is version 3. + if (op_sig.inputs.at(0).type == kTfLiteBool) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_TRANSPOSE_CONV: { + auto transpose_conv_params = + reinterpret_cast(op_sig.builtin_data); + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + // `quantized_bias_type` is supported at version 5. + TFLITE_DCHECK(transpose_conv_params != nullptr); + if (transpose_conv_params->quantized_bias_type) { + return 5; + } + } + + // TransposeConvOp has fused activation function from version 4. + if (transpose_conv_params != nullptr && + transpose_conv_params->activation) { + return 4; + } + + if (op_sig.inputs.size() == 4 && + op_sig.inputs.at(3).type != kTfLiteNoType) { + return 3; + } + // If the op takes int8 input, it is version 2. + if (op_sig.inputs.at(1).type == kTfLiteInt8) { + return 2; + } + return 1; + } + + case BuiltinOperator_LSTM: { + auto lstm_params = + reinterpret_cast(op_sig.builtin_data); + // If the input activation and output tensor are int16 and a weight is + // int8, this is a version 5. + if (lstm_params->kernel_type == kTfLiteLSTMFullKernel && + op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(2).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 5; + } + // If the input tensor is float and a weight is int8, this is a version + // 3 hybrid operation. + TFLITE_DCHECK(lstm_params != nullptr); + if (lstm_params->kernel_type == kTfLiteLSTMFullKernel && + op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(2).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + if (lstm_params->asymmetric_quantize_inputs) { + return 4; + } + return 3; + } + // KERNEL_BASIC was added in version 2. + if (lstm_params->kernel_type == kTfLiteLSTMBasicKernel) { + return 2; + } + return 1; + } + + case BuiltinOperator_SPLIT: + // If the op take in16 input, it is version 4. + if (op_sig.inputs.at(1).type == kTfLiteInt16) { + return 4; + } + // If the op take int8 input, it is version 2, for int32 it's version 3. + // The input tensor is at index 1 not 0, 0 is the axis. + if (op_sig.inputs.at(1).type == kTfLiteInt32) { + return 3; + } + if (op_sig.inputs.at(1).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_SPARSE_TO_DENSE: + // Version 3 supports Int8 and Uint8 type. + if (op_sig.inputs.at(2).type == kTfLiteInt8 || + op_sig.inputs.at(2).type == kTfLiteUInt8) { + return 3; + } + // Version 2 supports Int64 value type. + if (op_sig.inputs.at(2).type == kTfLiteInt64) { + return 2; + } + return 1; + + case BuiltinOperator_SLICE: + if (op_sig.inputs.at(0).type == kTfLiteUInt32) { + return 6; + } + if (op_sig.inputs.at(0).dims.size() > 4) { + return 5; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 4; + } + // Version 3 supports string input types. + if (op_sig.inputs.at(0).type == kTfLiteString) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_UNPACK: + // If the op take int8/uint8 input, it is version 2. + if (op_sig.inputs.at(0).type == kTfLiteInt8 || + op_sig.inputs.at(0).type == kTfLiteUInt8) { + return 2; + } + // If the op take bool input, it is version 3. + if (op_sig.inputs.at(0).type == kTfLiteBool) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 4; + } + return 1; + + case BuiltinOperator_DEQUANTIZE: + if (op_sig.inputs.at(0).type == kTfLiteInt4) { + return 6; + } + // Version 3 supports signed int16 input types. + if (op_sig.inputs.at(0).type == kTfLiteInt16 || + op_sig.inputs.at(0).type == kTfLiteFloat16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + if (op_sig.ext_options.dequantize.is_per_channel_quantized) { + return 5; + } + return 2; + } + return 1; + + case BuiltinOperator_QUANTIZE: + if (op_sig.inputs.at(0).type == kTfLiteInt4 || + op_sig.outputs.at(0).type == kTfLiteInt4) { + return 4; + } + if (op_sig.ext_options.quantize.is_per_channel_quantized) { + return 3; + } + if (op_sig.outputs.at(0).type == kTfLiteInt16) { + return 2; + } + return 1; + + case BuiltinOperator_FLOOR_DIV: + if (op_sig.inputs.at(0).type == kTfLiteInt16 || + op_sig.inputs.at(0).type == kTfLiteInt8) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteFloat32) { + return 2; + } + return 1; + + case BuiltinOperator_FLOOR_MOD: + if (op_sig.inputs.at(0).type == kTfLiteInt16 || + op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_L2_NORMALIZATION: + if (op_sig.outputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_ABS: + // Version 5 supports int32 + if (op_sig.inputs.at(0).type == kTfLiteInt32) { + return 5; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return op_sig.ext_options.abs.input_quantized ? 3 : 4; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8 || + op_sig.inputs.at(0).type == kTfLiteUInt8) { + return 2; + } + return 1; + case BuiltinOperator_RELU: + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8 || + op_sig.inputs.at(0).type == kTfLiteUInt8) { + return 2; + } + return 1; + + case BuiltinOperator_STRIDED_SLICE: { + auto strided_slice_params = + reinterpret_cast(op_sig.builtin_data); + TFLITE_DCHECK(strided_slice_params != nullptr); + if (strided_slice_params->offset == true) { + return 8; + } + if (op_sig.inputs.at(0).type == kTfLiteUInt32) { + return 7; + } + if (strided_slice_params->ellipsis_mask != 0 || + strided_slice_params->new_axis_mask != 0) { + return 6; + } + if (op_sig.inputs.at(0).type == kTfLiteString) { + return 5; + } + if (op_sig.ext_options.strided_slice.num_dims > 4) { + return 4; + } + // If the op takes bool input, it is version 3. + if (op_sig.inputs.at(0).type == kTfLiteBool) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + } + case BuiltinOperator_REVERSE_V2: + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteBool) { + return 2; + } + return 1; + case BuiltinOperator_RESIZE_BILINEAR: { + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 4; + } + auto resize_bilinear_params = + reinterpret_cast(op_sig.builtin_data); + TFLITE_DCHECK(resize_bilinear_params != nullptr); + if (resize_bilinear_params->half_pixel_centers) { + return 3; + } else if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + } + case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: { + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 4; + } + auto resize_nearest_neighbor_params = + reinterpret_cast( + op_sig.builtin_data); + TFLITE_DCHECK(resize_nearest_neighbor_params != nullptr); + if (resize_nearest_neighbor_params->half_pixel_centers || + resize_nearest_neighbor_params->align_corners) { + return 3; + } else if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + } + + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 4; + } + if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_PACK: + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteUInt32) { + return 4; + } + return 1; + + case BuiltinOperator_TILE: + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteString) { + return 2; + } + return 1; + + case BuiltinOperator_SQUEEZE: + if (op_sig.inputs.at(0).type == kTfLiteString) { + return 2; + } + return 1; + + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_BATCH_TO_SPACE_ND: + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 4; + } + if (op_sig.inputs.at(0).dims.size() != 4) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_ADD: { + if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt16 && + !op_sig.ext_options.add.input_quantized) { + return 5; + } + if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt64) { + return 4; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + auto add_params = + reinterpret_cast(op_sig.builtin_data); + if (add_params && !add_params->pot_scale_int16) { + return 3; + } + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + } + + case BuiltinOperator_SUB: { + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + auto sub_params = + reinterpret_cast(op_sig.builtin_data); + if (sub_params && !sub_params->pot_scale_int16) { + return 5; + } + } + if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt64) { + return 4; + } + if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + } + + case BuiltinOperator_GATHER_ND: + if (op_sig.inputs.at(0).type == kTfLiteBool) { + return 5; + } + if (op_sig.inputs.at(1).type == kTfLiteInt16) { + return 4; + } + if (!op_sig.inputs.empty() && + (op_sig.inputs.at(0).type == kTfLiteInt16)) { + return 3; + } + if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteString) { + return 2; + } + return 1; + + case BuiltinOperator_DIV: + if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) { + return 2; + } + return 1; + case BuiltinOperator_TANH: + case BuiltinOperator_LOGISTIC: + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 3; + } + + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_FILL: + if (op_sig.inputs.size() >= 2) { + if (op_sig.inputs.at(1).type == kTfLiteFloat16) return 4; + if (op_sig.inputs.at(1).type == kTfLiteInt8 || + op_sig.inputs.at(1).type == kTfLiteInt16) { + return 3; + } else if ((op_sig.inputs.at(1).type == kTfLiteBool || + op_sig.inputs.at(1).type == kTfLiteString)) { + return 2; + } + } + return 1; + + case BuiltinOperator_EQUAL: + if (!op_sig.inputs.empty()) { + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 4; + } + if (op_sig.inputs.at(0).type == kTfLiteString) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + } + return 1; + case BuiltinOperator_NOT_EQUAL: + if (!op_sig.inputs.empty()) { + if (op_sig.inputs.at(0).type == kTfLiteString) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + } + return 1; + + case BuiltinOperator_LEAKY_RELU: + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 2; + } + return 1; + + case BuiltinOperator_RANGE: + if (op_sig.inputs.at(0).type == kTfLiteInt64) { + return 2; + } + return 1; + + case BuiltinOperator_BATCH_MATMUL: { + // In case of int16 inputs, the version is 3. + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + if (op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + auto batch_mat_mul_params = + reinterpret_cast(op_sig.builtin_data); + if (batch_mat_mul_params && + batch_mat_mul_params->asymmetric_quantize_inputs) { + // This is to use the updated quantization scheme. + return 4; + } + } + return 1; + } + + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: + if (op_sig.inputs.at(0).dims.size() > 4) { + return 4; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_CONCATENATION: + if (op_sig.inputs.at(0).type == kTfLiteUInt32) { + return 4; + } + // In case of int16 inputs, the version is 3. + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_SOFTMAX: + case BuiltinOperator_MEAN: + case BuiltinOperator_MIRROR_PAD: + case BuiltinOperator_REDUCE_MAX: + case BuiltinOperator_REDUCE_MIN: + case BuiltinOperator_RELU6: + case BuiltinOperator_RSQRT: + // In case of int16 inputs, the version is 3. + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_RNN: { + if (op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + auto rnn_params = + reinterpret_cast(op_sig.builtin_data); + if (rnn_params && rnn_params->asymmetric_quantize_inputs) { + return 3; + } else { + return 2; + } + } + return 1; + } + + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { + if (op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + auto sequence_rnn_params = + reinterpret_cast(op_sig.builtin_data); + if (sequence_rnn_params && + sequence_rnn_params->asymmetric_quantize_inputs) { + return 3; + } else { + return 2; + } + } + return 1; + } + + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { + if (op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + auto bidirectional_sequence_rnn_params = + reinterpret_cast( + op_sig.builtin_data); + if (bidirectional_sequence_rnn_params && + bidirectional_sequence_rnn_params->asymmetric_quantize_inputs) { + return 3; + } else { + return 2; + } + } + return 1; + } + + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { + if (op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + auto bidirectional_sequence_lstm_params = + reinterpret_cast( + op_sig.builtin_data); + if (bidirectional_sequence_lstm_params && + bidirectional_sequence_lstm_params->asymmetric_quantize_inputs) { + return 3; + } else { + return 2; + } + } + return 1; + } + + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { + auto unidirectional_sequence_lstm_params = + reinterpret_cast( + op_sig.builtin_data); + // If the input activation and output tensor are int16 and a weight is + // int8, this is a version 5. + if (op_sig.inputs.at(0).type == kTfLiteInt16 && + op_sig.inputs.at(2).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteInt16) { + return 5; + } + if (unidirectional_sequence_lstm_params && + unidirectional_sequence_lstm_params->diagonal_recurrent_tensors) { + return 4; + } + // If the input tensor is float and a weight is int8, this is a version + // 2 hybrid operation. + if (op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(2).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + if (unidirectional_sequence_lstm_params && + unidirectional_sequence_lstm_params->asymmetric_quantize_inputs) { + return 3; + } + return 2; + } + return 1; + } + + case BuiltinOperator_ARG_MAX: + case BuiltinOperator_ARG_MIN: + if (op_sig.inputs.at(0).type == kTfLiteBool) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_SELECT: { + if (op_sig.inputs.at(0).type == kTfLiteUInt32) { + return 4; + } + if (op_sig.inputs.at(0).dims.size() == 5 || + op_sig.inputs.at(1).dims.size() == 5 || + op_sig.inputs.at(2).dims.size() == 5) + return 3; + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + } + case BuiltinOperator_LESS: + case BuiltinOperator_GREATER_EQUAL: { + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + } + case BuiltinOperator_SELECT_V2: { + if (op_sig.inputs.at(0).type == kTfLiteUInt32) { + return 2; + } + return 1; + } + case BuiltinOperator_SPACE_TO_DEPTH: + case BuiltinOperator_SPLIT_V: + case BuiltinOperator_SUM: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_GREATER: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_SQUARED_DIFFERENCE: + case BuiltinOperator_DEPTH_TO_SPACE: + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + case BuiltinOperator_TOPK_V2: + if (op_sig.inputs.at(0).type == kTfLiteInt16 || + op_sig.inputs.at(1).type == kTfLiteInt16 || + op_sig.outputs.at(1).type == kTfLiteInt16) { + return 3; + } + if (op_sig.inputs.at(0).type == kTfLiteInt8) { + return 2; + } + return 1; + + case BuiltinOperator_EXP: + case BuiltinOperator_LOG: + case BuiltinOperator_REDUCE_PROD: + if (op_sig.inputs.at(0).type == kTfLiteInt8 || + op_sig.inputs.at(0).type == kTfLiteInt16) { + return 2; + } + return 1; + case BuiltinOperator_DYNAMIC_UPDATE_SLICE: + if (op_sig.inputs.at(2).type == kTfLiteInt64) return 2; + return 1; + + // The version one of broadcast to op won't be not supported since the + // version one was rollbacked and the builtin op code number has been + // changed because of builtin op code shortage problem. + // Quantized broadcast_to is version 3 + case BuiltinOperator_BROADCAST_TO: + if (op_sig.inputs.at(0).type == kTfLiteInt8 || + op_sig.inputs.at(0).type == kTfLiteInt16) { + return 3; + } + return 2; + case BuiltinOperator_CAST: + if (op_sig.inputs.at(0).type == kTfLiteBFloat16 || + op_sig.outputs.at(0).type == kTfLiteBFloat16) { + return 7; + } else if (op_sig.inputs.at(0).type == kTfLiteInt4 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { + return 6; + } else if (op_sig.inputs.at(0).type == kTfLiteFloat64 || + op_sig.outputs.at(0).type == kTfLiteFloat64 || + op_sig.inputs.at(0).type == kTfLiteFloat16 || + op_sig.outputs.at(0).type == kTfLiteFloat16) { + return 5; + } else if (op_sig.inputs.at(0).type == kTfLiteUInt16 || + op_sig.outputs.at(0).type == kTfLiteUInt16) { + return 4; + } else if (op_sig.inputs.at(0).type == kTfLiteInt8 || + op_sig.outputs.at(0).type == kTfLiteInt8) { + return 3; + } else if (op_sig.inputs.at(0).type == kTfLiteUInt32 || + op_sig.outputs.at(0).type == kTfLiteUInt32) { + return 2; + } + return 1; + case BuiltinOperator_WHERE: + if (op_sig.inputs.at(0).type == kTfLiteBool) return 1; + return 2; + case BuiltinOperator_GELU: + if (op_sig.inputs.at(0).type == kTfLiteInt8 || + op_sig.inputs.at(0).type == kTfLiteUInt8) { + return 2; + } + return 1; + default: + return 1; + } + // Prevent lint error about this function being too long. + // NOLINTNEXTLINE +} + +void UpdateOpVersion(uint8_t* model_buffer_pointer) { + auto model = GetMutableModel(model_buffer_pointer); + auto subgraphs = model->subgraphs(); + + for (int i = 0; i < subgraphs->Length(); ++i) { + const SubGraph* subgraph = subgraphs->Get(i); + for (int j = 0; j < subgraph->operators()->Length(); ++j) { + const Operator* op = subgraph->operators()->Get(j); + OperatorCode* op_code = + model->mutable_operator_codes()->GetMutableObject(op->opcode_index()); + + auto builtin_code = GetBuiltinCode(op_code); + if (builtin_code != BuiltinOperator_CUSTOM) { + OpSignature op_sig = GetOpSignature(op_code, op, subgraph, model); + // Update builtin operator version. + int32_t op_ver = GetBuiltinOperatorVersion(op_sig); + if (op_sig.builtin_data) { + free(op_sig.builtin_data); + } + // Skip updating op version if the current node uses lower version. + // TODO(b/184366869): Populate multiple versions of operator once MLIR + // quantizer is ready. + if (op_ver <= op_code->version()) { + continue; + } + if (!op_code->mutate_version(op_ver)) { + LOG(ERROR) << "Can't set operator " + << EnumNameBuiltinOperator(builtin_code) << " to version " + << op_ver; + } + } + } + } +} + +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h new file mode 100644 index 00000000000000..bd1f551669a94c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ + +#include + +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" + +namespace tflite { + +// Returns version of builtin ops by the given signature. +int GetBuiltinOperatorVersion(const OpSignature& op_sig); + +// Update operator's version of the given TFL flatbuffer model. +void UpdateOpVersion(uint8_t* model_buffer_pointer); + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc new file mode 100644 index 00000000000000..5ad70990125a90 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc @@ -0,0 +1,1435 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" + +#include + +#include +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" + +namespace tflite { +namespace { + +// Creates vector of OpSignatureTensorSpec with the given TfLiteType vector. +std::vector CreateOpSignatureTensorSpecs( + const std::vector& types) { + std::vector tensor_specs; + for (auto type : types) { + OpSignatureTensorSpec tensor_spec = {}; + tensor_spec.type = type; + tensor_specs.push_back(tensor_spec); + } + return tensor_specs; +} + +// Creates vector of OpSignatureTensorSpec with the given TfLiteType vector, +// each with rank 'rank' +std::vector CreateOpSignatureTensorSpecs( + const std::vector& types, int rank) { + std::vector tensor_specs; + for (auto type : types) { + OpSignatureTensorSpec tensor_spec = {}; + tensor_spec.type = type; + for (int i = 0; i < rank; i++) { + tensor_spec.dims.push_back(4); + } + tensor_specs.push_back(tensor_spec); + } + return tensor_specs; +} + +// Creates vector of OpSignatureTensorSpec of single tensor spec of TfLiteType. +std::vector CreateOpSignatureTensorSpecs( + const TfLiteType type) { + std::vector tensor_specs; + OpSignatureTensorSpec tensor_spec = {}; + tensor_spec.type = type; + tensor_specs.push_back(tensor_spec); + return tensor_specs; +} + +// Creates vector of OpSignatureTensorSpec of single tensor spec of TfLiteType +// with shapes. +std::vector CreateOpSignatureTensorSpecs( + const TfLiteType type, const int dim) { + std::vector tensor_specs; + OpSignatureTensorSpec tensor_spec = {}; + tensor_spec.type = type; + for (int i = 0; i < dim; i++) { + tensor_spec.dims.push_back(4); + } + tensor_specs.push_back(tensor_spec); + return tensor_specs; +} + +// Creates vector of OpSignatureTensorSpec of two tensor specs of TfLiteType +// with shapes. +std::vector CreateOpSignatureTensorSpecs( + const TfLiteType type, const int dim1, const int dim2) { + std::vector tensor_specs; + OpSignatureTensorSpec tensor_spec1 = {}; + tensor_spec1.type = type; + for (int i = 0; i < dim1; i++) { + tensor_spec1.dims.push_back(4); + } + tensor_specs.push_back(tensor_spec1); + + OpSignatureTensorSpec tensor_spec2 = {}; + tensor_spec2.type = type; + for (int i = 0; i < dim2; i++) { + tensor_spec2.dims.push_back(4); + } + tensor_specs.push_back(tensor_spec2); + return tensor_specs; +} + +} // namespace + +TEST(OpVersionTest, VersioningSpareToDense) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_SPARSE_TO_DENSE, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt8, kTfLiteInt8}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_SPARSE_TO_DENSE, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteUInt8, kTfLiteUInt8, kTfLiteUInt8}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_SPARSE_TO_DENSE, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt64, kTfLiteInt64, kTfLiteInt64}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_SPARSE_TO_DENSE, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteInt32, kTfLiteInt32}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +// Test version for a simple Op with 2 versions and the input type controls the +// version. +void SimpleVersioningTest(BuiltinOperator op) { + OpSignature fake_op_sig = { + .op = op, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = op, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +// Similar to SimpleVersioningTest function, but +// op has 3 versions and the input type includes kTfLiteInt16. +void SimpleVersioningTestExtended(BuiltinOperator op) { + OpSignature fake_op_sig = { + .op = op, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + SimpleVersioningTest(op); +} + +// Test version for a simple Op with 2 versions and the output type controls the +void SimpleOutputVersioningTest(BuiltinOperator op) { + OpSignature fake_op_sig = { + .op = op, + .inputs = std::vector{}, + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = op, + .inputs = std::vector{}, + .outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningEqualTest) { + SimpleVersioningTest(BuiltinOperator_EQUAL); + OpSignature fake_op_sig = { + .op = BuiltinOperator_EQUAL, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteString), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); +} + +TEST(OpVersionTest, VersioningNotEqualTest) { + SimpleVersioningTest(BuiltinOperator_NOT_EQUAL); + OpSignature fake_op_sig = { + .op = BuiltinOperator_NOT_EQUAL, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteString), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); +} + +TEST(OpVersionTest, VersioningLessTest) { + SimpleVersioningTest(BuiltinOperator_LESS); +} + +TEST(OpVersionTest, VersioningLessEqualTest) { + SimpleVersioningTest(BuiltinOperator_LESS_EQUAL); +} + +TEST(OpVersionTest, VersioningGreaterTest) { + SimpleVersioningTest(BuiltinOperator_GREATER); +} + +TEST(OpVersionTest, VersioningGreaterEqualTest) { + SimpleVersioningTest(BuiltinOperator_GREATER_EQUAL); +} + +TEST(OpVersionTest, VersioningSpaceToBatchNDTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_SPACE_TO_BATCH_ND, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 3); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 3); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 3); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningLogSoftmaxTest) { + SimpleVersioningTest(BuiltinOperator_LOG_SOFTMAX); +} + +TEST(OpVersionTest, VersioningPackTest) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_PACK; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_PACK; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_PACK; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_PACK; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningUnpackTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_UNPACK, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_UNPACK, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_UNPACK, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningRangeTest) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_RANGE; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningReluTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_RELU, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_RELU, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_RELU, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_RELU, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningBatchToSpaceNDTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_BATCH_TO_SPACE_ND, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 3); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 3); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 3); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningTanhTest) { + SimpleVersioningTest(BuiltinOperator_TANH); +} + +TEST(OpVersionTest, VersioningStridedSliceTest) { + TfLiteStridedSliceParams strided_slice_params = {}; + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_STRIDED_SLICE; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + fake_op_sig.builtin_data = reinterpret_cast(&strided_slice_params); + strided_slice_params.ellipsis_mask = 0; + strided_slice_params.new_axis_mask = 2; + fake_op_sig.ext_options.strided_slice.num_dims = 5; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); + + strided_slice_params.new_axis_mask = 0; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig.ext_options.strided_slice.num_dims = 4; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); + + strided_slice_params.offset = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); +} + +TEST(OpVersionTest, VersioningSpaceToDepthTest) { + SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH); +} + +TEST(OpVersionTest, VersioningSliceTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_SLICE, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig = { + .op = BuiltinOperator_SLICE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig = { + .op = BuiltinOperator_SLICE, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteString, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_SLICE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_SLICE, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_SLICE; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); +} + +TEST(OpVersionTest, VersioningLogisticTest) { + SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH); +} + +TEST(OpVersionTest, VersioningL2NormTest) { + SimpleOutputVersioningTest(BuiltinOperator_L2_NORMALIZATION); +} + +TEST(OpVersionTest, VersioningMaxTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_MAXIMUM, + }; + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 5, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_MAXIMUM, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningMinTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_MINIMUM, + }; + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 5, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_MINIMUM, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningMeanTest) { + SimpleVersioningTestExtended(BuiltinOperator_MEAN); +} + +TEST(OpVersionTest, VersioningSumTest) { + SimpleVersioningTest(BuiltinOperator_SUM); +} + +TEST(OpVersionTest, VersioningReduceMinTest) { + SimpleVersioningTestExtended(BuiltinOperator_REDUCE_MIN); +} + +TEST(OpVersionTest, VersioningReduceMaxTest) { + SimpleVersioningTestExtended(BuiltinOperator_REDUCE_MAX); +} + +TEST(OpVersionTest, VersioningMirrorPadTest) { + SimpleVersioningTestExtended(BuiltinOperator_MIRROR_PAD); +} + +TEST(OpVersionTest, VersioningReduceProdTest) { + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_REDUCE_PROD; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningAddTest) { + TfLiteAddParams add_params = {}; + OpSignature fake_op_sig = { + .op = BuiltinOperator_ADD, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .builtin_data = reinterpret_cast(&add_params)}; + add_params.pot_scale_int16 = false; + fake_op_sig.ext_options.add.input_quantized = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig.ext_options.add.input_quantized = false; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + SimpleVersioningTest(BuiltinOperator_ADD); +} + +TEST(OpVersionTest, VersioningSubTest) { + TfLiteSubParams sub_params = {}; + OpSignature fake_op_sig = { + .op = BuiltinOperator_SUB, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .builtin_data = reinterpret_cast(&sub_params)}; + sub_params.pot_scale_int16 = false; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 5); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + SimpleVersioningTest(BuiltinOperator_SUB); +} + +TEST(OpVersionTest, VersioningMUL7TestInt16) { + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_MUL; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); + fake_op_sig.ext_options.mul.input_quantized = false; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); +} + +TEST(OpVersionTest, VersioningMUL7TestUInt32) { + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_MUL; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); +} + +TEST(OpVersionTest, VersioningMUL6Test) { + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_MUL; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteComplex64); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); +} + +TEST(OpVersionTest, VersioningMUL5Test) { + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_MUL; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); +} + +TEST(OpVersionTest, VersioningSub4Test) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_SUB, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); +} + +void SimpleMulVersioningTest(TfLiteType data_type, float multiplier, + int version) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_MUL, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{data_type, data_type}), + .outputs = CreateOpSignatureTensorSpecs(data_type), + }; + fake_op_sig.ext_options.mul = {1.0f, 1.0f, 1.0f / multiplier}; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), version); +} + +TEST(OpVersionTest, VersioningMulTest) { + SimpleMulVersioningTest(kTfLiteUInt8, 0.5f, 1); + SimpleMulVersioningTest(kTfLiteInt8, 0.5f, 2); + SimpleMulVersioningTest(kTfLiteInt8, 2.0f, 3); +} + +TEST(OpVersionTest, VersioningPadTest) { + SimpleVersioningTest(BuiltinOperator_PAD); +} + +TEST(OpVersionTest, VersioningPadV2Test) { + SimpleVersioningTest(BuiltinOperator_PADV2); +} + +TEST(OpVersionTest, VersioningConcatenationTest) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_CONCATENATION; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); +} + +TEST(OpVersionTest, VersioningSelectTest) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_SELECT; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteUInt32, kTfLiteUInt32, kTfLiteUInt32}, 5); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_SELECT; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteUInt8, kTfLiteUInt8, kTfLiteUInt8}, 5); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_SELECT; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt8, kTfLiteInt8}, 4); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_SELECT; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteFloat32}, + 4); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningSelectV2Test) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_SELECT_V2; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteUInt32, kTfLiteUInt32, kTfLiteUInt32}, 5); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_SELECT_V2; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteInt32, kTfLiteInt32}, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningRelu6Test) { + SimpleVersioningTestExtended(BuiltinOperator_RELU6); +} + +TEST(OpVersionTest, VersioningFullyConnectedTest) { + TfLiteFullyConnectedParams fully_connected_params = {}; + OpSignature fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteUInt8, kTfLiteUInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + fully_connected_params.weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + fully_connected_params.weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + fully_connected_params.weights_format = + kTfLiteFullyConnectedWeightsFormatDefault; + fake_op_sig.ext_options.fully_connected.sparse_weight = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8, kTfLiteFloat32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + fully_connected_params.asymmetric_quantize_inputs = false; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fully_connected_params.asymmetric_quantize_inputs = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 9); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt16, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + fully_connected_params.quantized_bias_type = kTfLiteInt32; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 11); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + fake_op_sig.ext_options.fully_connected.is_per_channel_quantized = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 12); +} + +TEST(OpVersionTest, VersioningDequantizeTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_DEQUANTIZE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_DEQUANTIZE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_DEQUANTIZE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.ext_options.dequantize.is_per_channel_quantized = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig = { + .op = BuiltinOperator_DEQUANTIZE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} + +TEST(OpVersionTest, VersioningQuantizeTest) { + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_QUANTIZE; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + fake_op_sig.ext_options.quantize.is_per_channel_quantized = false; + + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.ext_options.quantize.is_per_channel_quantized = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); +} + +TEST(OpVersionTest, VersioningConv2DTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteUInt8, kTfLiteUInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + }; + fake_op_sig.ext_options.conv_2d.is_per_channel_quantized = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig.op = BuiltinOperator_CONV_2D; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8}); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + fake_op_sig.ext_options.conv_2d.is_grouped_convolution = true; + + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); + + TfLiteConvParams conv_params = {}; + fake_op_sig = { + .op = BuiltinOperator_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt16, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .builtin_data = reinterpret_cast(&conv_params), + }; + conv_params.quantized_bias_type = kTfLiteInt32; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); +} + +TEST(OpVersionTest, VersioningFloorDivOperatorTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_FLOOR_DIV, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_FLOOR_DIV, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_FLOOR_DIV, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); +} + +TEST(OpVersionTest, VersioningFloorModOperatorTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_FLOOR_MOD, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_FLOOR_MOD, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} + +TEST(OpVersionTest, VersioningTransposeConvOperatorTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteUInt8}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteInt8, kTfLiteInt8}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .inputs = CreateOpSignatureTensorSpecs(std::vector{ + kTfLiteInt32, kTfLiteInt8, kTfLiteInt8, kTfLiteInt32}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + const auto none_type = kTfLiteNoType; + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .inputs = CreateOpSignatureTensorSpecs(std::vector{ + kTfLiteInt32, kTfLiteInt8, kTfLiteInt8, none_type}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + TfLiteTransposeConvParams transpose_conv_params = {}; + transpose_conv_params.activation = kTfLiteActRelu; + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .inputs = CreateOpSignatureTensorSpecs(std::vector{ + kTfLiteInt32, kTfLiteInt8, kTfLiteInt8, none_type}), + .builtin_data = reinterpret_cast(&transpose_conv_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + transpose_conv_params = {}; + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt16, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .builtin_data = reinterpret_cast(&transpose_conv_params), + }; + transpose_conv_params.quantized_bias_type = kTfLiteInt32; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); +} + +TEST(OpVersionTest, VersioningSVDFOperatorTest) { + TfLiteSVDFParams svdf_params = {}; + OpSignature fake_op_sig = { + .op = BuiltinOperator_SVDF, + .inputs = CreateOpSignatureTensorSpecs(std::vector{ + kTfLiteFloat32, kTfLiteFloat32, kTfLiteFloat32, kTfLiteFloat32, + kTfLiteFloat32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&svdf_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_SVDF, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8, kTfLiteFloat32, + kTfLiteFloat32, kTfLiteFloat32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&svdf_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + svdf_params.asymmetric_quantize_inputs = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + svdf_params = {}; + fake_op_sig = { + .op = BuiltinOperator_SVDF, + .inputs = CreateOpSignatureTensorSpecs(std::vector{ + kTfLiteInt8, kTfLiteInt8, kTfLiteInt32, kTfLiteInt32, kTfLiteInt16}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&svdf_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); +} + +TEST(OpVersionTest, VersioningDepthwiseConv2DTest) { + TfLiteDepthwiseConvParams depthwise_conv_params = {}; + OpSignature fake_op_sig = { + .op = BuiltinOperator_DEPTHWISE_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&depthwise_conv_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); + + depthwise_conv_params = {}; + fake_op_sig = { + .op = BuiltinOperator_DEPTHWISE_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&depthwise_conv_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_DEPTHWISE_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&depthwise_conv_params), + }; + depthwise_conv_params.dilation_width_factor = 2; + depthwise_conv_params.dilation_height_factor = 2; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_DEPTHWISE_CONV_2D, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&depthwise_conv_params), + }; + depthwise_conv_params.dilation_width_factor = 1; + depthwise_conv_params.dilation_height_factor = 1; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} +TEST(OpVersionTest, VersioningTileOperatorTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_TILE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_TILE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteString), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} +TEST(OpVersionTest, VersioningTransposeTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteBool, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteBool, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} +TEST(OpVersionTest, VersioningGatherNdOperatorTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_GATHER_ND, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteInt32}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_GATHER_ND, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteString, kTfLiteInt32}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_GATHER_ND, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt16, kTfLiteInt32}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_GATHER_ND, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteInt16}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig = { + .op = BuiltinOperator_GATHER_ND, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteBool, kTfLiteInt16}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); +} +TEST(OpVersionTest, VersioningDivTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_DIV, + }; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 5, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 5, 5); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 4); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} +TEST(OpVersionTEst, VersioningFillTest) { + OpSignature fake_op_sig = {BuiltinOperator_FILL}; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteFloat16}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt64, kTfLiteFloat16}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteInt8}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt64, kTfLiteInt16}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteBool}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteString}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt32, kTfLiteInt32}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} +TEST(OpVersionTest, VersioningResizeBilinearTest) { + // Default. + TfLiteResizeBilinearParams resize_bilinear_params = {}; + OpSignature fake_op_sig = { + .op = BuiltinOperator_RESIZE_BILINEAR, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&resize_bilinear_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // align_corners=true is still version 1. + resize_bilinear_params.align_corners = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // half_pixel_centers=true must be version 3. + resize_bilinear_params.align_corners = false; + resize_bilinear_params.half_pixel_centers = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + // int8 input is version 2. + resize_bilinear_params = {}; + fake_op_sig = { + .op = BuiltinOperator_RESIZE_BILINEAR, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&resize_bilinear_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + resize_bilinear_params.half_pixel_centers = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + // int16 input is version 4. + resize_bilinear_params = {}; + fake_op_sig = { + .op = BuiltinOperator_RESIZE_BILINEAR, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt16, kTfLiteInt32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .builtin_data = reinterpret_cast(&resize_bilinear_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); +} +TEST(OpVersionTest, VersioningResizeNearestNeighborTest) { + // Default. + TfLiteResizeNearestNeighborParams resize_nearest_neighbor_params = {}; + OpSignature fake_op_sig = { + .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&resize_nearest_neighbor_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // align_corners=true is version 3. + resize_nearest_neighbor_params.align_corners = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + // half_pixel_centers=true must be version 3. + resize_nearest_neighbor_params.align_corners = false; + resize_nearest_neighbor_params.half_pixel_centers = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + // int8 input is version 2. + resize_nearest_neighbor_params = {}; + fake_op_sig = { + .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&resize_nearest_neighbor_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + resize_nearest_neighbor_params.align_corners = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + // int16 input is version 4. + resize_nearest_neighbor_params = {}; + fake_op_sig = { + .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt16, kTfLiteInt32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .builtin_data = reinterpret_cast(&resize_nearest_neighbor_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); +} +TEST(OpVersionTest, VersioningAbsTest) { + // Default. + OpSignature fake_op_sig = { + .op = BuiltinOperator_ABS, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // int8 input is version 2. + fake_op_sig = { + .op = BuiltinOperator_ABS, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + // int16 quantized input is version 3. + fake_op_sig = { + .op = BuiltinOperator_ABS, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + fake_op_sig.ext_options.abs.input_quantized = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + // int16 non-quantized input is version 4. + fake_op_sig = { + .op = BuiltinOperator_ABS, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_ABS; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); +} +TEST(OpVersionTest, VersioningSignTest) { + // Default. + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_SIGN; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // int32 input is version 2. + fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_SIGN; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} +TEST(OpVersionTest, VersioningBatchMatMulTest) { + // Default. + TfLiteBatchMatMulParams batch_mat_mul_params = {}; + OpSignature fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&batch_mat_mul_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // int8 input is version 2. + batch_mat_mul_params = {}; + fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&batch_mat_mul_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + // int16 input is version 3. + fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt16, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .builtin_data = reinterpret_cast(&batch_mat_mul_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + // Symmetric hybrid quantized input is version 1. + fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&batch_mat_mul_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // Asymmetric hybrid quantized input is version 4. + fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&batch_mat_mul_params), + }; + batch_mat_mul_params.asymmetric_quantize_inputs = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); +} +TEST(OpVersionTest, VersioningSquaredDifferenceTest) { + // Default. + OpSignature fake_op_sig = { + .op = BuiltinOperator_SQUARED_DIFFERENCE, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // int8 input is version 2. + fake_op_sig = { + .op = BuiltinOperator_SQUARED_DIFFERENCE, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} +TEST(OpVersionTest, VersioningRsqrtTest) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_RSQRT; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); +} +TEST(OpVersionTest, VersioningBroadcastToTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_BROADCAST_TO, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + // Quantized broadcast_to op is version 3. + fake_op_sig = { + .op = BuiltinOperator_BROADCAST_TO, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { + .op = BuiltinOperator_BROADCAST_TO, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); +} + +TEST(OpVersionTest, VersioningGeluTest) { + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_GELU; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.op = BuiltinOperator_GELU; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.op = BuiltinOperator_GELU; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} + +TEST(OpVersionTest, VersioningUnidirectionalLstmTest) { + TfLiteUnidirectionalSequenceLSTMParams params = {}; + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteFloat32}); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + fake_op_sig.builtin_data = reinterpret_cast(¶ms); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteInt8}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + params.asymmetric_quantize_inputs = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + params.diagonal_recurrent_tensors = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); +} + +TEST(OpVersionTest, VersioningExpTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_EXP, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + fake_op_sig = { + .op = BuiltinOperator_EXP, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig = { + .op = BuiltinOperator_EXP, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} + +TEST(OpVersionTest, VersioningLogTest) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_LOG; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} + +TEST(OpVersionTest, VersioningDynamicUpdateSliceTest) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_DYNAMIC_UPDATE_SLICE; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteInt32}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteInt64}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc new file mode 100644 index 00000000000000..b3bbd7f3be3faa --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc @@ -0,0 +1,510 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" + +namespace tflite { + +bool CompareRuntimeVersion(const std::string& v1, const std::string& v2) { + const std::vector vec1 = absl::StrSplit(v1, '.'); + const std::vector vec2 = absl::StrSplit(v2, '.'); + int i = 0; + while (i < vec1.size() && i < vec2.size()) { + int v1_val, v2_val; + if (absl::SimpleAtoi(vec1[i], &v1_val) && + absl::SimpleAtoi(vec2[i], &v2_val)) { + if (v1_val != v2_val) return v1_val < v2_val; + } + ++i; + } + // If there are remaining items in v2 not being compared, then v1 should + // precede v2. + return i < vec2.size(); +} + +std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, + int op_version) { + // A map from the version key of an op to its minimum runtime version. + // For example, {{kAveragePool, 1}, "1.5.0"}, means the 1st version of + // AveragePool requires a minimum TF Lite runtime version '1.5.0`. + // NOTE: When adding a new op version pair, associate it with the current + // runtime version defined in tensorflow/core/public/version.h. + static const std::map, + std::string>* op_version_map = + new std::map, std::string>( + {{{BuiltinOperator_AVERAGE_POOL_2D, 1}, "1.5.0"}, + {{BuiltinOperator_AVERAGE_POOL_2D, 2}, "1.14.0"}, + {{BuiltinOperator_AVERAGE_POOL_2D, 3}, "2.3.0"}, + {{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"}, + {{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"}, + {{BuiltinOperator_BATCH_MATMUL, 3}, "2.4.0"}, + {{BuiltinOperator_BATCH_MATMUL, 4}, "2.5.0"}, + // The version one of broadcast to op won't be not supported since + // the version one was rollbacked and the builtin op code number + // has been changed because of builtin op code shortage problem. + {{BuiltinOperator_BROADCAST_TO, 2}, "2.5.0"}, + {{BuiltinOperator_BROADCAST_TO, 3}, "2.5.0"}, + {{BuiltinOperator_CONV_2D, 1}, "1.5.0"}, + {{BuiltinOperator_CONV_2D, 2}, "1.14.0"}, + {{BuiltinOperator_CONV_2D, 3}, "1.14.0"}, + {{BuiltinOperator_CONV_2D, 4}, "2.3.0"}, + {{BuiltinOperator_CONV_2D, 5}, "2.4.0"}, + {{BuiltinOperator_CONV_2D, 6}, "2.9.0"}, + {{BuiltinOperator_CONV_2D, 7}, "2.11.0"}, + {{BuiltinOperator_CONV_2D, 8}, "2.15.0"}, + {{BuiltinOperator_DEPTHWISE_CONV_2D, 1}, "1.5.0"}, + {{BuiltinOperator_DEPTHWISE_CONV_2D, 2}, "1.12.0"}, + {{BuiltinOperator_DEPTHWISE_CONV_2D, 3}, "1.14.0"}, + {{BuiltinOperator_DEPTHWISE_CONV_2D, 4}, "2.2.0"}, + {{BuiltinOperator_DEPTHWISE_CONV_2D, 5}, "2.3.0"}, + {{BuiltinOperator_DEPTHWISE_CONV_2D, 6}, "2.3.0"}, + {{BuiltinOperator_DEPTHWISE_CONV_2D, 7}, "2.11.0"}, + {{BuiltinOperator_ADD, 1}, "1.5.0"}, + {{BuiltinOperator_ADD, 2}, "1.14.0"}, + {{BuiltinOperator_ADD, 3}, "2.4.0"}, + {{BuiltinOperator_ADD, 4}, "2.6.0"}, + {{BuiltinOperator_ADD, 5}, "2.13.0"}, + {{BuiltinOperator_ADD_N, 1}, "1.14.0"}, + {{BuiltinOperator_SPACE_TO_BATCH_ND, 1}, "1.6.0"}, + {{BuiltinOperator_SPACE_TO_BATCH_ND, 2}, "1.14.0"}, + {{BuiltinOperator_SPACE_TO_BATCH_ND, 3}, "2.3.0"}, + {{BuiltinOperator_SPACE_TO_BATCH_ND, 4}, "2.12.0"}, + {{BuiltinOperator_SUB, 1}, "1.6.0"}, + {{BuiltinOperator_SUB, 2}, "1.14.0"}, + {{BuiltinOperator_SUB, 3}, "2.3.0"}, + {{BuiltinOperator_SUB, 4}, "2.4.0"}, + {{BuiltinOperator_SUB, 5}, "2.4.0"}, + {{BuiltinOperator_DENSIFY, 1}, "2.2.0"}, + {{BuiltinOperator_DIV, 1}, "1.6.0"}, + {{BuiltinOperator_DIV, 2}, "2.3.0"}, + {{BuiltinOperator_BATCH_TO_SPACE_ND, 1}, "1.6.0"}, + {{BuiltinOperator_BATCH_TO_SPACE_ND, 2}, "1.14.0"}, + {{BuiltinOperator_BATCH_TO_SPACE_ND, 3}, "2.3.0"}, + {{BuiltinOperator_BATCH_TO_SPACE_ND, 4}, "2.12.0"}, + {{BuiltinOperator_CAST, 1}, "1.5.0"}, + {{BuiltinOperator_CAST, 2}, "2.7.0"}, + {{BuiltinOperator_CAST, 3}, "2.8.0"}, + {{BuiltinOperator_CAST, 4}, "2.9.0"}, + {{BuiltinOperator_CAST, 5}, "2.12.0"}, + {{BuiltinOperator_CAST, 6}, "2.15.0"}, + {{BuiltinOperator_CONCATENATION, 1}, "1.5.0"}, + {{BuiltinOperator_CONCATENATION, 2}, "1.14.0"}, + {{BuiltinOperator_CONCATENATION, 3}, "2.3.0"}, + {{BuiltinOperator_CONCATENATION, 4}, "2.14.0"}, + {{BuiltinOperator_DEPTH_TO_SPACE, 1}, "2.1.0"}, + {{BuiltinOperator_DEPTH_TO_SPACE, 2}, "2.5.0"}, + {{BuiltinOperator_EMBEDDING_LOOKUP, 1}, "1.13.0"}, + {{BuiltinOperator_EMBEDDING_LOOKUP, 2}, "1.14.0"}, + {{BuiltinOperator_EMBEDDING_LOOKUP, 3}, "1.14.0"}, + {{BuiltinOperator_EMBEDDING_LOOKUP, 4}, "2.18.0"}, + {{BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, 1}, "1.5.0"}, + {{BuiltinOperator_FAKE_QUANT, 1}, "1.5.0"}, + {{BuiltinOperator_FAKE_QUANT, 2}, "1.10.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 1}, "1.5.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 2}, "1.10.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 3}, "1.14.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 4}, "1.14.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 5}, "2.0.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 6}, "2.1.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 7}, "2.3.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 8}, "2.3.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 9}, "2.3.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 10}, "2.11.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 11}, "2.15.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 12}, "2.17.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 13}, "2.18.0"}, + {{BuiltinOperator_GATHER, 1}, "1.6.0"}, + {{BuiltinOperator_GATHER, 2}, "1.14.0"}, + {{BuiltinOperator_GATHER, 3}, "1.15.0"}, + {{BuiltinOperator_GATHER, 4}, "2.4.0"}, + {{BuiltinOperator_GATHER, 5}, "2.5.0"}, + {{BuiltinOperator_GATHER, 6}, "2.13.0"}, + {{BuiltinOperator_GATHER, 7}, "2.15.0"}, + {{BuiltinOperator_GATHER_ND, 1}, "1.14.0"}, + {{BuiltinOperator_GATHER_ND, 2}, "2.3.0"}, + {{BuiltinOperator_GATHER_ND, 3}, "2.5.0"}, + {{BuiltinOperator_GATHER_ND, 4}, "2.13.0"}, + {{BuiltinOperator_GATHER_ND, 5}, "2.16.0"}, + {{BuiltinOperator_HASHTABLE_LOOKUP, 1}, "1.5.0"}, + {{BuiltinOperator_SVDF, 1}, "1.5.0"}, + {{BuiltinOperator_SVDF, 2}, "1.14.0"}, + {{BuiltinOperator_SVDF, 3}, "2.2.0"}, + {{BuiltinOperator_SVDF, 4}, "2.3.0"}, + {{BuiltinOperator_L2_NORMALIZATION, 1}, "1.5.0"}, + {{BuiltinOperator_L2_NORMALIZATION, 2}, "1.14.0"}, + {{BuiltinOperator_L2_POOL_2D, 1}, "1.5.0"}, + {{BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, 1}, "1.5.0"}, + {{BuiltinOperator_MAX_POOL_2D, 1}, "1.5.0"}, + {{BuiltinOperator_MAX_POOL_2D, 2}, "1.14.0"}, + {{BuiltinOperator_MAX_POOL_2D, 3}, "2.3.0"}, + {{BuiltinOperator_MAXIMUM, 1}, "1.14.0"}, + {{BuiltinOperator_MAXIMUM, 2}, "1.14.0"}, + {{BuiltinOperator_MAXIMUM, 3}, "2.3.0"}, + {{BuiltinOperator_MAXIMUM, 4}, "2.3.0"}, + {{BuiltinOperator_MINIMUM, 1}, "1.14.0"}, + {{BuiltinOperator_MINIMUM, 2}, "1.14.0"}, + {{BuiltinOperator_MINIMUM, 3}, "2.3.0"}, + {{BuiltinOperator_MINIMUM, 4}, "2.3.0"}, + {{BuiltinOperator_MUL, 1}, "1.5.0"}, + {{BuiltinOperator_MUL, 2}, "1.14.0"}, + {{BuiltinOperator_MUL, 3}, "1.15.0"}, + {{BuiltinOperator_MUL, 4}, "2.3.0"}, + {{BuiltinOperator_MUL, 5}, "2.6.0"}, + {{BuiltinOperator_MUL, 6}, "2.11.0"}, + {{BuiltinOperator_MUL, 7}, "2.13.0"}, + {{BuiltinOperator_NON_MAX_SUPPRESSION_V4, 1}, "2.1.0"}, + {{BuiltinOperator_NON_MAX_SUPPRESSION_V5, 1}, "2.1.0"}, + {{BuiltinOperator_PAD, 1}, "1.5.0"}, + {{BuiltinOperator_PAD, 2}, "1.14.0"}, + {{BuiltinOperator_PAD, 3}, "2.4.0"}, + {{BuiltinOperator_PAD, 4}, "2.6.0"}, + {{BuiltinOperator_TILE, 1}, "1.10.1"}, + {{BuiltinOperator_TILE, 2}, "2.2.0"}, + {{BuiltinOperator_TILE, 3}, "2.8.0"}, + {{BuiltinOperator_PADV2, 1}, "1.9.0"}, + {{BuiltinOperator_PADV2, 2}, "1.14.0"}, + {{BuiltinOperator_PADV2, 3}, "2.4.0"}, + {{BuiltinOperator_PADV2, 4}, "2.6.0"}, + {{BuiltinOperator_RESHAPE, 1}, "1.5.0"}, + {{BuiltinOperator_SOFTMAX, 1}, "1.5.0"}, + {{BuiltinOperator_SOFTMAX, 2}, "1.14.0"}, + {{BuiltinOperator_SOFTMAX, 3}, "2.3.0"}, + {{BuiltinOperator_SPACE_TO_DEPTH, 1}, "1.5.0"}, + {{BuiltinOperator_SPACE_TO_DEPTH, 2}, "1.14.0"}, + {{BuiltinOperator_TRANSPOSE, 1}, "1.6.0"}, + {{BuiltinOperator_TRANSPOSE, 2}, "1.14.0"}, + {{BuiltinOperator_TRANSPOSE, 3}, "1.15.0"}, + {{BuiltinOperator_TRANSPOSE, 4}, "2.3.0"}, + {{BuiltinOperator_TRANSPOSE, 5}, "2.4.0"}, + {{BuiltinOperator_TRANSPOSE, 6}, "2.12.0"}, + {{BuiltinOperator_LSTM, 1}, "1.7.0"}, + {{BuiltinOperator_LSTM, 2}, "1.10.0"}, + {{BuiltinOperator_LSTM, 3}, "1.14.0"}, + {{BuiltinOperator_LSTM, 4}, "2.3.0"}, + {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 1}, "1.13.1"}, + {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 2}, "1.14.0"}, + {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 3}, "2.3.0"}, + {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 4}, "2.12.0"}, + {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, 1}, "1.14.0"}, + {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, 2}, "1.14.0"}, + {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, 3}, "1.14.0"}, + {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, 1}, "1.14.0"}, + {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, 2}, "1.14.0"}, + {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, 3}, "2.3.0"}, + {{BuiltinOperator_MEAN, 1}, "1.6.0"}, + {{BuiltinOperator_MEAN, 2}, "1.14.0"}, + {{BuiltinOperator_MEAN, 3}, "2.4.0"}, + {{BuiltinOperator_SUM, 1}, "1.10.0"}, + {{BuiltinOperator_SUM, 2}, "1.15.0"}, + {{BuiltinOperator_REDUCE_MAX, 1}, "1.11.0"}, + {{BuiltinOperator_REDUCE_MAX, 2}, "1.14.0"}, + {{BuiltinOperator_REDUCE_MAX, 3}, "2.5.0"}, + {{BuiltinOperator_REDUCE_MIN, 1}, "1.11.0"}, + {{BuiltinOperator_REDUCE_MIN, 2}, "1.14.0"}, + {{BuiltinOperator_REDUCE_MIN, 3}, "2.5.0"}, + {{BuiltinOperator_REDUCE_PROD, 1}, "1.11.0"}, + {{BuiltinOperator_REDUCE_PROD, 2}, "2.6.0"}, + {{BuiltinOperator_REDUCE_ANY, 1}, "1.11.0"}, + {{BuiltinOperator_RELU6, 1}, "1.5.0"}, + {{BuiltinOperator_RELU6, 2}, "1.14.0"}, + {{BuiltinOperator_RELU6, 3}, "2.5.0"}, + {{BuiltinOperator_RESIZE_BILINEAR, 1}, "1.7.0"}, + {{BuiltinOperator_RESIZE_BILINEAR, 2}, "1.14.0"}, + {{BuiltinOperator_RESIZE_BILINEAR, 3}, "2.2.0"}, + {{BuiltinOperator_RESIZE_BILINEAR, 4}, "2.5.0"}, + {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 1}, "1.13.1"}, + {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 2}, "1.14.0"}, + {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 3}, "2.3.0"}, + {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 4}, "2.4.0"}, + {{BuiltinOperator_RNN, 1}, "1.5.0"}, + {{BuiltinOperator_RNN, 2}, "1.14.0"}, + {{BuiltinOperator_RNN, 3}, "2.3.0"}, + {{BuiltinOperator_SKIP_GRAM, 1}, "1.5.0"}, + {{BuiltinOperator_SQUEEZE, 1}, "1.6.0"}, + {{BuiltinOperator_SQUEEZE, 2}, "2.5.0"}, + {{BuiltinOperator_SPLIT, 1}, "1.5.0"}, + {{BuiltinOperator_SPLIT, 2}, "1.14.0"}, + {{BuiltinOperator_SPLIT, 3}, "1.14.0"}, + {{BuiltinOperator_SPLIT, 4}, "2.3.0"}, + {{BuiltinOperator_SPLIT_V, 1}, "1.13.1"}, + {{BuiltinOperator_SPLIT_V, 2}, "2.3.0"}, + {{BuiltinOperator_STRIDED_SLICE, 1}, "1.6.0"}, + {{BuiltinOperator_STRIDED_SLICE, 2}, "1.14.0"}, + {{BuiltinOperator_STRIDED_SLICE, 3}, "2.1.0"}, + {{BuiltinOperator_STRIDED_SLICE, 4}, "2.2.0"}, + {{BuiltinOperator_STRIDED_SLICE, 5}, "2.5.0"}, + {{BuiltinOperator_STRIDED_SLICE, 6}, "2.6.0"}, + {{BuiltinOperator_STRIDED_SLICE, 7}, "2.14.0"}, + {{BuiltinOperator_STRIDED_SLICE, 8}, "2.14.0"}, + {{BuiltinOperator_TOPK_V2, 1}, "1.7.0"}, + {{BuiltinOperator_TOPK_V2, 2}, "1.14.0"}, + {{BuiltinOperator_TOPK_V2, 3}, "2.13.0"}, + {{BuiltinOperator_ARG_MAX, 1}, "1.9.0"}, + {{BuiltinOperator_ARG_MAX, 2}, "1.14.0"}, + {{BuiltinOperator_ARG_MAX, 3}, "2.9.0"}, + {{BuiltinOperator_ARG_MIN, 1}, "1.9.0"}, + {{BuiltinOperator_ARG_MIN, 2}, "1.14.0"}, + {{BuiltinOperator_ARG_MIN, 3}, "2.9.0"}, + {{BuiltinOperator_TRANSPOSE_CONV, 1}, "1.9.0"}, + {{BuiltinOperator_TRANSPOSE_CONV, 2}, "2.2.0"}, + {{BuiltinOperator_TRANSPOSE_CONV, 3}, "2.3.0"}, + {{BuiltinOperator_TRANSPOSE_CONV, 4}, "2.13.0"}, + {{BuiltinOperator_TRANSPOSE_CONV, 5}, "2.15.0"}, + {{BuiltinOperator_SPARSE_TO_DENSE, 1}, "1.9.0"}, + {{BuiltinOperator_SPARSE_TO_DENSE, 2}, "1.14.0"}, + {{BuiltinOperator_SPARSE_TO_DENSE, 3}, "1.15.0"}, + {{BuiltinOperator_EXPAND_DIMS, 1}, "1.10.0"}, + {{BuiltinOperator_PACK, 1}, "1.11.0"}, + {{BuiltinOperator_PACK, 2}, "1.14.0"}, + {{BuiltinOperator_PACK, 3}, "2.3.0"}, + {{BuiltinOperator_PACK, 4}, "2.13.0"}, + {{BuiltinOperator_SHAPE, 1}, "1.10.0"}, + {{BuiltinOperator_SLICE, 1}, "1.14.0"}, + {{BuiltinOperator_SLICE, 2}, "1.14.0"}, + {{BuiltinOperator_SLICE, 3}, "1.14.0"}, + {{BuiltinOperator_SLICE, 4}, "2.4.0"}, + {{BuiltinOperator_SLICE, 5}, "2.5.0"}, + {{BuiltinOperator_SLICE, 6}, "2.14.0"}, + {{BuiltinOperator_TANH, 1}, "1.14.0"}, + {{BuiltinOperator_TANH, 2}, "1.14.0"}, + {{BuiltinOperator_TANH, 3}, "2.3.0"}, + {{BuiltinOperator_ONE_HOT, 1}, "1.11.0"}, + {{BuiltinOperator_UNPACK, 1}, "1.11.0"}, + {{BuiltinOperator_UNPACK, 2}, "1.14.0"}, + {{BuiltinOperator_UNPACK, 3}, "2.2.0"}, + {{BuiltinOperator_UNPACK, 4}, "2.3.0"}, + {{BuiltinOperator_LEAKY_RELU, 1}, "1.13.1"}, + {{BuiltinOperator_LEAKY_RELU, 2}, "2.3.0"}, + {{BuiltinOperator_LOGISTIC, 1}, "1.14.0"}, + {{BuiltinOperator_LOGISTIC, 2}, "1.14.0"}, + {{BuiltinOperator_LOGISTIC, 3}, "2.3.0"}, + {{BuiltinOperator_LOG_SOFTMAX, 1}, "1.14.0"}, + {{BuiltinOperator_LOG_SOFTMAX, 2}, "1.14.0"}, + {{BuiltinOperator_LSH_PROJECTION, 1}, "1.5.0"}, + {{BuiltinOperator_SQUARED_DIFFERENCE, 1}, "1.13.1"}, + {{BuiltinOperator_SQUARED_DIFFERENCE, 2}, "2.5.0"}, + {{BuiltinOperator_MIRROR_PAD, 1}, "1.13.1"}, + {{BuiltinOperator_MIRROR_PAD, 2}, "2.3.0"}, + {{BuiltinOperator_MIRROR_PAD, 3}, "2.12.0"}, + {{BuiltinOperator_UNIQUE, 1}, "1.14.0"}, + {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, 1}, "1.14.0"}, + {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, 2}, "1.14.0"}, + {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, 3}, "2.3.0"}, + {{BuiltinOperator_WHERE, 1}, "1.14.0"}, + {{BuiltinOperator_DEQUANTIZE, 1}, "1.13.1"}, + {{BuiltinOperator_DEQUANTIZE, 2}, "1.14.0"}, + {{BuiltinOperator_DEQUANTIZE, 3}, "1.15.0"}, + {{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"}, + {{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"}, + {{BuiltinOperator_DEQUANTIZE, 6}, "2.18.0"}, + {{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"}, + {{BuiltinOperator_EQUAL, 1}, "1.14.0"}, + {{BuiltinOperator_EQUAL, 2}, "1.14.0"}, + {{BuiltinOperator_EQUAL, 3}, "2.3.0"}, + {{BuiltinOperator_EQUAL, 4}, "2.13.0"}, + {{BuiltinOperator_NOT_EQUAL, 1}, "1.14.0"}, + {{BuiltinOperator_NOT_EQUAL, 2}, "1.14.0"}, + {{BuiltinOperator_NOT_EQUAL, 3}, "2.3.0"}, + {{BuiltinOperator_GREATER, 1}, "1.14.0"}, + {{BuiltinOperator_GREATER, 2}, "1.14.0"}, + {{BuiltinOperator_GREATER_EQUAL, 1}, "1.14.0"}, + {{BuiltinOperator_GREATER_EQUAL, 2}, "1.14.0"}, + {{BuiltinOperator_GREATER_EQUAL, 3}, "2.13.0"}, + {{BuiltinOperator_LESS, 1}, "1.14.0"}, + {{BuiltinOperator_LESS, 2}, "1.14.0"}, + {{BuiltinOperator_LESS, 3}, "2.13.0"}, + {{BuiltinOperator_LESS_EQUAL, 1}, "1.14.0"}, + {{BuiltinOperator_LESS_EQUAL, 2}, "1.14.0"}, + {{BuiltinOperator_SCATTER_ND, 1}, "2.1.0"}, + {{BuiltinOperator_SEGMENT_SUM, 1}, "2.2.0"}, + {{BuiltinOperator_SELECT, 1}, "1.14.0"}, + {{BuiltinOperator_SELECT, 2}, "1.14.0"}, + {{BuiltinOperator_SELECT, 3}, "2.12.0"}, + {{BuiltinOperator_SELECT, 4}, "2.12.0"}, + {{BuiltinOperator_SELECT_V2, 1}, "2.2.0"}, + {{BuiltinOperator_SELECT_V2, 2}, "2.12.0"}, + {{BuiltinOperator_IF, 1}, "1.15.0"}, + {{BuiltinOperator_FLOOR_DIV, 1}, "1.14.0"}, + {{BuiltinOperator_FLOOR_DIV, 2}, "1.14.0"}, + {{BuiltinOperator_FLOOR_DIV, 3}, "2.13.0"}, + {{BuiltinOperator_FLOOR, 1}, "1.9.0"}, + {{BuiltinOperator_CEIL, 1}, "1.14.0"}, + {{BuiltinOperator_MATRIX_DIAG, 1}, "1.14.0"}, + {{BuiltinOperator_MATRIX_SET_DIAG, 1}, "1.14.0"}, + {{BuiltinOperator_ELU, 1}, "1.14.0"}, + {{BuiltinOperator_QUANTIZE, 1}, "1.14.0"}, + {{BuiltinOperator_QUANTIZE, 2}, "1.15.0"}, + {{BuiltinOperator_QUANTIZE, 3}, "2.7.0"}, + {{BuiltinOperator_ROUND, 1}, "1.14.0"}, + {{BuiltinOperator_RELU, 1}, "1.5.0"}, + {{BuiltinOperator_RELU, 2}, "2.1.0"}, + {{BuiltinOperator_RELU, 3}, "2.5.0"}, + {{BuiltinOperator_RELU_N1_TO_1, 1}, "1.5.0"}, + {{BuiltinOperator_RELU_0_TO_1, 1}, "2.10.0"}, + {{BuiltinOperator_PRELU, 1}, "1.8.0"}, + {{BuiltinOperator_EXP, 1}, "1.7.0"}, + {{BuiltinOperator_EXP, 2}, "2.12.0"}, + {{BuiltinOperator_COS, 1}, "1.14.0"}, + {{BuiltinOperator_NEG, 1}, "1.9.0"}, + {{BuiltinOperator_POW, 1}, "1.10.0"}, + {{BuiltinOperator_LOGICAL_OR, 1}, "1.11.0"}, + {{BuiltinOperator_LOGICAL_AND, 1}, "1.11.0"}, + {{BuiltinOperator_LOGICAL_NOT, 1}, "1.11.0"}, + {{BuiltinOperator_FLOOR_MOD, 1}, "1.13.0"}, + {{BuiltinOperator_FLOOR_MOD, 2}, "2.13.0"}, + {{BuiltinOperator_RANGE, 1}, "1.13.0"}, + {{BuiltinOperator_RANGE, 2}, "2.14.0"}, + {{BuiltinOperator_SIN, 1}, "1.9.0"}, + {{BuiltinOperator_LOG, 1}, "1.14.0"}, + {{BuiltinOperator_LOG, 2}, "2.15.0"}, + {{BuiltinOperator_SQRT, 1}, "1.10.0"}, + {{BuiltinOperator_RSQRT, 1}, "1.10.0"}, + {{BuiltinOperator_RSQRT, 2}, "2.5.0"}, + {{BuiltinOperator_RSQRT, 3}, "2.15.0"}, + {{BuiltinOperator_SQUARE, 1}, "1.12.0"}, + {{BuiltinOperator_ZEROS_LIKE, 1}, "1.12.0"}, + {{BuiltinOperator_ABS, 1}, "1.13.0"}, + {{BuiltinOperator_ABS, 2}, "2.4.0"}, + {{BuiltinOperator_ABS, 3}, "2.5.0"}, + {{BuiltinOperator_ABS, 4}, "2.6.0"}, + {{BuiltinOperator_ABS, 5}, "2.12.0"}, + {{BuiltinOperator_HARD_SWISH, 1}, "1.15.0"}, + {{BuiltinOperator_FILL, 1}, "1.13.0"}, + {{BuiltinOperator_FILL, 2}, "2.3.0"}, + {{BuiltinOperator_FILL, 3}, "2.5.0"}, + {{BuiltinOperator_FILL, 4}, "2.12.0"}, + {{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"}, + {{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"}, + {{BuiltinOperator_REVERSE_V2, 3}, "2.5.0"}, + {{BuiltinOperator_RANK, 1}, "1.14.0"}, + {{BuiltinOperator_WHILE, 1}, "1.15.0"}, + {{BuiltinOperator_CUMSUM, 1}, "2.4.0"}, + {{BuiltinOperator_CALL_ONCE, 1}, "2.5.0"}, + {{BuiltinOperator_RFFT2D, 1}, "2.5.0"}, + {{BuiltinOperator_CONV_3D, 1}, "2.5.0"}, + {{BuiltinOperator_IMAG, 1}, "2.5.0"}, + {{BuiltinOperator_REAL, 1}, "2.5.0"}, + {{BuiltinOperator_COMPLEX_ABS, 1}, "2.5.0"}, + {{BuiltinOperator_HASHTABLE, 1}, "2.5.0"}, + {{BuiltinOperator_HASHTABLE_FIND, 1}, "2.5.0"}, + {{BuiltinOperator_HASHTABLE_IMPORT, 1}, "2.5.0"}, + {{BuiltinOperator_HASHTABLE_SIZE, 1}, "2.5.0"}, + {{BuiltinOperator_REDUCE_ALL, 1}, "2.6.0"}, + {{BuiltinOperator_CONV_3D_TRANSPOSE, 1}, "2.6.0"}, + {{BuiltinOperator_VAR_HANDLE, 1}, "2.6.0"}, + {{BuiltinOperator_READ_VARIABLE, 1}, "2.6.0"}, + {{BuiltinOperator_ASSIGN_VARIABLE, 1}, "2.6.0"}, + {{BuiltinOperator_BROADCAST_ARGS, 1}, "2.6.0"}, + {{BuiltinOperator_RANDOM_STANDARD_NORMAL, 1}, "2.8.0"}, + {{BuiltinOperator_BUCKETIZE, 1}, "2.8.0"}, + {{BuiltinOperator_WHERE, 2}, "2.8.0"}, + {{BuiltinOperator_RANDOM_UNIFORM, 1}, "2.8.0"}, + {{BuiltinOperator_MULTINOMIAL, 1}, "2.8.0"}, + {{BuiltinOperator_GELU, 1}, "2.9.0"}, + {{BuiltinOperator_GELU, 2}, "2.9.0"}, + {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 1}, "2.9.0"}, + {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 2}, "2.17.0"}, + {{BuiltinOperator_UNSORTED_SEGMENT_PROD, 1}, "2.10.0"}, + {{BuiltinOperator_UNSORTED_SEGMENT_MAX, 1}, "2.10.0"}, + {{BuiltinOperator_UNSORTED_SEGMENT_MIN, 1}, "2.11.0"}, + {{BuiltinOperator_UNSORTED_SEGMENT_SUM, 1}, "2.10.0"}, + {{BuiltinOperator_ATAN2, 1}, "2.10.0"}, + {{BuiltinOperator_SIGN, 1}, "2.11.0"}, + {{BuiltinOperator_SIGN, 2}, "2.12.0"}, + {{BuiltinOperator_BITCAST, 1}, "2.13.0"}, + {{BuiltinOperator_BITWISE_XOR, 1}, "2.13.0"}, + {{BuiltinOperator_RIGHT_SHIFT, 1}, "2.13.0"}, + {{BuiltinOperator_STABLEHLO_SCATTER, 1}, "2.15.0"}, + {{BuiltinOperator_DILATE, 1}, "2.15.0"}, + {{BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR, 1}, "2.15.0"}, + {{BuiltinOperator_REDUCE_WINDOW, 1}, "2.15.0"}, + {{BuiltinOperator_STABLEHLO_GATHER, 1}, "2.16.0"}, + {{BuiltinOperator_STABLEHLO_ADD, 1}, "2.16.0"}, + {{BuiltinOperator_STABLEHLO_MULTIPLY, 1}, "2.16.0"}, + {{BuiltinOperator_STABLEHLO_REDUCE_WINDOW, 1}, "2.16.0"}, + {{BuiltinOperator_STABLEHLO_MAXIMUM, 1}, "2.16.0"}, + {{BuiltinOperator_STABLEHLO_MINIMUM, 1}, "2.16.0"}, + {{BuiltinOperator_STABLEHLO_PAD, 1}, "2.16.0"}, + {{BuiltinOperator_STABLEHLO_COMPOSITE, 1}, "2.17.0"}, + {{BuiltinOperator_STABLEHLO_AND, 1}, "2.17.0"}, + {{BuiltinOperator_STABLEHLO_SHIFT_LEFT, 1}, "2.17.0"}, + {{BuiltinOperator_STABLEHLO_CBRT, 1}, "2.17.0"}}); + + std::pair version_key = {op_code, op_version}; + auto it = op_version_map->find(version_key); + if (it == op_version_map->end()) { + return std::string(); + } + return it->second; +} + +void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer) { + auto model = GetMutableModel(model_buffer_pointer); + std::string model_min_version; + auto subgraphs = model->subgraphs(); + for (int i = 0; i < subgraphs->size(); ++i) { + const SubGraph* subgraph = subgraphs->Get(i); + for (int j = 0; j < subgraph->operators()->size(); ++j) { + const Operator* op = subgraph->operators()->Get(j); + const OperatorCode* op_code = + model->operator_codes()->Get(op->opcode_index()); + std::string runtime_version = FindMinimumRuntimeVersionForOp( + GetBuiltinCode(op_code), op_code->version()); + // If we didn't find the current op version in the map, skip comparison. + if (runtime_version.empty()) { + continue; + } + if (CompareRuntimeVersion(model_min_version, runtime_version)) { + // Current min model runtime version should be bumped if we see a + // higher op version. + model_min_version = runtime_version; + } + } + } + // The size of the `min_runtime_version` metadata buffer is 16 bytes. If the + // generated `model_min_version` is equal or longer than 16 bytes, print a + // warning message and return. + if (model_min_version.size() >= 16) { + LOG(WARNING) << "Skip writing minimum runtime version string since it's " + << "longer than 16 bytes."; + return; + } + // Copy over the bytes from `model_min_version` into the buffer. + for (int i = 0; i < model->metadata()->size(); ++i) { + if (model->metadata()->Get(i)->name()->str() == "min_runtime_version") { + auto buffer = model->metadata()->Get(i)->buffer(); + auto buffer_data = + model->mutable_buffers()->GetMutableObject(buffer)->mutable_data(); + memset(buffer_data->data(), 0, buffer_data->size()); + memcpy(buffer_data->data(), model_min_version.data(), + model_min_version.size()); + break; + } + } +} + +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h new file mode 100644 index 00000000000000..7d586df5ab4c00 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ + +#include +#include + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" + +namespace tflite { +// Update minimum runtime version of the given TFL flatbuffer model. +void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer); + +// Find the minimum runtime version of a given op version. Return an empty +// string the version is not registered. +std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, + int op_version); + +// Returns true if the first version string precedes the second. +// For example, '1.9' should precede '1.14', also '1.14' should precede +// '1.14.1'. If two version string is equal, then false will be returned. +bool CompareRuntimeVersion(const std::string&, const std::string&); + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index bce9627fbd3381..3a66e9f0874a7a 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -386,6 +386,7 @@ if(TFLITE_ENABLE_GPU) ${TFLITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_SRCS} ${TFLITE_SOURCE_DIR}/tools/versioning/gpu_compatibility.cc ${TFLITE_SOURCE_DIR}/tools/versioning/op_signature.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/tools/versioning/op_signature.cc ) include_directories( AFTER @@ -684,6 +685,7 @@ set(_ALL_TFLITE_SRCS ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.cc ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.h ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/flatbuffer_conversions.cc ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/verifier.h ${TF_SOURCE_DIR}/compiler/mlir/lite/allocation.h ${TF_SOURCE_DIR}/compiler/mlir/lite/allocation.cc diff --git a/tensorflow/lite/kernels/CMakeLists.txt b/tensorflow/lite/kernels/CMakeLists.txt index 946b56353e6c15..b63f7f5b6cc8d6 100644 --- a/tensorflow/lite/kernels/CMakeLists.txt +++ b/tensorflow/lite/kernels/CMakeLists.txt @@ -154,6 +154,7 @@ target_link_libraries(tensorflow-lite-test-external-main tensorflow-lite-test-base -Wl,--no-whole-archive gtest + absl::log ) macro(add_kernel_test TEST_SRC TEST_LIB) diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD index 7377ec00d6b666..cfcb613719baca 100644 --- a/tensorflow/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -27,6 +27,8 @@ cc_library( deps = [ ":types", "//tensorflow/compiler/mlir/lite/delegates/flex:allowlisted_flex_ops_lib", + "//tensorflow/compiler/mlir/lite/tools/versioning", + "//tensorflow/compiler/mlir/lite/tools/versioning:op_signature", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/c:c_api_types", @@ -36,8 +38,6 @@ cc_library( "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:runtime", "//tensorflow/lite/toco:toco_port", - "//tensorflow/lite/tools/versioning", - "//tensorflow/lite/tools/versioning:op_signature", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -114,6 +114,7 @@ cc_library( ":types", "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantize_weights", "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", + "//tensorflow/compiler/mlir/lite/tools/versioning", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:status", "//tensorflow/lite:schema_fbs_version", @@ -123,7 +124,6 @@ cc_library( "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:toco_port", "//tensorflow/lite/toco:tooling_util", - "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@flatbuffers", diff --git a/tensorflow/lite/toco/tflite/builtin_operator.h b/tensorflow/lite/toco/tflite/builtin_operator.h index 69a1d4e5970b0b..fa4c1df9b9c50c 100644 --- a/tensorflow/lite/toco/tflite/builtin_operator.h +++ b/tensorflow/lite/toco/tflite/builtin_operator.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" #include "tensorflow/lite/toco/tflite/operator.h" namespace toco { diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index 44223eac63c130..a1ca936464e95a 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -25,6 +25,7 @@ limitations under the License. #include "flatbuffers/string.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/lite/tools/versioning/runtime_version.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 428afa0a9076e2..c73e30781faf09 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -31,6 +31,8 @@ limitations under the License. // graph_transformation module. #include "tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -42,8 +44,6 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/simple_operator.h" #include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/lite/toco/toco_types.h" -#include "tensorflow/lite/tools/versioning/op_signature.h" -#include "tensorflow/lite/tools/versioning/op_version.h" namespace toco { diff --git a/tensorflow/lite/toco/tflite/operator.h b/tensorflow/lite/toco/tflite/operator.h index 7b8b6b64e21e83..836c287674e084 100644 --- a/tensorflow/lite/toco/tflite/operator.h +++ b/tensorflow/lite/toco/tflite/operator.h @@ -22,10 +22,9 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/tools/versioning/op_signature.h" -#include "tensorflow/lite/tools/versioning/op_version.h" namespace toco { diff --git a/tensorflow/lite/toco/tflite/simple_operator.h b/tensorflow/lite/toco/tflite/simple_operator.h index 150b0d0721706e..7f26ee2eaac339 100644 --- a/tensorflow/lite/toco/tflite/simple_operator.h +++ b/tensorflow/lite/toco/tflite/simple_operator.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ #define TENSORFLOW_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" #include "tensorflow/lite/toco/tflite/operator.h" namespace toco { diff --git a/tensorflow/lite/tools/benchmark/CMakeLists.txt b/tensorflow/lite/tools/benchmark/CMakeLists.txt index 477254a42ecb9e..79986f6c0ecd29 100644 --- a/tensorflow/lite/tools/benchmark/CMakeLists.txt +++ b/tensorflow/lite/tools/benchmark/CMakeLists.txt @@ -52,6 +52,7 @@ list(APPEND TFLITE_BENCHMARK_LIBS example_proto model_runtime_info_proto protobuf::libprotobuf + absl::log ) # TODO(b/171007016): Enable performance options on Windows. diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD index f173ce2c89734b..c0945a5c7c0fa6 100644 --- a/tensorflow/lite/tools/versioning/BUILD +++ b/tensorflow/lite/tools/versioning/BUILD @@ -66,6 +66,7 @@ cc_library( ], compatible_with = get_compatible_with_portable(), deps = [ + "//tensorflow/compiler/mlir/lite/tools/versioning:op_signature", "//tensorflow/lite:stderr_reporter", "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:c_api_types", diff --git a/tensorflow/lite/tools/versioning/op_signature.cc b/tensorflow/lite/tools/versioning/op_signature.cc index 64b97924b47352..19373155956f29 100644 --- a/tensorflow/lite/tools/versioning/op_signature.cc +++ b/tensorflow/lite/tools/versioning/op_signature.cc @@ -14,88 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/tools/versioning/op_signature.h" -#include +#include +#include +#include -#include "tensorflow/lite/core/api/flatbuffer_conversions.h" -#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" +#include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/schema/schema_utils.h" -#include "tensorflow/lite/stderr_reporter.h" namespace tflite { namespace { -// A BuiltinDataAllocator which just uses malloc()/free(). -class MallocDataAllocator : public BuiltinDataAllocator { - public: - void* Allocate(size_t size, size_t alignment_hint) override { - return malloc(size); - } - void Deallocate(void* data) override { free(data); } -}; - -// Get the number of dimensions of a tensor with idx of an operator op. -inline int GetNumDims(const SubGraph* subgraph, const Operator* op, int idx) { - const flatbuffers::Vector* ret = - subgraph->tensors()->Get(op->inputs()->Get(idx))->shape(); - if (ret) { - return ret->size(); - } else { - return 0; - } -} - -std::vector GetOpSignatureTensorSpecs( - const flatbuffers::Vector* tensors, const SubGraph* subgraph, - const Model* model) { - std::vector tensor_specs; - if (!tensors) { - return tensor_specs; - } - StderrReporter error_reporter; - - for (int32_t i = 0; i < tensors->Length(); ++i) { - int32_t tensor_no = tensors->Get(i); - - OpSignatureTensorSpec tensor_spec = {kTfLiteNoType}; - if (tensor_no >= 0) { - if (subgraph->tensors() && tensor_no < subgraph->tensors()->Length()) { - auto* fb_tensor = subgraph->tensors()->Get(tensor_no); - ConvertTensorType(fb_tensor->type(), &tensor_spec.type, - &error_reporter); - auto buffer_idx = fb_tensor->buffer(); - // Check if the tensor is a constant tensor. - if (buffer_idx != 0 && buffer_idx < model->buffers()->Length()) { - auto* buffer = model->buffers()->Get(buffer_idx); - if (buffer->data() && buffer->data()->size() != 0) { - tensor_spec.is_const = true; - } - } - const flatbuffers::Vector* shape_vec = fb_tensor->shape(); - if (shape_vec) { - for (int32_t j = 0; j < shape_vec->Length(); ++j) { - tensor_spec.dims.push_back(shape_vec->Get(j)); - } - } - const flatbuffers::Vector* shape_signature_vec = - fb_tensor->shape_signature(); - tensor_spec.is_shape_dynamic = false; - if (shape_signature_vec) { - for (int32_t j = 0; j < shape_signature_vec->Length(); ++j) { - if (shape_signature_vec->Get(j) == -1) { - tensor_spec.is_shape_dynamic = true; - break; - } - } - } - } - } - tensor_specs.push_back(tensor_spec); - } - return tensor_specs; -} - std::vector GetOpSignatureTensorSpecs( TfLiteIntArray* tensors, const TfLiteContext* context, const TfLiteNode* tflite_node) { @@ -130,167 +60,6 @@ std::vector GetOpSignatureTensorSpecs( } // namespace -OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, - const SubGraph* subgraph, const Model* model) { - auto builtin_code = GetBuiltinCode(op_code); - OpSignature op_sig = {builtin_code}; - std::memset(&op_sig.ext_options, 0, sizeof(op_sig.ext_options)); - - if (builtin_code != BuiltinOperator_CUSTOM) { - StderrReporter error_reporter; - MallocDataAllocator allocator; - ParseOpData(op, builtin_code, &error_reporter, &allocator, - &op_sig.builtin_data); - } else { - op_sig.custom_name = op_code->custom_code()->str(); - } - - switch (builtin_code) { - case BuiltinOperator_DEPTHWISE_CONV_2D: { - const Tensor* filter_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - const QuantizationParameters* filter_quant = - filter_tensor->quantization(); - int num_channels = filter_tensor->shape()->Get(3); - if (filter_quant && filter_quant->scale() && - filter_quant->scale()->Length() && - filter_quant->scale()->Length() == num_channels) { - op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized = true; - } - } break; - - case BuiltinOperator_FULLY_CONNECTED: { - const Tensor* weight_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - op_sig.ext_options.fully_connected.sparse_weight = - (weight_tensor->sparsity() != nullptr); - const QuantizationParameters* weight_quant = - weight_tensor->quantization(); - if (weight_quant && weight_quant->scale() && - weight_quant->scale()->size() && weight_tensor->shape() && - weight_tensor->shape()->size()) { - op_sig.ext_options.fully_connected.is_per_channel_quantized = - weight_quant->scale()->size() > 1 && - weight_quant->scale()->size() == weight_tensor->shape()->Get(0); - } - } break; - - case BuiltinOperator_MUL: { - if (op->inputs()->Length() < 2 || op->outputs()->Length() < 1) { - break; - } - const Tensor* input1_tensor = - subgraph->tensors()->Get(op->inputs()->Get(0)); - const Tensor* input2_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - const Tensor* output_tensor = - subgraph->tensors()->Get(op->outputs()->Get(0)); - const QuantizationParameters* input1_quant = - input1_tensor->quantization(); - const QuantizationParameters* input2_qunt = input2_tensor->quantization(); - const QuantizationParameters* output_quant = - output_tensor->quantization(); - if (input1_quant && input1_quant->scale() && - input1_quant->scale()->Length() && input2_qunt && - input2_qunt->scale() && input2_qunt->scale()->Length() && - output_quant && output_quant->scale() && - output_quant->scale()->Length()) { - op_sig.ext_options.mul.input1_scale = input1_quant->scale()->Get(0); - op_sig.ext_options.mul.input2_scale = input2_qunt->scale()->Get(0); - op_sig.ext_options.mul.output_scale = output_quant->scale()->Get(0); - } - if (input1_quant || input2_qunt) { - op_sig.ext_options.mul.input_quantized = true; - } - } break; - - case BuiltinOperator_CONV_2D: { - const Tensor* input_tensor = - subgraph->tensors()->Get(op->inputs()->Get(0)); - const Tensor* filter_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - const QuantizationParameters* filter_quant = - filter_tensor->quantization(); - int num_filters = filter_tensor->shape()->Get(0); - if (filter_quant && filter_quant->scale() && - filter_quant->scale()->Length() && - filter_quant->scale()->Length() == num_filters) { - op_sig.ext_options.conv_2d.is_per_channel_quantized = true; - } - if (input_tensor->shape() && input_tensor->shape()->size()) { - int num_input_channels = input_tensor->shape()->Get(3); - int num_filter_input_channels = filter_tensor->shape()->Get(3); - op_sig.ext_options.conv_2d.is_grouped_convolution = - num_input_channels != num_filter_input_channels; - } else { - op_sig.ext_options.conv_2d.is_grouped_convolution = false; - } - } break; - - case BuiltinOperator_STRIDED_SLICE: { - op_sig.ext_options.strided_slice.num_dims = GetNumDims(subgraph, op, 0); - } break; - - case BuiltinOperator_ABS: { - if (subgraph->tensors()->Get(op->inputs()->Get(0))->quantization()) { - op_sig.ext_options.abs.input_quantized = true; - } - } break; - - case BuiltinOperator_DEQUANTIZE: { - const Tensor* input_tensor = - subgraph->tensors()->Get(op->inputs()->Get(0)); - const QuantizationParameters* input_quant = input_tensor->quantization(); - if (input_quant && input_quant->scale() && - input_quant->scale()->Length() > 1 && - input_quant->scale()->Length() == - input_tensor->shape()->Get(input_quant->quantized_dimension())) { - op_sig.ext_options.dequantize.is_per_channel_quantized = true; - } - } break; - - case BuiltinOperator_QUANTIZE: { - const Tensor* output_tensor = - subgraph->tensors()->Get(op->outputs()->Get(0)); - const QuantizationParameters* output_quant = - output_tensor->quantization(); - if (output_quant && output_quant->scale() && - output_quant->scale()->Length() > 1 && - output_quant->scale()->Length() == - output_tensor->shape()->Get( - output_quant->quantized_dimension())) { - op_sig.ext_options.quantize.is_per_channel_quantized = true; - } - } break; - - case BuiltinOperator_ADD: { - if (subgraph->tensors()->Get(op->inputs()->Get(0))->quantization()) { - op_sig.ext_options.add.input_quantized = true; - } - } break; - - case BuiltinOperator_EMBEDDING_LOOKUP: { - const Tensor* table_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - const QuantizationParameters* table_quant = table_tensor->quantization(); - if (table_quant && table_quant->scale() && table_quant->scale()->size() && - table_tensor->shape() && table_tensor->shape()->size()) { - op_sig.ext_options.embedding_lookup.is_per_channel_quantized = - table_quant->scale()->size() > 1 && - table_quant->scale()->size() == table_tensor->shape()->Get(0); - } - } break; - - default: - break; - } - - op_sig.inputs = GetOpSignatureTensorSpecs(op->inputs(), subgraph, model); - op_sig.outputs = GetOpSignatureTensorSpecs(op->outputs(), subgraph, model); - op_sig.version = op_code->version(); - return op_sig; -} - OpSignature GetOpSignature(const TfLiteContext* context, const TfLiteNode* node, const TfLiteRegistration* registration) { OpSignature op_sig = { diff --git a/tensorflow/lite/tools/versioning/op_signature.h b/tensorflow/lite/tools/versioning/op_signature.h index 6f83d119d5938f..b2dd3086c2d0d4 100644 --- a/tensorflow/lite/tools/versioning/op_signature.h +++ b/tensorflow/lite/tools/versioning/op_signature.h @@ -15,83 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ #define TENSORFLOW_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ -#include -#include - -#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" // iwyu pragma: export #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { -// OpSignature contains operator parameters for version functions. -typedef struct { - TfLiteType type; - std::vector dims; - bool is_const; - bool is_shape_dynamic; -} OpSignatureTensorSpec; - -typedef struct { - BuiltinOperator op; - std::vector inputs; - std::vector outputs; - void* builtin_data; - int version; - const void* custom_initial_data; - std::string custom_name; - union { - struct { - bool is_per_channel_quantized; - bool is_grouped_convolution; - } conv_2d; - struct { - bool is_per_channel_quantized; - } depthwise_conv_2d; - struct { - // TODO(b/156530611): Make this global when more ops support sparse - // computation. - bool sparse_weight; - bool is_per_channel_quantized; - } fully_connected; - struct { - float input1_scale; - float input2_scale; - float output_scale; - bool input_quantized; - } mul; - struct { - int32_t num_dims; - } strided_slice; - struct { - bool input_quantized; - } abs; - struct { - bool is_per_channel_quantized; - } dequantize; - struct { - bool is_per_channel_quantized; - } quantize; - struct { - bool input_quantized; - } add; - struct { - bool is_per_channel_quantized; - } embedding_lookup; - } ext_options; -} OpSignature; - -// Generate OpSignature with the given OperatorCode, Operator and Tensors (from -// SubGraph). The OpSignature will be used by GetBuiltinOperatorVersion() and -// mostly input and output tensor types are enough to figure out op version. -// But some ops (DEPTHWISE_CONV_2D, FULLY_CONNECTED, ...) require to pass their -// options to decide op version. -// -// WARNING: The caller is responsible to free the allocated -// OpSignature.builtin_data memory. -OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, - const SubGraph* subgraph, const Model* model); - // Generate OpSignature with the given TfLiteContext, TfLiteNode and // TfLiteRegistration. // The function can be used by a compatibility checker of a delegate such as From 5be2c060060db56be97a251775ae156cadff6ecc Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Tue, 24 Sep 2024 16:07:57 -0700 Subject: [PATCH 212/483] Ensure TFL_CAPI_EXPORT is always placed at the start of the line. Also remove use of TFL_CAPI_EXPORT for a struct type; it should only be applied to functions and variables/constants, not to types. PiperOrigin-RevId: 678432197 --- tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h index 62dfdbf9345d83..1ece2fbb74da4d 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h @@ -23,7 +23,7 @@ extern "C" { // Use TfLiteNnapiDelegateOptionsDefault() for Default options. // WARNING: This is an experimental API and subject to change. -typedef struct TFL_CAPI_EXPORT TfLiteNnapiDelegateOptions { +typedef struct TfLiteNnapiDelegateOptions { // Preferred Power/perf trade-off. For more details please see // ANeuralNetworksCompilation_setPreference documentation in : // https://developer.android.com/ndk/reference/group/neural-networks.html @@ -84,8 +84,8 @@ typedef struct TFL_CAPI_EXPORT TfLiteNnapiDelegateOptions { // Returns a delegate that uses NNAPI for ops execution. // Must outlive the interpreter. // WARNING: This is an experimental API and subject to change. -TfLiteDelegate* TFL_CAPI_EXPORT -TfLiteNnapiDelegateCreate(const TfLiteNnapiDelegateOptions* options); +TFL_CAPI_EXPORT TfLiteDelegate* TfLiteNnapiDelegateCreate( + const TfLiteNnapiDelegateOptions* options); // Returns TfLiteNnapiDelegateOptions populated with default values. // WARNING: This is an experimental API and subject to change. @@ -93,7 +93,7 @@ TFL_CAPI_EXPORT TfLiteNnapiDelegateOptions TfLiteNnapiDelegateOptionsDefault(); // Does any needed cleanup and deletes 'delegate'. // WARNING: This is an experimental API and subject to change. -void TFL_CAPI_EXPORT TfLiteNnapiDelegateDelete(TfLiteDelegate* delegate); +TFL_CAPI_EXPORT void TfLiteNnapiDelegateDelete(TfLiteDelegate* delegate); #ifdef __cplusplus } From a4d33071d50c53937f0bd01e45022c67f3bf80ab Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Tue, 24 Sep 2024 17:01:06 -0700 Subject: [PATCH 213/483] Grant executable permission to bash script file. PiperOrigin-RevId: 678448479 --- ci/official/containers/ml_build/setup.sources.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 ci/official/containers/ml_build/setup.sources.sh diff --git a/ci/official/containers/ml_build/setup.sources.sh b/ci/official/containers/ml_build/setup.sources.sh old mode 100644 new mode 100755 From 9d41410a2d6fda234a7fb79a1916709b60df4690 Mon Sep 17 00:00:00 2001 From: Arturo Schmidt Date: Tue, 24 Sep 2024 17:13:01 -0700 Subject: [PATCH 214/483] Add base test to ConvertMlirToGraph. PiperOrigin-RevId: 678451938 --- tensorflow/compiler/mlir/tf2xla/api/v2/BUILD | 31 +++++ .../api/v2/testdata/valid_executor.mlir | 10 ++ .../tf2xla/api/v2/testdata/valid_graph.txt | 44 +++++++ .../api/v2/tf_executor_to_graph_test.cc | 110 ++++++++++++++++++ 4 files changed, 195 insertions(+) create mode 100644 tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_executor.mlir create mode 100644 tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_graph.txt create mode 100644 tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph_test.cc diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index 31188f4456f711..1d264445bac1d9 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -272,3 +272,34 @@ cc_library( "@local_xla//xla:status_macros", ], ) + +tf_cc_test( + name = "tf_executor_to_graph_test", + srcs = ["tf_executor_to_graph_test.cc"], + data = [ + "testdata/valid_executor.mlir", + "testdata/valid_graph.txt", + ], + deps = [ + ":tf_executor_to_graph", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:resource_loader", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/tsl/lib/core:status_test_util", + "@riegeli//riegeli/bytes:fd_reader", + "@riegeli//riegeli/bytes:read_all", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_executor.mlir b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_executor.mlir new file mode 100644 index 00000000000000..3db375e788a033 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_executor.mlir @@ -0,0 +1,10 @@ + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() { + tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", dtype = "tfdtype$DT_INT32", value = #tf_type : tensor<2xi32>} : () -> tensor<2xi32> loc("Empty/shape") + tf_executor.fetch + } + func.return + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_graph.txt b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_graph.txt new file mode 100644 index 00000000000000..4eed21fedff195 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_graph.txt @@ -0,0 +1,44 @@ + node { + name: "Empty/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:TPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\200\000\000\000" + } + } + } + experimental_debug_info { + } +} +library { +} +versions { + producer: 268 +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph_test.cc new file mode 100644 index 00000000000000..9a53e51e4c71e1 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" + +#include + +#include +#include + +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "riegeli/bytes/fd_reader.h" // from @riegeli +#include "riegeli/bytes/read_all.h" // from @riegeli +#include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tsl/platform/protobuf.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace { + +using mlir::DialectRegistry; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::OwningOpRef; + +std::string TestDataPath() { + return tensorflow::GetDataDependencyFilepath( + "tensorflow/compiler/mlir/tf2xla/api/v2/testdata/"); +} + +class TfExecutorToGraphTest : public ::testing::Test { + public: + TfExecutorToGraphTest() { + mlir::RegisterCommonToolingDialects(registry_); + context_.appendDialectRegistry(registry_); + context_.loadAllAvailableDialects(); + } + + absl::StatusOr> CreateMlirModule( + std::string mlir_module_filename) { + std::string mlir_module_path = TestDataPath() + mlir_module_filename; + return mlir::parseSourceFile(mlir_module_path, &context_); + } + + GraphDef CreateGraphDef(std::string graphdef_filename) { + std::string file_path = TestDataPath() + graphdef_filename; + std::string contents; + GraphDef graph_def; + auto status = riegeli::ReadAll(riegeli::FdReader(file_path), contents); + if (!status.ok()) { + return graph_def; + } + tsl::protobuf::TextFormat::ParseFromString(contents, &graph_def); + return graph_def; + } + + DialectRegistry registry_; + MLIRContext context_; + OwningOpRef mlir_module_; +}; + +TEST_F(TfExecutorToGraphTest, ConvertMlirToGraphSucceeds) { + auto valid_executor_module = CreateMlirModule("valid_executor.mlir"); + GraphExportConfig confs; + absl::flat_hash_set control_ret_nodes; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), + FunctionDefLibrary()); + auto result_graph = std::make_unique(flib_def); + + TF_ASSERT_OK(ConvertTfExecutorToGraph(valid_executor_module.value().get(), + confs, &result_graph, &flib_def, + &control_ret_nodes)); + + GraphDef result_graphdef; + result_graph->ToGraphDef(&result_graphdef); + GraphDef expected_graphdef = CreateGraphDef("valid_graph.txt"); + EXPECT_EQ(result_graphdef.DebugString(), expected_graphdef.DebugString()); +} + +} // namespace +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow From 40c94adf57b614b1d98faad3b9281682d0de16c7 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh Date: Tue, 24 Sep 2024 17:21:57 -0700 Subject: [PATCH 215/483] Call `super().build(input_shape)` instead of `self.built = True` in all Keras layers. Within `build()`, some Keras layers where calling `super().build(input_shape)` while some were calling `self.built = True`. This would result in a different config when serializing whereby layers doing `self.built = True` would not have a `build_config`. This change makes it consistent between all the layers as well as consistent with Keras 3. Note that some layers need to call `Layer.build(self, input_shape)` directly to bypass some class' `build()` but still populate the information for the `build_config`. PiperOrigin-RevId: 678454186 --- .../api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt | 2 +- .../api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt | 2 +- .../golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt | 2 +- .../api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt | 2 +- .../golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index d1f11277be17f4..2128a8e80f3739 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -177,7 +177,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index 35b7e32aecce43..404d0400e36c5c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -181,7 +181,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index e42c2702cec45c..8fb9c75673d610 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -176,7 +176,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index 4e0851595da6d7..071db3a8abf79f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -175,7 +175,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index f7b61918ff1cc0..1dd1ee3ea574e5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -177,7 +177,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" From 05abf38d0a8506d355b5012536193ab233180ad2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 17:47:11 -0700 Subject: [PATCH 216/483] Reverts 2bd80ea311a6f745bce72a375b5538d6413f66e2 PiperOrigin-RevId: 678460954 --- tensorflow/compiler/mlir/lite/BUILD | 2 +- .../compiler/mlir/lite/flatbuffer_export.cc | 4 +- .../compiler/mlir/lite/tools/versioning/BUILD | 87 - .../lite/tools/versioning/op_signature.cc | 264 --- .../mlir/lite/tools/versioning/op_signature.h | 96 -- .../tools/versioning/op_signature_test.cc | 99 -- .../mlir/lite/tools/versioning/op_version.cc | 1134 ------------- .../mlir/lite/tools/versioning/op_version.h | 33 - .../lite/tools/versioning/op_version_test.cc | 1435 ----------------- .../lite/tools/versioning/runtime_version.cc | 510 ------ .../lite/tools/versioning/runtime_version.h | 40 - tensorflow/lite/CMakeLists.txt | 2 - tensorflow/lite/kernels/CMakeLists.txt | 1 - tensorflow/lite/toco/tflite/BUILD | 6 +- .../lite/toco/tflite/builtin_operator.h | 1 - tensorflow/lite/toco/tflite/export.cc | 2 +- tensorflow/lite/toco/tflite/operator.cc | 4 +- tensorflow/lite/toco/tflite/operator.h | 3 +- tensorflow/lite/toco/tflite/simple_operator.h | 1 - .../lite/tools/benchmark/CMakeLists.txt | 1 - tensorflow/lite/tools/versioning/BUILD | 1 - .../lite/tools/versioning/op_signature.cc | 241 ++- .../lite/tools/versioning/op_signature.h | 74 +- 23 files changed, 320 insertions(+), 3721 deletions(-) delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/BUILD delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_signature_test.cc delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_version.h delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc delete mode 100644 tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index b938f25519809b..3da89f496218a3 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1329,7 +1329,6 @@ cc_library( "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", - "//tensorflow/compiler/mlir/lite/tools/versioning", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", @@ -1338,6 +1337,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 869173bb2f8d99..fad3e5c1409372 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -96,8 +96,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" -#include "tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h" #include "tensorflow/compiler/mlir/lite/utils/control_edges.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" @@ -119,6 +117,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/tstring.h" +#include "tensorflow/lite/tools/versioning/op_version.h" +#include "tensorflow/lite/tools/versioning/runtime_version.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/status.h" #include "tsl/platform/tstring.h" diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/BUILD b/tensorflow/compiler/mlir/lite/tools/versioning/BUILD deleted file mode 100644 index 8cb1f84debe84d..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/BUILD +++ /dev/null @@ -1,87 +0,0 @@ -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", -) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -cc_library( - name = "versioning", - srcs = [ - "op_version.cc", - "runtime_version.cc", - ], - hdrs = [ - "op_version.h", - "runtime_version.h", - ], - compatible_with = get_compatible_with_portable(), - deps = [ - ":op_signature", - "//tensorflow/compiler/mlir/lite/core/c:tflite_common", - "//tensorflow/compiler/mlir/lite/kernels/internal:compatibility_macros", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", - "//tensorflow/compiler/mlir/lite/schema:schema_utils", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@flatbuffers", - ], -) - -tf_cc_test( - name = "versioning_test", - srcs = [ - "op_version_test.cc", - ], - deps = [ - ":op_signature", - ":versioning", - "//tensorflow/compiler/mlir/lite/core/c:tflite_common", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "op_signature", - srcs = [ - "op_signature.cc", - ], - hdrs = [ - "op_signature.h", - ], - compatible_with = get_compatible_with_portable(), - deps = [ - "//tensorflow/compiler/mlir/lite/core/api:flatbuffer_conversions", - "//tensorflow/compiler/mlir/lite/core/c:tflite_common", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/lite/schema:schema_utils", - "@flatbuffers//:runtime_cc", - ], -) - -tf_cc_test( - name = "op_signature_test", - srcs = [ - "op_signature_test.cc", - ], - data = [ - "//tensorflow/lite:testdata/add.bin", - "//tensorflow/lite:testdata/multi_signatures.bin", - ], - deps = [ - ":op_signature", - "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", - "//tensorflow/compiler/mlir/lite/core/c:tflite_common", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/core/platform:resource_loader", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc deleted file mode 100644 index 13a71443f6f0b5..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc +++ /dev/null @@ -1,264 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" - -#include -#include -#include -#include - -#include "flatbuffers/vector.h" // from @flatbuffers -#include "tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h" -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" - -namespace tflite { -namespace { - -// A BuiltinDataAllocator which just uses malloc()/free(). -class MallocDataAllocator : public tflite_migration::BuiltinDataAllocator { - public: - void* Allocate(size_t size, size_t alignment_hint) override { - return malloc(size); - } - void Deallocate(void* data) override { free(data); } -}; - -// Get the number of dimensions of a tensor with idx of an operator op. -inline int GetNumDims(const SubGraph* subgraph, const Operator* op, int idx) { - const flatbuffers::Vector* ret = - subgraph->tensors()->Get(op->inputs()->Get(idx))->shape(); - if (ret) { - return ret->size(); - } else { - return 0; - } -} - -std::vector GetOpSignatureTensorSpecs( - const flatbuffers::Vector* tensors, const SubGraph* subgraph, - const Model* model) { - std::vector tensor_specs; - if (!tensors) { - return tensor_specs; - } - - for (int32_t i = 0; i < tensors->size(); ++i) { - int32_t tensor_no = tensors->Get(i); - - OpSignatureTensorSpec tensor_spec = {kTfLiteNoType}; - if (tensor_no >= 0) { - if (subgraph->tensors() && tensor_no < subgraph->tensors()->size()) { - auto* fb_tensor = subgraph->tensors()->Get(tensor_no); - tflite_migration::ConvertTensorType(fb_tensor->type(), - &tensor_spec.type) - .IgnoreError(); - auto buffer_idx = fb_tensor->buffer(); - // Check if the tensor is a constant tensor. - if (buffer_idx != 0 && buffer_idx < model->buffers()->size()) { - auto* buffer = model->buffers()->Get(buffer_idx); - if (buffer->data() && buffer->data()->size() != 0) { - tensor_spec.is_const = true; - } - } - const flatbuffers::Vector* shape_vec = fb_tensor->shape(); - if (shape_vec) { - for (int32_t j = 0; j < shape_vec->size(); ++j) { - tensor_spec.dims.push_back(shape_vec->Get(j)); - } - } - const flatbuffers::Vector* shape_signature_vec = - fb_tensor->shape_signature(); - tensor_spec.is_shape_dynamic = false; - if (shape_signature_vec) { - for (int32_t j = 0; j < shape_signature_vec->size(); ++j) { - if (shape_signature_vec->Get(j) == -1) { - tensor_spec.is_shape_dynamic = true; - break; - } - } - } - } - } - tensor_specs.push_back(tensor_spec); - } - return tensor_specs; -} - -} // namespace - -OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, - const SubGraph* subgraph, const Model* model) { - auto builtin_code = GetBuiltinCode(op_code); - OpSignature op_sig = {builtin_code}; - std::memset(&op_sig.ext_options, 0, sizeof(op_sig.ext_options)); - - if (builtin_code != BuiltinOperator_CUSTOM) { - MallocDataAllocator allocator; - tflite_migration::ParseOpData(op, builtin_code, &allocator, - &op_sig.builtin_data) - .IgnoreError(); - } else { - op_sig.custom_name = op_code->custom_code()->str(); - } - - switch (builtin_code) { - case BuiltinOperator_DEPTHWISE_CONV_2D: { - const Tensor* filter_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - const QuantizationParameters* filter_quant = - filter_tensor->quantization(); - int num_channels = filter_tensor->shape()->Get(3); - if (filter_quant && filter_quant->scale() && - filter_quant->scale()->size() && - filter_quant->scale()->size() == num_channels) { - op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized = true; - } - } break; - - case BuiltinOperator_FULLY_CONNECTED: { - const Tensor* weight_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - op_sig.ext_options.fully_connected.sparse_weight = - (weight_tensor->sparsity() != nullptr); - const QuantizationParameters* weight_quant = - weight_tensor->quantization(); - if (weight_quant && weight_quant->scale() && - weight_quant->scale()->size() && weight_tensor->shape() && - weight_tensor->shape()->size()) { - op_sig.ext_options.fully_connected.is_per_channel_quantized = - weight_quant->scale()->size() > 1 && - weight_quant->scale()->size() == weight_tensor->shape()->Get(0); - } - } break; - - case BuiltinOperator_MUL: { - if (op->inputs()->size() < 2 || op->outputs()->size() < 1) { - break; - } - const Tensor* input1_tensor = - subgraph->tensors()->Get(op->inputs()->Get(0)); - const Tensor* input2_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - const Tensor* output_tensor = - subgraph->tensors()->Get(op->outputs()->Get(0)); - const QuantizationParameters* input1_quant = - input1_tensor->quantization(); - const QuantizationParameters* input2_qunt = input2_tensor->quantization(); - const QuantizationParameters* output_quant = - output_tensor->quantization(); - if (input1_quant && input1_quant->scale() && - input1_quant->scale()->size() && input2_qunt && - input2_qunt->scale() && input2_qunt->scale()->size() && - output_quant && output_quant->scale() && - output_quant->scale()->size()) { - op_sig.ext_options.mul.input1_scale = input1_quant->scale()->Get(0); - op_sig.ext_options.mul.input2_scale = input2_qunt->scale()->Get(0); - op_sig.ext_options.mul.output_scale = output_quant->scale()->Get(0); - } - if (input1_quant || input2_qunt) { - op_sig.ext_options.mul.input_quantized = true; - } - } break; - - case BuiltinOperator_CONV_2D: { - const Tensor* input_tensor = - subgraph->tensors()->Get(op->inputs()->Get(0)); - const Tensor* filter_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - const QuantizationParameters* filter_quant = - filter_tensor->quantization(); - int num_filters = filter_tensor->shape()->Get(0); - if (filter_quant && filter_quant->scale() && - filter_quant->scale()->size() && - filter_quant->scale()->size() == num_filters) { - op_sig.ext_options.conv_2d.is_per_channel_quantized = true; - } - if (input_tensor->shape() && input_tensor->shape()->size()) { - int num_input_channels = input_tensor->shape()->Get(3); - int num_filter_input_channels = filter_tensor->shape()->Get(3); - op_sig.ext_options.conv_2d.is_grouped_convolution = - num_input_channels != num_filter_input_channels; - } else { - op_sig.ext_options.conv_2d.is_grouped_convolution = false; - } - } break; - - case BuiltinOperator_STRIDED_SLICE: { - op_sig.ext_options.strided_slice.num_dims = GetNumDims(subgraph, op, 0); - } break; - - case BuiltinOperator_ABS: { - if (subgraph->tensors()->Get(op->inputs()->Get(0))->quantization()) { - op_sig.ext_options.abs.input_quantized = true; - } - } break; - - case BuiltinOperator_DEQUANTIZE: { - const Tensor* input_tensor = - subgraph->tensors()->Get(op->inputs()->Get(0)); - const QuantizationParameters* input_quant = input_tensor->quantization(); - if (input_quant && input_quant->scale() && - input_quant->scale()->size() > 1 && - input_quant->scale()->size() == - input_tensor->shape()->Get(input_quant->quantized_dimension())) { - op_sig.ext_options.dequantize.is_per_channel_quantized = true; - } - } break; - - case BuiltinOperator_QUANTIZE: { - const Tensor* output_tensor = - subgraph->tensors()->Get(op->outputs()->Get(0)); - const QuantizationParameters* output_quant = - output_tensor->quantization(); - if (output_quant && output_quant->scale() && - output_quant->scale()->size() > 1 && - output_quant->scale()->size() == - output_tensor->shape()->Get( - output_quant->quantized_dimension())) { - op_sig.ext_options.quantize.is_per_channel_quantized = true; - } - } break; - - case BuiltinOperator_ADD: { - if (subgraph->tensors()->Get(op->inputs()->Get(0))->quantization()) { - op_sig.ext_options.add.input_quantized = true; - } - } break; - - case BuiltinOperator_EMBEDDING_LOOKUP: { - const Tensor* table_tensor = - subgraph->tensors()->Get(op->inputs()->Get(1)); - const QuantizationParameters* table_quant = table_tensor->quantization(); - if (table_quant && table_quant->scale() && table_quant->scale()->size() && - table_tensor->shape() && table_tensor->shape()->size()) { - op_sig.ext_options.embedding_lookup.is_per_channel_quantized = - table_quant->scale()->size() > 1 && - table_quant->scale()->size() == table_tensor->shape()->Get(0); - } - } break; - - default: - break; - } - - op_sig.inputs = GetOpSignatureTensorSpecs(op->inputs(), subgraph, model); - op_sig.outputs = GetOpSignatureTensorSpecs(op->outputs(), subgraph, model); - op_sig.version = op_code->version(); - return op_sig; -} - -} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h deleted file mode 100644 index 5799194f8770e7..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ - -#include -#include -#include - -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" - -namespace tflite { - -// OpSignature contains operator parameters for version functions. -typedef struct { - TfLiteType type; - std::vector dims; - bool is_const; - bool is_shape_dynamic; -} OpSignatureTensorSpec; - -typedef struct { - BuiltinOperator op; - std::vector inputs; - std::vector outputs; - void* builtin_data; - int version; - const void* custom_initial_data; - std::string custom_name; - union { - struct { - bool is_per_channel_quantized; - bool is_grouped_convolution; - } conv_2d; - struct { - bool is_per_channel_quantized; - } depthwise_conv_2d; - struct { - // TODO(b/156530611): Make this global when more ops support sparse - // computation. - bool sparse_weight; - bool is_per_channel_quantized; - } fully_connected; - struct { - float input1_scale; - float input2_scale; - float output_scale; - bool input_quantized; - } mul; - struct { - int32_t num_dims; - } strided_slice; - struct { - bool input_quantized; - } abs; - struct { - bool is_per_channel_quantized; - } dequantize; - struct { - bool is_per_channel_quantized; - } quantize; - struct { - bool input_quantized; - } add; - struct { - bool is_per_channel_quantized; - } embedding_lookup; - } ext_options; -} OpSignature; - -// Generate OpSignature with the given OperatorCode, Operator and Tensors (from -// SubGraph). The OpSignature will be used by GetBuiltinOperatorVersion() and -// mostly input and output tensor types are enough to figure out op version. -// But some ops (DEPTHWISE_CONV_2D, FULLY_CONNECTED, ...) require to pass their -// options to decide op version. -// -// WARNING: The caller is responsible to free the allocated -// OpSignature.builtin_data memory. -OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, - const SubGraph* subgraph, const Model* model); - -} // namespace tflite -#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature_test.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature_test.cc deleted file mode 100644 index e3db7b8da8ca49..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" - -#include -#include -#include -#include - -#include -#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/core/platform/resource_loader.h" - -namespace tflite { - -TEST(GetOpSignature, FlatBufferModel) { - const std::string& full_path = - tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin"); - auto fb_model = - mlir::TFL::FlatBufferModelAbslError::BuildFromFile(full_path.data()); - ASSERT_TRUE(fb_model); - auto model = fb_model->GetModel(); - auto subgraphs = model->subgraphs(); - const SubGraph* subgraph = subgraphs->Get(0); - const Operator* op1 = subgraph->operators()->Get(0); - const OperatorCode* op_code1 = - model->operator_codes()->Get(op1->opcode_index()); - OpSignature op_sig = GetOpSignature(op_code1, op1, subgraph, model); - EXPECT_EQ(op_sig.op, BuiltinOperator_ADD); - EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32); - EXPECT_EQ(op_sig.inputs[0].dims.size(), 4); - EXPECT_FALSE(op_sig.inputs[0].is_const); - EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic); - EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32); - EXPECT_FALSE(op_sig.outputs[0].is_const); - EXPECT_EQ(op_sig.outputs[0].dims.size(), 4); - EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic); - EXPECT_NE(op_sig.builtin_data, nullptr); - EXPECT_EQ(op_sig.version, 1); - free(op_sig.builtin_data); - - const Operator* op2 = subgraph->operators()->Get(1); - const OperatorCode* op_code2 = - model->operator_codes()->Get(op2->opcode_index()); - op_sig = GetOpSignature(op_code2, op2, subgraph, model); - EXPECT_EQ(op_sig.op, BuiltinOperator_ADD); - EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32); - EXPECT_EQ(op_sig.inputs[0].dims.size(), 4); - EXPECT_FALSE(op_sig.inputs[0].is_const); - EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic); - EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32); - EXPECT_FALSE(op_sig.outputs[0].is_const); - EXPECT_EQ(op_sig.outputs[0].dims.size(), 4); - EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic); - EXPECT_NE(op_sig.builtin_data, nullptr); - EXPECT_EQ(op_sig.version, 1); - free(op_sig.builtin_data); - - const std::string& full_path3 = tensorflow::GetDataDependencyFilepath( - "tensorflow/lite/testdata/multi_signatures.bin"); - auto fb_model3 = - mlir::TFL::FlatBufferModelAbslError::BuildFromFile(full_path3.data()); - ASSERT_TRUE(fb_model3); - auto model3 = fb_model3->GetModel(); - auto subgraphs3 = model3->subgraphs(); - const SubGraph* subgraph3 = subgraphs3->Get(0); - const Operator* op3 = subgraph3->operators()->Get(0); - const OperatorCode* op_code3 = - model3->operator_codes()->Get(op3->opcode_index()); - op_sig = GetOpSignature(op_code3, op3, subgraph3, model3); - EXPECT_EQ(op_sig.op, BuiltinOperator_ADD); - EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32); - EXPECT_EQ(op_sig.inputs[0].dims.size(), 1); - EXPECT_FALSE(op_sig.inputs[0].is_const); - EXPECT_TRUE(op_sig.inputs[0].is_shape_dynamic); - EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32); - EXPECT_FALSE(op_sig.outputs[0].is_const); - EXPECT_EQ(op_sig.outputs[0].dims.size(), 1); - EXPECT_TRUE(op_sig.outputs[0].is_shape_dynamic); - EXPECT_NE(op_sig.builtin_data, nullptr); - EXPECT_EQ(op_sig.version, 1); - free(op_sig.builtin_data); -} - -} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc deleted file mode 100644 index c0c1da3a158761..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc +++ /dev/null @@ -1,1134 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" - -#include -#include -#include - -#include "absl/log/log.h" -#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" -#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" - -namespace tflite { -namespace { - -bool NeedBroadcastForBinaryInputs(const OpSignature& op_sig) { - if (op_sig.inputs.size() < 2) { - return false; - } - return (op_sig.inputs.at(0).dims != op_sig.inputs.at(1).dims); -} - -int GetInputMaxDims(const OpSignature& op_sig) { - int max_dims = 0; - for (auto& input : op_sig.inputs) { - if (input.dims.size() > max_dims) { - max_dims = input.dims.size(); - } - } - return max_dims; -} - -} // namespace - -int GetBuiltinOperatorVersion(const OpSignature& op_sig) { - switch (op_sig.op) { - case BuiltinOperator_CONV_2D: { - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - // `quantized_bias_type` is supported at version 8. - auto conv_params = - reinterpret_cast(op_sig.builtin_data); - TFLITE_DCHECK(conv_params != nullptr); - if (conv_params->quantized_bias_type) { - return 8; - } - } - - if (op_sig.ext_options.conv_2d.is_grouped_convolution) { - return 6; - } - // If the op has signed int16 op_sig.inputs and op_sig.outputs, its - // version 4. - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(1).type == kTfLiteInt16 && - op_sig.outputs.at(1).type == kTfLiteInt16) { - return 4; - } - - // If the op has signed int8 op_sig.inputs and op_sig.outputs, its - // version 3. - if (op_sig.inputs.at(0).type == kTfLiteInt8 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteInt8) { - return 3; - } - // If the op has signed int8 and int4 op_sig.inputs and op_sig.outputs, - // its version 7. - if (op_sig.inputs.at(0).type == kTfLiteInt8 && - op_sig.inputs.at(1).type == kTfLiteInt4 && - op_sig.outputs.at(0).type == kTfLiteInt8) { - return 7; - } - // If the op is a signed int8 hybrid operation, we need to return - // version 2 or 5 if per channel. - if (op_sig.inputs.at(0).type == kTfLiteFloat32 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - if (op_sig.ext_options.conv_2d.is_per_channel_quantized) { - return 5; - } - return 2; - } - return 1; - } - case BuiltinOperator_DEPTHWISE_CONV_2D: { - // If the op accepts int16, we return version 5. - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(1).type == kTfLiteInt16 && - op_sig.outputs.at(1).type == kTfLiteInt16) { - return 5; - } - - // If the op is a signed int8 hybrid operation, we need to return - // version 4 or 6 if per-channel. - if (op_sig.inputs.at(0).type == kTfLiteFloat32 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - if (op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized) { - return 6; - } - return 4; - } - // If the op has signed int8 op_sig.inputs and op_sig.outputs, its - // version 3. - if (op_sig.inputs.at(0).type == kTfLiteInt8 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteInt8) { - return 3; - } - - // If the op has signed int8 and int4 op_sig.inputs and op_sig.outputs, - // its version 7. - if (op_sig.inputs.at(0).type == kTfLiteInt8 && - op_sig.inputs.at(1).type == kTfLiteInt4 && - op_sig.outputs.at(0).type == kTfLiteInt8) { - return 7; - } - - auto depthwise_conv_params = - reinterpret_cast(op_sig.builtin_data); - TFLITE_DCHECK(depthwise_conv_params != nullptr); - if (depthwise_conv_params->dilation_width_factor != 1 || - depthwise_conv_params->dilation_height_factor != 1) { - return 2; - } - return 1; - } - - case BuiltinOperator_EMBEDDING_LOOKUP: { - if (op_sig.inputs.at(1).type == kTfLiteInt4 || - op_sig.ext_options.embedding_lookup.is_per_channel_quantized) { - return 4; - } - return 1; - } - - case BuiltinOperator_FAKE_QUANT: { - auto fake_quant_params = - reinterpret_cast(op_sig.builtin_data); - TFLITE_DCHECK(fake_quant_params != nullptr); - if (fake_quant_params->narrow_range) { - return 2; - } - return 1; - } - - case BuiltinOperator_FULLY_CONNECTED: { - // +-----------------+--------------------+--------------------------+ - // | | Weight::Default | Weight::Shuffled4x16Int8 | - // +-----------------+--------------------+--------------------------+ - // | Float | 1 | 2 | - // | Quantized Uint8 | 1 | 2 | - // | Hybrid | 3 | 3 | - // | Quantized Int8 | 4 | 4 | - // +-----------------+--------------------+--------------------------+ - - auto fully_connected_params = - reinterpret_cast(op_sig.builtin_data); - TFLITE_DCHECK(fully_connected_params != nullptr); - - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(1).type == kTfLiteInt4 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 13; - } - - if (op_sig.inputs.at(0).type == kTfLiteFloat32 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32 && - op_sig.ext_options.fully_connected.is_per_channel_quantized) { - return 12; - } - - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - // `quantized_bias_type` is supported at version 11. - if (fully_connected_params->quantized_bias_type) { - return 11; - } - } - - // FullyConnected with sparse weight is supported at version 8. - if (op_sig.ext_options.fully_connected.sparse_weight) { - return 8; - } - - // Int16 fully fixed point kernel is at version 7. - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(1).type == kTfLiteInt16 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 7; - } - - // 2 op_sig.inputs (no bias) use case is supported starting from - // version 6. - if (op_sig.inputs.size() == 2) { - return 6; - } - // `keep_num_dims` is supported at version 5. - if (fully_connected_params->keep_num_dims) { - return 5; - } - // Int8 fully fixed point kernel is at version 4. - if (op_sig.inputs.at(0).type == kTfLiteInt8 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteInt8) { - return 4; - } - - // If the op has signed int8 and int4 op_sig.inputs and op_sig.outputs, - // its version 7. - if (op_sig.inputs.at(0).type == kTfLiteInt8 && - op_sig.inputs.at(1).type == kTfLiteInt4 && - op_sig.outputs.at(0).type == kTfLiteInt8) { - return 10; - } - - // If the op is a signed int8 hybrid operation, we need to return - // version 3. - if (op_sig.inputs.at(0).type == kTfLiteFloat32 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - if (fully_connected_params->asymmetric_quantize_inputs) { - // This is to use the updated quantization scheme. - return 9; - } - return 3; - } - // For float and uint8 fixed point kernels, if the weight is - // Shuffled4x16Int8, it is version 2. - if (fully_connected_params->weights_format == - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) { - return 2; - } - // Otherwise (weight is default), the version is 1. - return 1; - } - - case BuiltinOperator_GATHER: { - if (op_sig.inputs.at(0).type == kTfLiteInt4) { - return 7; - } - if (op_sig.inputs.at(1).type == kTfLiteInt16) { - return 6; - } - auto gather_params = - reinterpret_cast(op_sig.builtin_data); - if (gather_params && gather_params->batch_dims != 0) { - return 5; - } - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 4; - } - // If the op takes bool input, it is version 3. - if (op_sig.inputs.at(0).type == kTfLiteBool) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - } - - case BuiltinOperator_SVDF: { - // Fully integer SVDF has int8 as input and is of version 3. - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 3; - } - // If the op is a signed int8 hybrid operation, we need to return - // version 2. - if (op_sig.inputs.at(0).type == kTfLiteFloat32 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - auto svdf_params = - reinterpret_cast(op_sig.builtin_data); - // This is to use the updated quantization scheme - if (svdf_params && svdf_params->asymmetric_quantize_inputs) { - return 4; - } - return 2; - } - return 1; - } - - case BuiltinOperator_SIGN: - // Version 2 supports int32 inputs - if (op_sig.inputs.at(0).type == kTfLiteInt32) { - return 2; - } - return 1; - - case BuiltinOperator_MUL: - // Version 7 supports int16 and uint32 inputs - if ((op_sig.inputs.at(0).type == kTfLiteInt16 && - !op_sig.ext_options.mul.input_quantized) || - op_sig.inputs.at(0).type == kTfLiteUInt32) { - return 7; - } - // Version 6 supports complex32 inputs - if (op_sig.inputs.at(0).type == kTfLiteComplex64) { - return 6; - } - // Version 5 supports int64 inputs - if (op_sig.inputs.at(0).type == kTfLiteInt64) { - return 5; - } - // Version 4 supports int16 inputs - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 4; - } - // Version 3 supports have a rescale value greater than or equal to 1. - if (op_sig.ext_options.mul.input1_scale != 0 && - op_sig.ext_options.mul.input2_scale != 0 && - op_sig.ext_options.mul.output_scale != 0 && - (op_sig.ext_options.mul.input1_scale * - op_sig.ext_options.mul.input2_scale / - op_sig.ext_options.mul.output_scale) >= 1.0) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_MAX_POOL_2D: - case BuiltinOperator_AVERAGE_POOL_2D: - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 3; - } - - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_TRANSPOSE: - if (op_sig.inputs.at(0).dims.size() > 5) { - return 6; - } - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 5; - } - if (op_sig.inputs.at(0).dims.size() > 4) { - return 4; - } - // If the op takes bool input, it is version 3. - if (op_sig.inputs.at(0).type == kTfLiteBool) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_TRANSPOSE_CONV: { - auto transpose_conv_params = - reinterpret_cast(op_sig.builtin_data); - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - // `quantized_bias_type` is supported at version 5. - TFLITE_DCHECK(transpose_conv_params != nullptr); - if (transpose_conv_params->quantized_bias_type) { - return 5; - } - } - - // TransposeConvOp has fused activation function from version 4. - if (transpose_conv_params != nullptr && - transpose_conv_params->activation) { - return 4; - } - - if (op_sig.inputs.size() == 4 && - op_sig.inputs.at(3).type != kTfLiteNoType) { - return 3; - } - // If the op takes int8 input, it is version 2. - if (op_sig.inputs.at(1).type == kTfLiteInt8) { - return 2; - } - return 1; - } - - case BuiltinOperator_LSTM: { - auto lstm_params = - reinterpret_cast(op_sig.builtin_data); - // If the input activation and output tensor are int16 and a weight is - // int8, this is a version 5. - if (lstm_params->kernel_type == kTfLiteLSTMFullKernel && - op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(2).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 5; - } - // If the input tensor is float and a weight is int8, this is a version - // 3 hybrid operation. - TFLITE_DCHECK(lstm_params != nullptr); - if (lstm_params->kernel_type == kTfLiteLSTMFullKernel && - op_sig.inputs.at(0).type == kTfLiteFloat32 && - op_sig.inputs.at(2).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - if (lstm_params->asymmetric_quantize_inputs) { - return 4; - } - return 3; - } - // KERNEL_BASIC was added in version 2. - if (lstm_params->kernel_type == kTfLiteLSTMBasicKernel) { - return 2; - } - return 1; - } - - case BuiltinOperator_SPLIT: - // If the op take in16 input, it is version 4. - if (op_sig.inputs.at(1).type == kTfLiteInt16) { - return 4; - } - // If the op take int8 input, it is version 2, for int32 it's version 3. - // The input tensor is at index 1 not 0, 0 is the axis. - if (op_sig.inputs.at(1).type == kTfLiteInt32) { - return 3; - } - if (op_sig.inputs.at(1).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_SPARSE_TO_DENSE: - // Version 3 supports Int8 and Uint8 type. - if (op_sig.inputs.at(2).type == kTfLiteInt8 || - op_sig.inputs.at(2).type == kTfLiteUInt8) { - return 3; - } - // Version 2 supports Int64 value type. - if (op_sig.inputs.at(2).type == kTfLiteInt64) { - return 2; - } - return 1; - - case BuiltinOperator_SLICE: - if (op_sig.inputs.at(0).type == kTfLiteUInt32) { - return 6; - } - if (op_sig.inputs.at(0).dims.size() > 4) { - return 5; - } - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 4; - } - // Version 3 supports string input types. - if (op_sig.inputs.at(0).type == kTfLiteString) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_UNPACK: - // If the op take int8/uint8 input, it is version 2. - if (op_sig.inputs.at(0).type == kTfLiteInt8 || - op_sig.inputs.at(0).type == kTfLiteUInt8) { - return 2; - } - // If the op take bool input, it is version 3. - if (op_sig.inputs.at(0).type == kTfLiteBool) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 4; - } - return 1; - - case BuiltinOperator_DEQUANTIZE: - if (op_sig.inputs.at(0).type == kTfLiteInt4) { - return 6; - } - // Version 3 supports signed int16 input types. - if (op_sig.inputs.at(0).type == kTfLiteInt16 || - op_sig.inputs.at(0).type == kTfLiteFloat16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - if (op_sig.ext_options.dequantize.is_per_channel_quantized) { - return 5; - } - return 2; - } - return 1; - - case BuiltinOperator_QUANTIZE: - if (op_sig.inputs.at(0).type == kTfLiteInt4 || - op_sig.outputs.at(0).type == kTfLiteInt4) { - return 4; - } - if (op_sig.ext_options.quantize.is_per_channel_quantized) { - return 3; - } - if (op_sig.outputs.at(0).type == kTfLiteInt16) { - return 2; - } - return 1; - - case BuiltinOperator_FLOOR_DIV: - if (op_sig.inputs.at(0).type == kTfLiteInt16 || - op_sig.inputs.at(0).type == kTfLiteInt8) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteFloat32) { - return 2; - } - return 1; - - case BuiltinOperator_FLOOR_MOD: - if (op_sig.inputs.at(0).type == kTfLiteInt16 || - op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_L2_NORMALIZATION: - if (op_sig.outputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_ABS: - // Version 5 supports int32 - if (op_sig.inputs.at(0).type == kTfLiteInt32) { - return 5; - } - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return op_sig.ext_options.abs.input_quantized ? 3 : 4; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8 || - op_sig.inputs.at(0).type == kTfLiteUInt8) { - return 2; - } - return 1; - case BuiltinOperator_RELU: - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8 || - op_sig.inputs.at(0).type == kTfLiteUInt8) { - return 2; - } - return 1; - - case BuiltinOperator_STRIDED_SLICE: { - auto strided_slice_params = - reinterpret_cast(op_sig.builtin_data); - TFLITE_DCHECK(strided_slice_params != nullptr); - if (strided_slice_params->offset == true) { - return 8; - } - if (op_sig.inputs.at(0).type == kTfLiteUInt32) { - return 7; - } - if (strided_slice_params->ellipsis_mask != 0 || - strided_slice_params->new_axis_mask != 0) { - return 6; - } - if (op_sig.inputs.at(0).type == kTfLiteString) { - return 5; - } - if (op_sig.ext_options.strided_slice.num_dims > 4) { - return 4; - } - // If the op takes bool input, it is version 3. - if (op_sig.inputs.at(0).type == kTfLiteBool) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - } - case BuiltinOperator_REVERSE_V2: - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteBool) { - return 2; - } - return 1; - case BuiltinOperator_RESIZE_BILINEAR: { - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 4; - } - auto resize_bilinear_params = - reinterpret_cast(op_sig.builtin_data); - TFLITE_DCHECK(resize_bilinear_params != nullptr); - if (resize_bilinear_params->half_pixel_centers) { - return 3; - } else if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - } - case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: { - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 4; - } - auto resize_nearest_neighbor_params = - reinterpret_cast( - op_sig.builtin_data); - TFLITE_DCHECK(resize_nearest_neighbor_params != nullptr); - if (resize_nearest_neighbor_params->half_pixel_centers || - resize_nearest_neighbor_params->align_corners) { - return 3; - } else if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - } - - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 4; - } - if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_PACK: - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteUInt32) { - return 4; - } - return 1; - - case BuiltinOperator_TILE: - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteString) { - return 2; - } - return 1; - - case BuiltinOperator_SQUEEZE: - if (op_sig.inputs.at(0).type == kTfLiteString) { - return 2; - } - return 1; - - case BuiltinOperator_SPACE_TO_BATCH_ND: - case BuiltinOperator_BATCH_TO_SPACE_ND: - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 4; - } - if (op_sig.inputs.at(0).dims.size() != 4) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_ADD: { - if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt16 && - !op_sig.ext_options.add.input_quantized) { - return 5; - } - if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt64) { - return 4; - } - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - auto add_params = - reinterpret_cast(op_sig.builtin_data); - if (add_params && !add_params->pot_scale_int16) { - return 3; - } - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - } - - case BuiltinOperator_SUB: { - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - auto sub_params = - reinterpret_cast(op_sig.builtin_data); - if (sub_params && !sub_params->pot_scale_int16) { - return 5; - } - } - if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt64) { - return 4; - } - if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - } - - case BuiltinOperator_GATHER_ND: - if (op_sig.inputs.at(0).type == kTfLiteBool) { - return 5; - } - if (op_sig.inputs.at(1).type == kTfLiteInt16) { - return 4; - } - if (!op_sig.inputs.empty() && - (op_sig.inputs.at(0).type == kTfLiteInt16)) { - return 3; - } - if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteString) { - return 2; - } - return 1; - - case BuiltinOperator_DIV: - if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) { - return 2; - } - return 1; - case BuiltinOperator_TANH: - case BuiltinOperator_LOGISTIC: - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 3; - } - - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_FILL: - if (op_sig.inputs.size() >= 2) { - if (op_sig.inputs.at(1).type == kTfLiteFloat16) return 4; - if (op_sig.inputs.at(1).type == kTfLiteInt8 || - op_sig.inputs.at(1).type == kTfLiteInt16) { - return 3; - } else if ((op_sig.inputs.at(1).type == kTfLiteBool || - op_sig.inputs.at(1).type == kTfLiteString)) { - return 2; - } - } - return 1; - - case BuiltinOperator_EQUAL: - if (!op_sig.inputs.empty()) { - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 4; - } - if (op_sig.inputs.at(0).type == kTfLiteString) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - } - return 1; - case BuiltinOperator_NOT_EQUAL: - if (!op_sig.inputs.empty()) { - if (op_sig.inputs.at(0).type == kTfLiteString) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - } - return 1; - - case BuiltinOperator_LEAKY_RELU: - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 2; - } - return 1; - - case BuiltinOperator_RANGE: - if (op_sig.inputs.at(0).type == kTfLiteInt64) { - return 2; - } - return 1; - - case BuiltinOperator_BATCH_MATMUL: { - // In case of int16 inputs, the version is 3. - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - if (op_sig.inputs.at(0).type == kTfLiteFloat32 && - op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - auto batch_mat_mul_params = - reinterpret_cast(op_sig.builtin_data); - if (batch_mat_mul_params && - batch_mat_mul_params->asymmetric_quantize_inputs) { - // This is to use the updated quantization scheme. - return 4; - } - } - return 1; - } - - case BuiltinOperator_PAD: - case BuiltinOperator_PADV2: - if (op_sig.inputs.at(0).dims.size() > 4) { - return 4; - } - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_CONCATENATION: - if (op_sig.inputs.at(0).type == kTfLiteUInt32) { - return 4; - } - // In case of int16 inputs, the version is 3. - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_SOFTMAX: - case BuiltinOperator_MEAN: - case BuiltinOperator_MIRROR_PAD: - case BuiltinOperator_REDUCE_MAX: - case BuiltinOperator_REDUCE_MIN: - case BuiltinOperator_RELU6: - case BuiltinOperator_RSQRT: - // In case of int16 inputs, the version is 3. - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_RNN: { - if (op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - auto rnn_params = - reinterpret_cast(op_sig.builtin_data); - if (rnn_params && rnn_params->asymmetric_quantize_inputs) { - return 3; - } else { - return 2; - } - } - return 1; - } - - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - if (op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - auto sequence_rnn_params = - reinterpret_cast(op_sig.builtin_data); - if (sequence_rnn_params && - sequence_rnn_params->asymmetric_quantize_inputs) { - return 3; - } else { - return 2; - } - } - return 1; - } - - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { - if (op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - auto bidirectional_sequence_rnn_params = - reinterpret_cast( - op_sig.builtin_data); - if (bidirectional_sequence_rnn_params && - bidirectional_sequence_rnn_params->asymmetric_quantize_inputs) { - return 3; - } else { - return 2; - } - } - return 1; - } - - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { - if (op_sig.inputs.at(1).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - auto bidirectional_sequence_lstm_params = - reinterpret_cast( - op_sig.builtin_data); - if (bidirectional_sequence_lstm_params && - bidirectional_sequence_lstm_params->asymmetric_quantize_inputs) { - return 3; - } else { - return 2; - } - } - return 1; - } - - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { - auto unidirectional_sequence_lstm_params = - reinterpret_cast( - op_sig.builtin_data); - // If the input activation and output tensor are int16 and a weight is - // int8, this is a version 5. - if (op_sig.inputs.at(0).type == kTfLiteInt16 && - op_sig.inputs.at(2).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteInt16) { - return 5; - } - if (unidirectional_sequence_lstm_params && - unidirectional_sequence_lstm_params->diagonal_recurrent_tensors) { - return 4; - } - // If the input tensor is float and a weight is int8, this is a version - // 2 hybrid operation. - if (op_sig.inputs.at(0).type == kTfLiteFloat32 && - op_sig.inputs.at(2).type == kTfLiteInt8 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - if (unidirectional_sequence_lstm_params && - unidirectional_sequence_lstm_params->asymmetric_quantize_inputs) { - return 3; - } - return 2; - } - return 1; - } - - case BuiltinOperator_ARG_MAX: - case BuiltinOperator_ARG_MIN: - if (op_sig.inputs.at(0).type == kTfLiteBool) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_SELECT: { - if (op_sig.inputs.at(0).type == kTfLiteUInt32) { - return 4; - } - if (op_sig.inputs.at(0).dims.size() == 5 || - op_sig.inputs.at(1).dims.size() == 5 || - op_sig.inputs.at(2).dims.size() == 5) - return 3; - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - } - case BuiltinOperator_LESS: - case BuiltinOperator_GREATER_EQUAL: { - if (op_sig.inputs.at(0).type == kTfLiteInt16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - } - case BuiltinOperator_SELECT_V2: { - if (op_sig.inputs.at(0).type == kTfLiteUInt32) { - return 2; - } - return 1; - } - case BuiltinOperator_SPACE_TO_DEPTH: - case BuiltinOperator_SPLIT_V: - case BuiltinOperator_SUM: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_GREATER: - case BuiltinOperator_LESS_EQUAL: - case BuiltinOperator_SQUARED_DIFFERENCE: - case BuiltinOperator_DEPTH_TO_SPACE: - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - case BuiltinOperator_TOPK_V2: - if (op_sig.inputs.at(0).type == kTfLiteInt16 || - op_sig.inputs.at(1).type == kTfLiteInt16 || - op_sig.outputs.at(1).type == kTfLiteInt16) { - return 3; - } - if (op_sig.inputs.at(0).type == kTfLiteInt8) { - return 2; - } - return 1; - - case BuiltinOperator_EXP: - case BuiltinOperator_LOG: - case BuiltinOperator_REDUCE_PROD: - if (op_sig.inputs.at(0).type == kTfLiteInt8 || - op_sig.inputs.at(0).type == kTfLiteInt16) { - return 2; - } - return 1; - case BuiltinOperator_DYNAMIC_UPDATE_SLICE: - if (op_sig.inputs.at(2).type == kTfLiteInt64) return 2; - return 1; - - // The version one of broadcast to op won't be not supported since the - // version one was rollbacked and the builtin op code number has been - // changed because of builtin op code shortage problem. - // Quantized broadcast_to is version 3 - case BuiltinOperator_BROADCAST_TO: - if (op_sig.inputs.at(0).type == kTfLiteInt8 || - op_sig.inputs.at(0).type == kTfLiteInt16) { - return 3; - } - return 2; - case BuiltinOperator_CAST: - if (op_sig.inputs.at(0).type == kTfLiteBFloat16 || - op_sig.outputs.at(0).type == kTfLiteBFloat16) { - return 7; - } else if (op_sig.inputs.at(0).type == kTfLiteInt4 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { - return 6; - } else if (op_sig.inputs.at(0).type == kTfLiteFloat64 || - op_sig.outputs.at(0).type == kTfLiteFloat64 || - op_sig.inputs.at(0).type == kTfLiteFloat16 || - op_sig.outputs.at(0).type == kTfLiteFloat16) { - return 5; - } else if (op_sig.inputs.at(0).type == kTfLiteUInt16 || - op_sig.outputs.at(0).type == kTfLiteUInt16) { - return 4; - } else if (op_sig.inputs.at(0).type == kTfLiteInt8 || - op_sig.outputs.at(0).type == kTfLiteInt8) { - return 3; - } else if (op_sig.inputs.at(0).type == kTfLiteUInt32 || - op_sig.outputs.at(0).type == kTfLiteUInt32) { - return 2; - } - return 1; - case BuiltinOperator_WHERE: - if (op_sig.inputs.at(0).type == kTfLiteBool) return 1; - return 2; - case BuiltinOperator_GELU: - if (op_sig.inputs.at(0).type == kTfLiteInt8 || - op_sig.inputs.at(0).type == kTfLiteUInt8) { - return 2; - } - return 1; - default: - return 1; - } - // Prevent lint error about this function being too long. - // NOLINTNEXTLINE -} - -void UpdateOpVersion(uint8_t* model_buffer_pointer) { - auto model = GetMutableModel(model_buffer_pointer); - auto subgraphs = model->subgraphs(); - - for (int i = 0; i < subgraphs->Length(); ++i) { - const SubGraph* subgraph = subgraphs->Get(i); - for (int j = 0; j < subgraph->operators()->Length(); ++j) { - const Operator* op = subgraph->operators()->Get(j); - OperatorCode* op_code = - model->mutable_operator_codes()->GetMutableObject(op->opcode_index()); - - auto builtin_code = GetBuiltinCode(op_code); - if (builtin_code != BuiltinOperator_CUSTOM) { - OpSignature op_sig = GetOpSignature(op_code, op, subgraph, model); - // Update builtin operator version. - int32_t op_ver = GetBuiltinOperatorVersion(op_sig); - if (op_sig.builtin_data) { - free(op_sig.builtin_data); - } - // Skip updating op version if the current node uses lower version. - // TODO(b/184366869): Populate multiple versions of operator once MLIR - // quantizer is ready. - if (op_ver <= op_code->version()) { - continue; - } - if (!op_code->mutate_version(op_ver)) { - LOG(ERROR) << "Can't set operator " - << EnumNameBuiltinOperator(builtin_code) << " to version " - << op_ver; - } - } - } - } -} - -} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h deleted file mode 100644 index bd1f551669a94c..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ - -#include - -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" - -namespace tflite { - -// Returns version of builtin ops by the given signature. -int GetBuiltinOperatorVersion(const OpSignature& op_sig); - -// Update operator's version of the given TFL flatbuffer model. -void UpdateOpVersion(uint8_t* model_buffer_pointer); - -} // namespace tflite - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc deleted file mode 100644 index 5ad70990125a90..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc +++ /dev/null @@ -1,1435 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" - -#include - -#include -#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" - -namespace tflite { -namespace { - -// Creates vector of OpSignatureTensorSpec with the given TfLiteType vector. -std::vector CreateOpSignatureTensorSpecs( - const std::vector& types) { - std::vector tensor_specs; - for (auto type : types) { - OpSignatureTensorSpec tensor_spec = {}; - tensor_spec.type = type; - tensor_specs.push_back(tensor_spec); - } - return tensor_specs; -} - -// Creates vector of OpSignatureTensorSpec with the given TfLiteType vector, -// each with rank 'rank' -std::vector CreateOpSignatureTensorSpecs( - const std::vector& types, int rank) { - std::vector tensor_specs; - for (auto type : types) { - OpSignatureTensorSpec tensor_spec = {}; - tensor_spec.type = type; - for (int i = 0; i < rank; i++) { - tensor_spec.dims.push_back(4); - } - tensor_specs.push_back(tensor_spec); - } - return tensor_specs; -} - -// Creates vector of OpSignatureTensorSpec of single tensor spec of TfLiteType. -std::vector CreateOpSignatureTensorSpecs( - const TfLiteType type) { - std::vector tensor_specs; - OpSignatureTensorSpec tensor_spec = {}; - tensor_spec.type = type; - tensor_specs.push_back(tensor_spec); - return tensor_specs; -} - -// Creates vector of OpSignatureTensorSpec of single tensor spec of TfLiteType -// with shapes. -std::vector CreateOpSignatureTensorSpecs( - const TfLiteType type, const int dim) { - std::vector tensor_specs; - OpSignatureTensorSpec tensor_spec = {}; - tensor_spec.type = type; - for (int i = 0; i < dim; i++) { - tensor_spec.dims.push_back(4); - } - tensor_specs.push_back(tensor_spec); - return tensor_specs; -} - -// Creates vector of OpSignatureTensorSpec of two tensor specs of TfLiteType -// with shapes. -std::vector CreateOpSignatureTensorSpecs( - const TfLiteType type, const int dim1, const int dim2) { - std::vector tensor_specs; - OpSignatureTensorSpec tensor_spec1 = {}; - tensor_spec1.type = type; - for (int i = 0; i < dim1; i++) { - tensor_spec1.dims.push_back(4); - } - tensor_specs.push_back(tensor_spec1); - - OpSignatureTensorSpec tensor_spec2 = {}; - tensor_spec2.type = type; - for (int i = 0; i < dim2; i++) { - tensor_spec2.dims.push_back(4); - } - tensor_specs.push_back(tensor_spec2); - return tensor_specs; -} - -} // namespace - -TEST(OpVersionTest, VersioningSpareToDense) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_SPARSE_TO_DENSE, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt8, kTfLiteInt8}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_SPARSE_TO_DENSE, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteUInt8, kTfLiteUInt8, kTfLiteUInt8}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_SPARSE_TO_DENSE, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt64, kTfLiteInt64, kTfLiteInt64}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_SPARSE_TO_DENSE, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteInt32, kTfLiteInt32}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -// Test version for a simple Op with 2 versions and the input type controls the -// version. -void SimpleVersioningTest(BuiltinOperator op) { - OpSignature fake_op_sig = { - .op = op, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = op, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -// Similar to SimpleVersioningTest function, but -// op has 3 versions and the input type includes kTfLiteInt16. -void SimpleVersioningTestExtended(BuiltinOperator op) { - OpSignature fake_op_sig = { - .op = op, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - SimpleVersioningTest(op); -} - -// Test version for a simple Op with 2 versions and the output type controls the -void SimpleOutputVersioningTest(BuiltinOperator op) { - OpSignature fake_op_sig = { - .op = op, - .inputs = std::vector{}, - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = op, - .inputs = std::vector{}, - .outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningEqualTest) { - SimpleVersioningTest(BuiltinOperator_EQUAL); - OpSignature fake_op_sig = { - .op = BuiltinOperator_EQUAL, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteString), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); -} - -TEST(OpVersionTest, VersioningNotEqualTest) { - SimpleVersioningTest(BuiltinOperator_NOT_EQUAL); - OpSignature fake_op_sig = { - .op = BuiltinOperator_NOT_EQUAL, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteString), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); -} - -TEST(OpVersionTest, VersioningLessTest) { - SimpleVersioningTest(BuiltinOperator_LESS); -} - -TEST(OpVersionTest, VersioningLessEqualTest) { - SimpleVersioningTest(BuiltinOperator_LESS_EQUAL); -} - -TEST(OpVersionTest, VersioningGreaterTest) { - SimpleVersioningTest(BuiltinOperator_GREATER); -} - -TEST(OpVersionTest, VersioningGreaterEqualTest) { - SimpleVersioningTest(BuiltinOperator_GREATER_EQUAL); -} - -TEST(OpVersionTest, VersioningSpaceToBatchNDTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_SPACE_TO_BATCH_ND, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 3); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 3); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 3); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningLogSoftmaxTest) { - SimpleVersioningTest(BuiltinOperator_LOG_SOFTMAX); -} - -TEST(OpVersionTest, VersioningPackTest) { - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_PACK; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_PACK; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_PACK; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_PACK; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningUnpackTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_UNPACK, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_UNPACK, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_UNPACK, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningRangeTest) { - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_RANGE; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningReluTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_RELU, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_RELU, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_RELU, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_RELU, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningBatchToSpaceNDTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_BATCH_TO_SPACE_ND, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 3); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 3); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 3); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningTanhTest) { - SimpleVersioningTest(BuiltinOperator_TANH); -} - -TEST(OpVersionTest, VersioningStridedSliceTest) { - TfLiteStridedSliceParams strided_slice_params = {}; - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_STRIDED_SLICE; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - fake_op_sig.builtin_data = reinterpret_cast(&strided_slice_params); - strided_slice_params.ellipsis_mask = 0; - strided_slice_params.new_axis_mask = 2; - fake_op_sig.ext_options.strided_slice.num_dims = 5; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); - - strided_slice_params.new_axis_mask = 0; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - - fake_op_sig.ext_options.strided_slice.num_dims = 4; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); - - strided_slice_params.offset = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); -} - -TEST(OpVersionTest, VersioningSpaceToDepthTest) { - SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH); -} - -TEST(OpVersionTest, VersioningSliceTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_SLICE, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); - - fake_op_sig = { - .op = BuiltinOperator_SLICE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - - fake_op_sig = { - .op = BuiltinOperator_SLICE, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteString, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_SLICE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_SLICE, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_SLICE; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); -} - -TEST(OpVersionTest, VersioningLogisticTest) { - SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH); -} - -TEST(OpVersionTest, VersioningL2NormTest) { - SimpleOutputVersioningTest(BuiltinOperator_L2_NORMALIZATION); -} - -TEST(OpVersionTest, VersioningMaxTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_MAXIMUM, - }; - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 5, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_MAXIMUM, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningMinTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_MINIMUM, - }; - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 5, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_MINIMUM, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningMeanTest) { - SimpleVersioningTestExtended(BuiltinOperator_MEAN); -} - -TEST(OpVersionTest, VersioningSumTest) { - SimpleVersioningTest(BuiltinOperator_SUM); -} - -TEST(OpVersionTest, VersioningReduceMinTest) { - SimpleVersioningTestExtended(BuiltinOperator_REDUCE_MIN); -} - -TEST(OpVersionTest, VersioningReduceMaxTest) { - SimpleVersioningTestExtended(BuiltinOperator_REDUCE_MAX); -} - -TEST(OpVersionTest, VersioningMirrorPadTest) { - SimpleVersioningTestExtended(BuiltinOperator_MIRROR_PAD); -} - -TEST(OpVersionTest, VersioningReduceProdTest) { - OpSignature fake_op_sig; - fake_op_sig.op = BuiltinOperator_REDUCE_PROD; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningAddTest) { - TfLiteAddParams add_params = {}; - OpSignature fake_op_sig = { - .op = BuiltinOperator_ADD, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .builtin_data = reinterpret_cast(&add_params)}; - add_params.pot_scale_int16 = false; - fake_op_sig.ext_options.add.input_quantized = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig.ext_options.add.input_quantized = false; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); - - SimpleVersioningTest(BuiltinOperator_ADD); -} - -TEST(OpVersionTest, VersioningSubTest) { - TfLiteSubParams sub_params = {}; - OpSignature fake_op_sig = { - .op = BuiltinOperator_SUB, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .builtin_data = reinterpret_cast(&sub_params)}; - sub_params.pot_scale_int16 = false; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8, 4, 5); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - SimpleVersioningTest(BuiltinOperator_SUB); -} - -TEST(OpVersionTest, VersioningMUL7TestInt16) { - OpSignature fake_op_sig; - fake_op_sig.op = BuiltinOperator_MUL; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); - fake_op_sig.ext_options.mul.input_quantized = false; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); -} - -TEST(OpVersionTest, VersioningMUL7TestUInt32) { - OpSignature fake_op_sig; - fake_op_sig.op = BuiltinOperator_MUL; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); -} - -TEST(OpVersionTest, VersioningMUL6Test) { - OpSignature fake_op_sig; - fake_op_sig.op = BuiltinOperator_MUL; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteComplex64); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); -} - -TEST(OpVersionTest, VersioningMUL5Test) { - OpSignature fake_op_sig; - fake_op_sig.op = BuiltinOperator_MUL; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); -} - -TEST(OpVersionTest, VersioningSub4Test) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_SUB, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); -} - -void SimpleMulVersioningTest(TfLiteType data_type, float multiplier, - int version) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_MUL, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{data_type, data_type}), - .outputs = CreateOpSignatureTensorSpecs(data_type), - }; - fake_op_sig.ext_options.mul = {1.0f, 1.0f, 1.0f / multiplier}; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), version); -} - -TEST(OpVersionTest, VersioningMulTest) { - SimpleMulVersioningTest(kTfLiteUInt8, 0.5f, 1); - SimpleMulVersioningTest(kTfLiteInt8, 0.5f, 2); - SimpleMulVersioningTest(kTfLiteInt8, 2.0f, 3); -} - -TEST(OpVersionTest, VersioningPadTest) { - SimpleVersioningTest(BuiltinOperator_PAD); -} - -TEST(OpVersionTest, VersioningPadV2Test) { - SimpleVersioningTest(BuiltinOperator_PADV2); -} - -TEST(OpVersionTest, VersioningConcatenationTest) { - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_CONCATENATION; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); -} - -TEST(OpVersionTest, VersioningSelectTest) { - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_SELECT; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteUInt32, kTfLiteUInt32, kTfLiteUInt32}, 5); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_SELECT; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteUInt8, kTfLiteUInt8, kTfLiteUInt8}, 5); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_SELECT; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt8, kTfLiteInt8}, 4); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_SELECT; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteFloat32}, - 4); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningSelectV2Test) { - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_SELECT_V2; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteUInt32, kTfLiteUInt32, kTfLiteUInt32}, 5); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_SELECT_V2; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteInt32, kTfLiteInt32}, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningRelu6Test) { - SimpleVersioningTestExtended(BuiltinOperator_RELU6); -} - -TEST(OpVersionTest, VersioningFullyConnectedTest) { - TfLiteFullyConnectedParams fully_connected_params = {}; - OpSignature fake_op_sig = { - .op = BuiltinOperator_FULLY_CONNECTED, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteUInt8, kTfLiteUInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), - .builtin_data = reinterpret_cast(&fully_connected_params), - }; - fully_connected_params.weights_format = - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); - - fake_op_sig = { - .op = BuiltinOperator_FULLY_CONNECTED, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .builtin_data = reinterpret_cast(&fully_connected_params), - }; - fully_connected_params.weights_format = - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); - - fake_op_sig = { - .op = BuiltinOperator_FULLY_CONNECTED, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .builtin_data = reinterpret_cast(&fully_connected_params), - }; - fully_connected_params.weights_format = - kTfLiteFullyConnectedWeightsFormatDefault; - fake_op_sig.ext_options.fully_connected.sparse_weight = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); - - fake_op_sig = { - .op = BuiltinOperator_FULLY_CONNECTED, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8, kTfLiteFloat32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&fully_connected_params), - }; - fully_connected_params.asymmetric_quantize_inputs = false; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fully_connected_params.asymmetric_quantize_inputs = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 9); - - fake_op_sig = { - .op = BuiltinOperator_FULLY_CONNECTED, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt16, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .builtin_data = reinterpret_cast(&fully_connected_params), - }; - fully_connected_params.quantized_bias_type = kTfLiteInt32; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 11); - - fake_op_sig = { - .op = BuiltinOperator_FULLY_CONNECTED, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&fully_connected_params), - }; - fake_op_sig.ext_options.fully_connected.is_per_channel_quantized = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 12); -} - -TEST(OpVersionTest, VersioningDequantizeTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_DEQUANTIZE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_DEQUANTIZE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_DEQUANTIZE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.ext_options.dequantize.is_per_channel_quantized = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); - - fake_op_sig = { - .op = BuiltinOperator_DEQUANTIZE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} - -TEST(OpVersionTest, VersioningQuantizeTest) { - OpSignature fake_op_sig; - fake_op_sig.op = BuiltinOperator_QUANTIZE; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - fake_op_sig.ext_options.quantize.is_per_channel_quantized = false; - - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.ext_options.quantize.is_per_channel_quantized = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); -} - -TEST(OpVersionTest, VersioningConv2DTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteUInt8, kTfLiteUInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig = { - .op = BuiltinOperator_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - }; - fake_op_sig.ext_options.conv_2d.is_per_channel_quantized = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); - - fake_op_sig.op = BuiltinOperator_CONV_2D; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8}); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - fake_op_sig.ext_options.conv_2d.is_grouped_convolution = true; - - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); - - TfLiteConvParams conv_params = {}; - fake_op_sig = { - .op = BuiltinOperator_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt16, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .builtin_data = reinterpret_cast(&conv_params), - }; - conv_params.quantized_bias_type = kTfLiteInt32; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); -} - -TEST(OpVersionTest, VersioningFloorDivOperatorTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_FLOOR_DIV, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig = { - .op = BuiltinOperator_FLOOR_DIV, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_FLOOR_DIV, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); -} - -TEST(OpVersionTest, VersioningFloorModOperatorTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_FLOOR_MOD, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig = { - .op = BuiltinOperator_FLOOR_MOD, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); -} - -TEST(OpVersionTest, VersioningTransposeConvOperatorTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE_CONV, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteUInt8}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE_CONV, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteInt8, kTfLiteInt8}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE_CONV, - .inputs = CreateOpSignatureTensorSpecs(std::vector{ - kTfLiteInt32, kTfLiteInt8, kTfLiteInt8, kTfLiteInt32}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - const auto none_type = kTfLiteNoType; - fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE_CONV, - .inputs = CreateOpSignatureTensorSpecs(std::vector{ - kTfLiteInt32, kTfLiteInt8, kTfLiteInt8, none_type}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - TfLiteTransposeConvParams transpose_conv_params = {}; - transpose_conv_params.activation = kTfLiteActRelu; - fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE_CONV, - .inputs = CreateOpSignatureTensorSpecs(std::vector{ - kTfLiteInt32, kTfLiteInt8, kTfLiteInt8, none_type}), - .builtin_data = reinterpret_cast(&transpose_conv_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - - transpose_conv_params = {}; - fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE_CONV, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt16, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .builtin_data = reinterpret_cast(&transpose_conv_params), - }; - transpose_conv_params.quantized_bias_type = kTfLiteInt32; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); -} - -TEST(OpVersionTest, VersioningSVDFOperatorTest) { - TfLiteSVDFParams svdf_params = {}; - OpSignature fake_op_sig = { - .op = BuiltinOperator_SVDF, - .inputs = CreateOpSignatureTensorSpecs(std::vector{ - kTfLiteFloat32, kTfLiteFloat32, kTfLiteFloat32, kTfLiteFloat32, - kTfLiteFloat32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&svdf_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig = { - .op = BuiltinOperator_SVDF, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8, kTfLiteFloat32, - kTfLiteFloat32, kTfLiteFloat32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&svdf_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - svdf_params.asymmetric_quantize_inputs = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - - svdf_params = {}; - fake_op_sig = { - .op = BuiltinOperator_SVDF, - .inputs = CreateOpSignatureTensorSpecs(std::vector{ - kTfLiteInt8, kTfLiteInt8, kTfLiteInt32, kTfLiteInt32, kTfLiteInt16}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .builtin_data = reinterpret_cast(&svdf_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); -} - -TEST(OpVersionTest, VersioningDepthwiseConv2DTest) { - TfLiteDepthwiseConvParams depthwise_conv_params = {}; - OpSignature fake_op_sig = { - .op = BuiltinOperator_DEPTHWISE_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&depthwise_conv_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); - - depthwise_conv_params = {}; - fake_op_sig = { - .op = BuiltinOperator_DEPTHWISE_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .builtin_data = reinterpret_cast(&depthwise_conv_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_DEPTHWISE_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&depthwise_conv_params), - }; - depthwise_conv_params.dilation_width_factor = 2; - depthwise_conv_params.dilation_height_factor = 2; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_DEPTHWISE_CONV_2D, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&depthwise_conv_params), - }; - depthwise_conv_params.dilation_width_factor = 1; - depthwise_conv_params.dilation_height_factor = 1; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} -TEST(OpVersionTest, VersioningTileOperatorTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_TILE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig = { - .op = BuiltinOperator_TILE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteString), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); -} -TEST(OpVersionTest, VersioningTransposeTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); - - fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteBool, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteBool, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_TRANSPOSE, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} -TEST(OpVersionTest, VersioningGatherNdOperatorTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_GATHER_ND, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteInt32}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig = { - .op = BuiltinOperator_GATHER_ND, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteString, kTfLiteInt32}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig = { - .op = BuiltinOperator_GATHER_ND, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt16, kTfLiteInt32}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_GATHER_ND, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteInt16}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - - fake_op_sig = { - .op = BuiltinOperator_GATHER_ND, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteBool, kTfLiteInt16}), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); -} -TEST(OpVersionTest, VersioningDivTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_DIV, - }; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 5, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 5, 5); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8, 4, 4); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} -TEST(OpVersionTEst, VersioningFillTest) { - OpSignature fake_op_sig = {BuiltinOperator_FILL}; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteFloat16}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt64, kTfLiteFloat16}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteInt8}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt64, kTfLiteInt16}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteBool}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteString}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt32, kTfLiteInt32}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); -} -TEST(OpVersionTest, VersioningResizeBilinearTest) { - // Default. - TfLiteResizeBilinearParams resize_bilinear_params = {}; - OpSignature fake_op_sig = { - .op = BuiltinOperator_RESIZE_BILINEAR, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&resize_bilinear_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - // align_corners=true is still version 1. - resize_bilinear_params.align_corners = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - // half_pixel_centers=true must be version 3. - resize_bilinear_params.align_corners = false; - resize_bilinear_params.half_pixel_centers = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - // int8 input is version 2. - resize_bilinear_params = {}; - fake_op_sig = { - .op = BuiltinOperator_RESIZE_BILINEAR, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .builtin_data = reinterpret_cast(&resize_bilinear_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - resize_bilinear_params.half_pixel_centers = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - // int16 input is version 4. - resize_bilinear_params = {}; - fake_op_sig = { - .op = BuiltinOperator_RESIZE_BILINEAR, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt16, kTfLiteInt32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .builtin_data = reinterpret_cast(&resize_bilinear_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); -} -TEST(OpVersionTest, VersioningResizeNearestNeighborTest) { - // Default. - TfLiteResizeNearestNeighborParams resize_nearest_neighbor_params = {}; - OpSignature fake_op_sig = { - .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&resize_nearest_neighbor_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - // align_corners=true is version 3. - resize_nearest_neighbor_params.align_corners = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - // half_pixel_centers=true must be version 3. - resize_nearest_neighbor_params.align_corners = false; - resize_nearest_neighbor_params.half_pixel_centers = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - // int8 input is version 2. - resize_nearest_neighbor_params = {}; - fake_op_sig = { - .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .builtin_data = reinterpret_cast(&resize_nearest_neighbor_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - resize_nearest_neighbor_params.align_corners = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - // int16 input is version 4. - resize_nearest_neighbor_params = {}; - fake_op_sig = { - .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt16, kTfLiteInt32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .builtin_data = reinterpret_cast(&resize_nearest_neighbor_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); -} -TEST(OpVersionTest, VersioningAbsTest) { - // Default. - OpSignature fake_op_sig = { - .op = BuiltinOperator_ABS, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - // int8 input is version 2. - fake_op_sig = { - .op = BuiltinOperator_ABS, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - // int16 quantized input is version 3. - fake_op_sig = { - .op = BuiltinOperator_ABS, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - fake_op_sig.ext_options.abs.input_quantized = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - // int16 non-quantized input is version 4. - fake_op_sig = { - .op = BuiltinOperator_ABS, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_ABS; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); -} -TEST(OpVersionTest, VersioningSignTest) { - // Default. - OpSignature fake_op_sig; - fake_op_sig.op = BuiltinOperator_SIGN; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - // int32 input is version 2. - fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_SIGN; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); -} -TEST(OpVersionTest, VersioningBatchMatMulTest) { - // Default. - TfLiteBatchMatMulParams batch_mat_mul_params = {}; - OpSignature fake_op_sig = { - .op = BuiltinOperator_BATCH_MATMUL, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&batch_mat_mul_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - // int8 input is version 2. - batch_mat_mul_params = {}; - fake_op_sig = { - .op = BuiltinOperator_BATCH_MATMUL, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .builtin_data = reinterpret_cast(&batch_mat_mul_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - // int16 input is version 3. - fake_op_sig = { - .op = BuiltinOperator_BATCH_MATMUL, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt16, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .builtin_data = reinterpret_cast(&batch_mat_mul_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - // Symmetric hybrid quantized input is version 1. - fake_op_sig = { - .op = BuiltinOperator_BATCH_MATMUL, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&batch_mat_mul_params), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - // Asymmetric hybrid quantized input is version 4. - fake_op_sig = { - .op = BuiltinOperator_BATCH_MATMUL, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .builtin_data = reinterpret_cast(&batch_mat_mul_params), - }; - batch_mat_mul_params.asymmetric_quantize_inputs = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); -} -TEST(OpVersionTest, VersioningSquaredDifferenceTest) { - // Default. - OpSignature fake_op_sig = { - .op = BuiltinOperator_SQUARED_DIFFERENCE, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - // int8 input is version 2. - fake_op_sig = { - .op = BuiltinOperator_SQUARED_DIFFERENCE, - .inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteInt8, kTfLiteInt8}), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); -} -TEST(OpVersionTest, VersioningRsqrtTest) { - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_RSQRT; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); -} -TEST(OpVersionTest, VersioningBroadcastToTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_BROADCAST_TO, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - // Quantized broadcast_to op is version 3. - fake_op_sig = { - .op = BuiltinOperator_BROADCAST_TO, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - fake_op_sig = { - .op = BuiltinOperator_BROADCAST_TO, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); -} - -TEST(OpVersionTest, VersioningGeluTest) { - OpSignature fake_op_sig; - fake_op_sig.op = BuiltinOperator_GELU; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.op = BuiltinOperator_GELU; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.op = BuiltinOperator_GELU; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); -} - -TEST(OpVersionTest, VersioningUnidirectionalLstmTest) { - TfLiteUnidirectionalSequenceLSTMParams params = {}; - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteFloat32}); - fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - fake_op_sig.builtin_data = reinterpret_cast(¶ms); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteInt8}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - params.asymmetric_quantize_inputs = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); - - params.diagonal_recurrent_tensors = true; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); -} - -TEST(OpVersionTest, VersioningExpTest) { - OpSignature fake_op_sig = { - .op = BuiltinOperator_EXP, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - fake_op_sig = { - .op = BuiltinOperator_EXP, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - fake_op_sig = { - .op = BuiltinOperator_EXP, - .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16), - }; - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); -} - -TEST(OpVersionTest, VersioningLogTest) { - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_LOG; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); -} - -TEST(OpVersionTest, VersioningDynamicUpdateSliceTest) { - OpSignature fake_op_sig = {}; - fake_op_sig.op = BuiltinOperator_DYNAMIC_UPDATE_SLICE; - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteInt32}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); - - fake_op_sig.inputs = CreateOpSignatureTensorSpecs( - std::vector{kTfLiteFloat32, kTfLiteFloat32, kTfLiteInt64}); - EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); -} -} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc deleted file mode 100644 index b3bbd7f3be3faa..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc +++ /dev/null @@ -1,510 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" - -namespace tflite { - -bool CompareRuntimeVersion(const std::string& v1, const std::string& v2) { - const std::vector vec1 = absl::StrSplit(v1, '.'); - const std::vector vec2 = absl::StrSplit(v2, '.'); - int i = 0; - while (i < vec1.size() && i < vec2.size()) { - int v1_val, v2_val; - if (absl::SimpleAtoi(vec1[i], &v1_val) && - absl::SimpleAtoi(vec2[i], &v2_val)) { - if (v1_val != v2_val) return v1_val < v2_val; - } - ++i; - } - // If there are remaining items in v2 not being compared, then v1 should - // precede v2. - return i < vec2.size(); -} - -std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, - int op_version) { - // A map from the version key of an op to its minimum runtime version. - // For example, {{kAveragePool, 1}, "1.5.0"}, means the 1st version of - // AveragePool requires a minimum TF Lite runtime version '1.5.0`. - // NOTE: When adding a new op version pair, associate it with the current - // runtime version defined in tensorflow/core/public/version.h. - static const std::map, - std::string>* op_version_map = - new std::map, std::string>( - {{{BuiltinOperator_AVERAGE_POOL_2D, 1}, "1.5.0"}, - {{BuiltinOperator_AVERAGE_POOL_2D, 2}, "1.14.0"}, - {{BuiltinOperator_AVERAGE_POOL_2D, 3}, "2.3.0"}, - {{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"}, - {{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"}, - {{BuiltinOperator_BATCH_MATMUL, 3}, "2.4.0"}, - {{BuiltinOperator_BATCH_MATMUL, 4}, "2.5.0"}, - // The version one of broadcast to op won't be not supported since - // the version one was rollbacked and the builtin op code number - // has been changed because of builtin op code shortage problem. - {{BuiltinOperator_BROADCAST_TO, 2}, "2.5.0"}, - {{BuiltinOperator_BROADCAST_TO, 3}, "2.5.0"}, - {{BuiltinOperator_CONV_2D, 1}, "1.5.0"}, - {{BuiltinOperator_CONV_2D, 2}, "1.14.0"}, - {{BuiltinOperator_CONV_2D, 3}, "1.14.0"}, - {{BuiltinOperator_CONV_2D, 4}, "2.3.0"}, - {{BuiltinOperator_CONV_2D, 5}, "2.4.0"}, - {{BuiltinOperator_CONV_2D, 6}, "2.9.0"}, - {{BuiltinOperator_CONV_2D, 7}, "2.11.0"}, - {{BuiltinOperator_CONV_2D, 8}, "2.15.0"}, - {{BuiltinOperator_DEPTHWISE_CONV_2D, 1}, "1.5.0"}, - {{BuiltinOperator_DEPTHWISE_CONV_2D, 2}, "1.12.0"}, - {{BuiltinOperator_DEPTHWISE_CONV_2D, 3}, "1.14.0"}, - {{BuiltinOperator_DEPTHWISE_CONV_2D, 4}, "2.2.0"}, - {{BuiltinOperator_DEPTHWISE_CONV_2D, 5}, "2.3.0"}, - {{BuiltinOperator_DEPTHWISE_CONV_2D, 6}, "2.3.0"}, - {{BuiltinOperator_DEPTHWISE_CONV_2D, 7}, "2.11.0"}, - {{BuiltinOperator_ADD, 1}, "1.5.0"}, - {{BuiltinOperator_ADD, 2}, "1.14.0"}, - {{BuiltinOperator_ADD, 3}, "2.4.0"}, - {{BuiltinOperator_ADD, 4}, "2.6.0"}, - {{BuiltinOperator_ADD, 5}, "2.13.0"}, - {{BuiltinOperator_ADD_N, 1}, "1.14.0"}, - {{BuiltinOperator_SPACE_TO_BATCH_ND, 1}, "1.6.0"}, - {{BuiltinOperator_SPACE_TO_BATCH_ND, 2}, "1.14.0"}, - {{BuiltinOperator_SPACE_TO_BATCH_ND, 3}, "2.3.0"}, - {{BuiltinOperator_SPACE_TO_BATCH_ND, 4}, "2.12.0"}, - {{BuiltinOperator_SUB, 1}, "1.6.0"}, - {{BuiltinOperator_SUB, 2}, "1.14.0"}, - {{BuiltinOperator_SUB, 3}, "2.3.0"}, - {{BuiltinOperator_SUB, 4}, "2.4.0"}, - {{BuiltinOperator_SUB, 5}, "2.4.0"}, - {{BuiltinOperator_DENSIFY, 1}, "2.2.0"}, - {{BuiltinOperator_DIV, 1}, "1.6.0"}, - {{BuiltinOperator_DIV, 2}, "2.3.0"}, - {{BuiltinOperator_BATCH_TO_SPACE_ND, 1}, "1.6.0"}, - {{BuiltinOperator_BATCH_TO_SPACE_ND, 2}, "1.14.0"}, - {{BuiltinOperator_BATCH_TO_SPACE_ND, 3}, "2.3.0"}, - {{BuiltinOperator_BATCH_TO_SPACE_ND, 4}, "2.12.0"}, - {{BuiltinOperator_CAST, 1}, "1.5.0"}, - {{BuiltinOperator_CAST, 2}, "2.7.0"}, - {{BuiltinOperator_CAST, 3}, "2.8.0"}, - {{BuiltinOperator_CAST, 4}, "2.9.0"}, - {{BuiltinOperator_CAST, 5}, "2.12.0"}, - {{BuiltinOperator_CAST, 6}, "2.15.0"}, - {{BuiltinOperator_CONCATENATION, 1}, "1.5.0"}, - {{BuiltinOperator_CONCATENATION, 2}, "1.14.0"}, - {{BuiltinOperator_CONCATENATION, 3}, "2.3.0"}, - {{BuiltinOperator_CONCATENATION, 4}, "2.14.0"}, - {{BuiltinOperator_DEPTH_TO_SPACE, 1}, "2.1.0"}, - {{BuiltinOperator_DEPTH_TO_SPACE, 2}, "2.5.0"}, - {{BuiltinOperator_EMBEDDING_LOOKUP, 1}, "1.13.0"}, - {{BuiltinOperator_EMBEDDING_LOOKUP, 2}, "1.14.0"}, - {{BuiltinOperator_EMBEDDING_LOOKUP, 3}, "1.14.0"}, - {{BuiltinOperator_EMBEDDING_LOOKUP, 4}, "2.18.0"}, - {{BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, 1}, "1.5.0"}, - {{BuiltinOperator_FAKE_QUANT, 1}, "1.5.0"}, - {{BuiltinOperator_FAKE_QUANT, 2}, "1.10.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 1}, "1.5.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 2}, "1.10.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 3}, "1.14.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 4}, "1.14.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 5}, "2.0.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 6}, "2.1.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 7}, "2.3.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 8}, "2.3.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 9}, "2.3.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 10}, "2.11.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 11}, "2.15.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 12}, "2.17.0"}, - {{BuiltinOperator_FULLY_CONNECTED, 13}, "2.18.0"}, - {{BuiltinOperator_GATHER, 1}, "1.6.0"}, - {{BuiltinOperator_GATHER, 2}, "1.14.0"}, - {{BuiltinOperator_GATHER, 3}, "1.15.0"}, - {{BuiltinOperator_GATHER, 4}, "2.4.0"}, - {{BuiltinOperator_GATHER, 5}, "2.5.0"}, - {{BuiltinOperator_GATHER, 6}, "2.13.0"}, - {{BuiltinOperator_GATHER, 7}, "2.15.0"}, - {{BuiltinOperator_GATHER_ND, 1}, "1.14.0"}, - {{BuiltinOperator_GATHER_ND, 2}, "2.3.0"}, - {{BuiltinOperator_GATHER_ND, 3}, "2.5.0"}, - {{BuiltinOperator_GATHER_ND, 4}, "2.13.0"}, - {{BuiltinOperator_GATHER_ND, 5}, "2.16.0"}, - {{BuiltinOperator_HASHTABLE_LOOKUP, 1}, "1.5.0"}, - {{BuiltinOperator_SVDF, 1}, "1.5.0"}, - {{BuiltinOperator_SVDF, 2}, "1.14.0"}, - {{BuiltinOperator_SVDF, 3}, "2.2.0"}, - {{BuiltinOperator_SVDF, 4}, "2.3.0"}, - {{BuiltinOperator_L2_NORMALIZATION, 1}, "1.5.0"}, - {{BuiltinOperator_L2_NORMALIZATION, 2}, "1.14.0"}, - {{BuiltinOperator_L2_POOL_2D, 1}, "1.5.0"}, - {{BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, 1}, "1.5.0"}, - {{BuiltinOperator_MAX_POOL_2D, 1}, "1.5.0"}, - {{BuiltinOperator_MAX_POOL_2D, 2}, "1.14.0"}, - {{BuiltinOperator_MAX_POOL_2D, 3}, "2.3.0"}, - {{BuiltinOperator_MAXIMUM, 1}, "1.14.0"}, - {{BuiltinOperator_MAXIMUM, 2}, "1.14.0"}, - {{BuiltinOperator_MAXIMUM, 3}, "2.3.0"}, - {{BuiltinOperator_MAXIMUM, 4}, "2.3.0"}, - {{BuiltinOperator_MINIMUM, 1}, "1.14.0"}, - {{BuiltinOperator_MINIMUM, 2}, "1.14.0"}, - {{BuiltinOperator_MINIMUM, 3}, "2.3.0"}, - {{BuiltinOperator_MINIMUM, 4}, "2.3.0"}, - {{BuiltinOperator_MUL, 1}, "1.5.0"}, - {{BuiltinOperator_MUL, 2}, "1.14.0"}, - {{BuiltinOperator_MUL, 3}, "1.15.0"}, - {{BuiltinOperator_MUL, 4}, "2.3.0"}, - {{BuiltinOperator_MUL, 5}, "2.6.0"}, - {{BuiltinOperator_MUL, 6}, "2.11.0"}, - {{BuiltinOperator_MUL, 7}, "2.13.0"}, - {{BuiltinOperator_NON_MAX_SUPPRESSION_V4, 1}, "2.1.0"}, - {{BuiltinOperator_NON_MAX_SUPPRESSION_V5, 1}, "2.1.0"}, - {{BuiltinOperator_PAD, 1}, "1.5.0"}, - {{BuiltinOperator_PAD, 2}, "1.14.0"}, - {{BuiltinOperator_PAD, 3}, "2.4.0"}, - {{BuiltinOperator_PAD, 4}, "2.6.0"}, - {{BuiltinOperator_TILE, 1}, "1.10.1"}, - {{BuiltinOperator_TILE, 2}, "2.2.0"}, - {{BuiltinOperator_TILE, 3}, "2.8.0"}, - {{BuiltinOperator_PADV2, 1}, "1.9.0"}, - {{BuiltinOperator_PADV2, 2}, "1.14.0"}, - {{BuiltinOperator_PADV2, 3}, "2.4.0"}, - {{BuiltinOperator_PADV2, 4}, "2.6.0"}, - {{BuiltinOperator_RESHAPE, 1}, "1.5.0"}, - {{BuiltinOperator_SOFTMAX, 1}, "1.5.0"}, - {{BuiltinOperator_SOFTMAX, 2}, "1.14.0"}, - {{BuiltinOperator_SOFTMAX, 3}, "2.3.0"}, - {{BuiltinOperator_SPACE_TO_DEPTH, 1}, "1.5.0"}, - {{BuiltinOperator_SPACE_TO_DEPTH, 2}, "1.14.0"}, - {{BuiltinOperator_TRANSPOSE, 1}, "1.6.0"}, - {{BuiltinOperator_TRANSPOSE, 2}, "1.14.0"}, - {{BuiltinOperator_TRANSPOSE, 3}, "1.15.0"}, - {{BuiltinOperator_TRANSPOSE, 4}, "2.3.0"}, - {{BuiltinOperator_TRANSPOSE, 5}, "2.4.0"}, - {{BuiltinOperator_TRANSPOSE, 6}, "2.12.0"}, - {{BuiltinOperator_LSTM, 1}, "1.7.0"}, - {{BuiltinOperator_LSTM, 2}, "1.10.0"}, - {{BuiltinOperator_LSTM, 3}, "1.14.0"}, - {{BuiltinOperator_LSTM, 4}, "2.3.0"}, - {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 1}, "1.13.1"}, - {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 2}, "1.14.0"}, - {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 3}, "2.3.0"}, - {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 4}, "2.12.0"}, - {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, 1}, "1.14.0"}, - {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, 2}, "1.14.0"}, - {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, 3}, "1.14.0"}, - {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, 1}, "1.14.0"}, - {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, 2}, "1.14.0"}, - {{BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, 3}, "2.3.0"}, - {{BuiltinOperator_MEAN, 1}, "1.6.0"}, - {{BuiltinOperator_MEAN, 2}, "1.14.0"}, - {{BuiltinOperator_MEAN, 3}, "2.4.0"}, - {{BuiltinOperator_SUM, 1}, "1.10.0"}, - {{BuiltinOperator_SUM, 2}, "1.15.0"}, - {{BuiltinOperator_REDUCE_MAX, 1}, "1.11.0"}, - {{BuiltinOperator_REDUCE_MAX, 2}, "1.14.0"}, - {{BuiltinOperator_REDUCE_MAX, 3}, "2.5.0"}, - {{BuiltinOperator_REDUCE_MIN, 1}, "1.11.0"}, - {{BuiltinOperator_REDUCE_MIN, 2}, "1.14.0"}, - {{BuiltinOperator_REDUCE_MIN, 3}, "2.5.0"}, - {{BuiltinOperator_REDUCE_PROD, 1}, "1.11.0"}, - {{BuiltinOperator_REDUCE_PROD, 2}, "2.6.0"}, - {{BuiltinOperator_REDUCE_ANY, 1}, "1.11.0"}, - {{BuiltinOperator_RELU6, 1}, "1.5.0"}, - {{BuiltinOperator_RELU6, 2}, "1.14.0"}, - {{BuiltinOperator_RELU6, 3}, "2.5.0"}, - {{BuiltinOperator_RESIZE_BILINEAR, 1}, "1.7.0"}, - {{BuiltinOperator_RESIZE_BILINEAR, 2}, "1.14.0"}, - {{BuiltinOperator_RESIZE_BILINEAR, 3}, "2.2.0"}, - {{BuiltinOperator_RESIZE_BILINEAR, 4}, "2.5.0"}, - {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 1}, "1.13.1"}, - {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 2}, "1.14.0"}, - {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 3}, "2.3.0"}, - {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 4}, "2.4.0"}, - {{BuiltinOperator_RNN, 1}, "1.5.0"}, - {{BuiltinOperator_RNN, 2}, "1.14.0"}, - {{BuiltinOperator_RNN, 3}, "2.3.0"}, - {{BuiltinOperator_SKIP_GRAM, 1}, "1.5.0"}, - {{BuiltinOperator_SQUEEZE, 1}, "1.6.0"}, - {{BuiltinOperator_SQUEEZE, 2}, "2.5.0"}, - {{BuiltinOperator_SPLIT, 1}, "1.5.0"}, - {{BuiltinOperator_SPLIT, 2}, "1.14.0"}, - {{BuiltinOperator_SPLIT, 3}, "1.14.0"}, - {{BuiltinOperator_SPLIT, 4}, "2.3.0"}, - {{BuiltinOperator_SPLIT_V, 1}, "1.13.1"}, - {{BuiltinOperator_SPLIT_V, 2}, "2.3.0"}, - {{BuiltinOperator_STRIDED_SLICE, 1}, "1.6.0"}, - {{BuiltinOperator_STRIDED_SLICE, 2}, "1.14.0"}, - {{BuiltinOperator_STRIDED_SLICE, 3}, "2.1.0"}, - {{BuiltinOperator_STRIDED_SLICE, 4}, "2.2.0"}, - {{BuiltinOperator_STRIDED_SLICE, 5}, "2.5.0"}, - {{BuiltinOperator_STRIDED_SLICE, 6}, "2.6.0"}, - {{BuiltinOperator_STRIDED_SLICE, 7}, "2.14.0"}, - {{BuiltinOperator_STRIDED_SLICE, 8}, "2.14.0"}, - {{BuiltinOperator_TOPK_V2, 1}, "1.7.0"}, - {{BuiltinOperator_TOPK_V2, 2}, "1.14.0"}, - {{BuiltinOperator_TOPK_V2, 3}, "2.13.0"}, - {{BuiltinOperator_ARG_MAX, 1}, "1.9.0"}, - {{BuiltinOperator_ARG_MAX, 2}, "1.14.0"}, - {{BuiltinOperator_ARG_MAX, 3}, "2.9.0"}, - {{BuiltinOperator_ARG_MIN, 1}, "1.9.0"}, - {{BuiltinOperator_ARG_MIN, 2}, "1.14.0"}, - {{BuiltinOperator_ARG_MIN, 3}, "2.9.0"}, - {{BuiltinOperator_TRANSPOSE_CONV, 1}, "1.9.0"}, - {{BuiltinOperator_TRANSPOSE_CONV, 2}, "2.2.0"}, - {{BuiltinOperator_TRANSPOSE_CONV, 3}, "2.3.0"}, - {{BuiltinOperator_TRANSPOSE_CONV, 4}, "2.13.0"}, - {{BuiltinOperator_TRANSPOSE_CONV, 5}, "2.15.0"}, - {{BuiltinOperator_SPARSE_TO_DENSE, 1}, "1.9.0"}, - {{BuiltinOperator_SPARSE_TO_DENSE, 2}, "1.14.0"}, - {{BuiltinOperator_SPARSE_TO_DENSE, 3}, "1.15.0"}, - {{BuiltinOperator_EXPAND_DIMS, 1}, "1.10.0"}, - {{BuiltinOperator_PACK, 1}, "1.11.0"}, - {{BuiltinOperator_PACK, 2}, "1.14.0"}, - {{BuiltinOperator_PACK, 3}, "2.3.0"}, - {{BuiltinOperator_PACK, 4}, "2.13.0"}, - {{BuiltinOperator_SHAPE, 1}, "1.10.0"}, - {{BuiltinOperator_SLICE, 1}, "1.14.0"}, - {{BuiltinOperator_SLICE, 2}, "1.14.0"}, - {{BuiltinOperator_SLICE, 3}, "1.14.0"}, - {{BuiltinOperator_SLICE, 4}, "2.4.0"}, - {{BuiltinOperator_SLICE, 5}, "2.5.0"}, - {{BuiltinOperator_SLICE, 6}, "2.14.0"}, - {{BuiltinOperator_TANH, 1}, "1.14.0"}, - {{BuiltinOperator_TANH, 2}, "1.14.0"}, - {{BuiltinOperator_TANH, 3}, "2.3.0"}, - {{BuiltinOperator_ONE_HOT, 1}, "1.11.0"}, - {{BuiltinOperator_UNPACK, 1}, "1.11.0"}, - {{BuiltinOperator_UNPACK, 2}, "1.14.0"}, - {{BuiltinOperator_UNPACK, 3}, "2.2.0"}, - {{BuiltinOperator_UNPACK, 4}, "2.3.0"}, - {{BuiltinOperator_LEAKY_RELU, 1}, "1.13.1"}, - {{BuiltinOperator_LEAKY_RELU, 2}, "2.3.0"}, - {{BuiltinOperator_LOGISTIC, 1}, "1.14.0"}, - {{BuiltinOperator_LOGISTIC, 2}, "1.14.0"}, - {{BuiltinOperator_LOGISTIC, 3}, "2.3.0"}, - {{BuiltinOperator_LOG_SOFTMAX, 1}, "1.14.0"}, - {{BuiltinOperator_LOG_SOFTMAX, 2}, "1.14.0"}, - {{BuiltinOperator_LSH_PROJECTION, 1}, "1.5.0"}, - {{BuiltinOperator_SQUARED_DIFFERENCE, 1}, "1.13.1"}, - {{BuiltinOperator_SQUARED_DIFFERENCE, 2}, "2.5.0"}, - {{BuiltinOperator_MIRROR_PAD, 1}, "1.13.1"}, - {{BuiltinOperator_MIRROR_PAD, 2}, "2.3.0"}, - {{BuiltinOperator_MIRROR_PAD, 3}, "2.12.0"}, - {{BuiltinOperator_UNIQUE, 1}, "1.14.0"}, - {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, 1}, "1.14.0"}, - {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, 2}, "1.14.0"}, - {{BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, 3}, "2.3.0"}, - {{BuiltinOperator_WHERE, 1}, "1.14.0"}, - {{BuiltinOperator_DEQUANTIZE, 1}, "1.13.1"}, - {{BuiltinOperator_DEQUANTIZE, 2}, "1.14.0"}, - {{BuiltinOperator_DEQUANTIZE, 3}, "1.15.0"}, - {{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"}, - {{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"}, - {{BuiltinOperator_DEQUANTIZE, 6}, "2.18.0"}, - {{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"}, - {{BuiltinOperator_EQUAL, 1}, "1.14.0"}, - {{BuiltinOperator_EQUAL, 2}, "1.14.0"}, - {{BuiltinOperator_EQUAL, 3}, "2.3.0"}, - {{BuiltinOperator_EQUAL, 4}, "2.13.0"}, - {{BuiltinOperator_NOT_EQUAL, 1}, "1.14.0"}, - {{BuiltinOperator_NOT_EQUAL, 2}, "1.14.0"}, - {{BuiltinOperator_NOT_EQUAL, 3}, "2.3.0"}, - {{BuiltinOperator_GREATER, 1}, "1.14.0"}, - {{BuiltinOperator_GREATER, 2}, "1.14.0"}, - {{BuiltinOperator_GREATER_EQUAL, 1}, "1.14.0"}, - {{BuiltinOperator_GREATER_EQUAL, 2}, "1.14.0"}, - {{BuiltinOperator_GREATER_EQUAL, 3}, "2.13.0"}, - {{BuiltinOperator_LESS, 1}, "1.14.0"}, - {{BuiltinOperator_LESS, 2}, "1.14.0"}, - {{BuiltinOperator_LESS, 3}, "2.13.0"}, - {{BuiltinOperator_LESS_EQUAL, 1}, "1.14.0"}, - {{BuiltinOperator_LESS_EQUAL, 2}, "1.14.0"}, - {{BuiltinOperator_SCATTER_ND, 1}, "2.1.0"}, - {{BuiltinOperator_SEGMENT_SUM, 1}, "2.2.0"}, - {{BuiltinOperator_SELECT, 1}, "1.14.0"}, - {{BuiltinOperator_SELECT, 2}, "1.14.0"}, - {{BuiltinOperator_SELECT, 3}, "2.12.0"}, - {{BuiltinOperator_SELECT, 4}, "2.12.0"}, - {{BuiltinOperator_SELECT_V2, 1}, "2.2.0"}, - {{BuiltinOperator_SELECT_V2, 2}, "2.12.0"}, - {{BuiltinOperator_IF, 1}, "1.15.0"}, - {{BuiltinOperator_FLOOR_DIV, 1}, "1.14.0"}, - {{BuiltinOperator_FLOOR_DIV, 2}, "1.14.0"}, - {{BuiltinOperator_FLOOR_DIV, 3}, "2.13.0"}, - {{BuiltinOperator_FLOOR, 1}, "1.9.0"}, - {{BuiltinOperator_CEIL, 1}, "1.14.0"}, - {{BuiltinOperator_MATRIX_DIAG, 1}, "1.14.0"}, - {{BuiltinOperator_MATRIX_SET_DIAG, 1}, "1.14.0"}, - {{BuiltinOperator_ELU, 1}, "1.14.0"}, - {{BuiltinOperator_QUANTIZE, 1}, "1.14.0"}, - {{BuiltinOperator_QUANTIZE, 2}, "1.15.0"}, - {{BuiltinOperator_QUANTIZE, 3}, "2.7.0"}, - {{BuiltinOperator_ROUND, 1}, "1.14.0"}, - {{BuiltinOperator_RELU, 1}, "1.5.0"}, - {{BuiltinOperator_RELU, 2}, "2.1.0"}, - {{BuiltinOperator_RELU, 3}, "2.5.0"}, - {{BuiltinOperator_RELU_N1_TO_1, 1}, "1.5.0"}, - {{BuiltinOperator_RELU_0_TO_1, 1}, "2.10.0"}, - {{BuiltinOperator_PRELU, 1}, "1.8.0"}, - {{BuiltinOperator_EXP, 1}, "1.7.0"}, - {{BuiltinOperator_EXP, 2}, "2.12.0"}, - {{BuiltinOperator_COS, 1}, "1.14.0"}, - {{BuiltinOperator_NEG, 1}, "1.9.0"}, - {{BuiltinOperator_POW, 1}, "1.10.0"}, - {{BuiltinOperator_LOGICAL_OR, 1}, "1.11.0"}, - {{BuiltinOperator_LOGICAL_AND, 1}, "1.11.0"}, - {{BuiltinOperator_LOGICAL_NOT, 1}, "1.11.0"}, - {{BuiltinOperator_FLOOR_MOD, 1}, "1.13.0"}, - {{BuiltinOperator_FLOOR_MOD, 2}, "2.13.0"}, - {{BuiltinOperator_RANGE, 1}, "1.13.0"}, - {{BuiltinOperator_RANGE, 2}, "2.14.0"}, - {{BuiltinOperator_SIN, 1}, "1.9.0"}, - {{BuiltinOperator_LOG, 1}, "1.14.0"}, - {{BuiltinOperator_LOG, 2}, "2.15.0"}, - {{BuiltinOperator_SQRT, 1}, "1.10.0"}, - {{BuiltinOperator_RSQRT, 1}, "1.10.0"}, - {{BuiltinOperator_RSQRT, 2}, "2.5.0"}, - {{BuiltinOperator_RSQRT, 3}, "2.15.0"}, - {{BuiltinOperator_SQUARE, 1}, "1.12.0"}, - {{BuiltinOperator_ZEROS_LIKE, 1}, "1.12.0"}, - {{BuiltinOperator_ABS, 1}, "1.13.0"}, - {{BuiltinOperator_ABS, 2}, "2.4.0"}, - {{BuiltinOperator_ABS, 3}, "2.5.0"}, - {{BuiltinOperator_ABS, 4}, "2.6.0"}, - {{BuiltinOperator_ABS, 5}, "2.12.0"}, - {{BuiltinOperator_HARD_SWISH, 1}, "1.15.0"}, - {{BuiltinOperator_FILL, 1}, "1.13.0"}, - {{BuiltinOperator_FILL, 2}, "2.3.0"}, - {{BuiltinOperator_FILL, 3}, "2.5.0"}, - {{BuiltinOperator_FILL, 4}, "2.12.0"}, - {{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"}, - {{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"}, - {{BuiltinOperator_REVERSE_V2, 3}, "2.5.0"}, - {{BuiltinOperator_RANK, 1}, "1.14.0"}, - {{BuiltinOperator_WHILE, 1}, "1.15.0"}, - {{BuiltinOperator_CUMSUM, 1}, "2.4.0"}, - {{BuiltinOperator_CALL_ONCE, 1}, "2.5.0"}, - {{BuiltinOperator_RFFT2D, 1}, "2.5.0"}, - {{BuiltinOperator_CONV_3D, 1}, "2.5.0"}, - {{BuiltinOperator_IMAG, 1}, "2.5.0"}, - {{BuiltinOperator_REAL, 1}, "2.5.0"}, - {{BuiltinOperator_COMPLEX_ABS, 1}, "2.5.0"}, - {{BuiltinOperator_HASHTABLE, 1}, "2.5.0"}, - {{BuiltinOperator_HASHTABLE_FIND, 1}, "2.5.0"}, - {{BuiltinOperator_HASHTABLE_IMPORT, 1}, "2.5.0"}, - {{BuiltinOperator_HASHTABLE_SIZE, 1}, "2.5.0"}, - {{BuiltinOperator_REDUCE_ALL, 1}, "2.6.0"}, - {{BuiltinOperator_CONV_3D_TRANSPOSE, 1}, "2.6.0"}, - {{BuiltinOperator_VAR_HANDLE, 1}, "2.6.0"}, - {{BuiltinOperator_READ_VARIABLE, 1}, "2.6.0"}, - {{BuiltinOperator_ASSIGN_VARIABLE, 1}, "2.6.0"}, - {{BuiltinOperator_BROADCAST_ARGS, 1}, "2.6.0"}, - {{BuiltinOperator_RANDOM_STANDARD_NORMAL, 1}, "2.8.0"}, - {{BuiltinOperator_BUCKETIZE, 1}, "2.8.0"}, - {{BuiltinOperator_WHERE, 2}, "2.8.0"}, - {{BuiltinOperator_RANDOM_UNIFORM, 1}, "2.8.0"}, - {{BuiltinOperator_MULTINOMIAL, 1}, "2.8.0"}, - {{BuiltinOperator_GELU, 1}, "2.9.0"}, - {{BuiltinOperator_GELU, 2}, "2.9.0"}, - {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 1}, "2.9.0"}, - {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 2}, "2.17.0"}, - {{BuiltinOperator_UNSORTED_SEGMENT_PROD, 1}, "2.10.0"}, - {{BuiltinOperator_UNSORTED_SEGMENT_MAX, 1}, "2.10.0"}, - {{BuiltinOperator_UNSORTED_SEGMENT_MIN, 1}, "2.11.0"}, - {{BuiltinOperator_UNSORTED_SEGMENT_SUM, 1}, "2.10.0"}, - {{BuiltinOperator_ATAN2, 1}, "2.10.0"}, - {{BuiltinOperator_SIGN, 1}, "2.11.0"}, - {{BuiltinOperator_SIGN, 2}, "2.12.0"}, - {{BuiltinOperator_BITCAST, 1}, "2.13.0"}, - {{BuiltinOperator_BITWISE_XOR, 1}, "2.13.0"}, - {{BuiltinOperator_RIGHT_SHIFT, 1}, "2.13.0"}, - {{BuiltinOperator_STABLEHLO_SCATTER, 1}, "2.15.0"}, - {{BuiltinOperator_DILATE, 1}, "2.15.0"}, - {{BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR, 1}, "2.15.0"}, - {{BuiltinOperator_REDUCE_WINDOW, 1}, "2.15.0"}, - {{BuiltinOperator_STABLEHLO_GATHER, 1}, "2.16.0"}, - {{BuiltinOperator_STABLEHLO_ADD, 1}, "2.16.0"}, - {{BuiltinOperator_STABLEHLO_MULTIPLY, 1}, "2.16.0"}, - {{BuiltinOperator_STABLEHLO_REDUCE_WINDOW, 1}, "2.16.0"}, - {{BuiltinOperator_STABLEHLO_MAXIMUM, 1}, "2.16.0"}, - {{BuiltinOperator_STABLEHLO_MINIMUM, 1}, "2.16.0"}, - {{BuiltinOperator_STABLEHLO_PAD, 1}, "2.16.0"}, - {{BuiltinOperator_STABLEHLO_COMPOSITE, 1}, "2.17.0"}, - {{BuiltinOperator_STABLEHLO_AND, 1}, "2.17.0"}, - {{BuiltinOperator_STABLEHLO_SHIFT_LEFT, 1}, "2.17.0"}, - {{BuiltinOperator_STABLEHLO_CBRT, 1}, "2.17.0"}}); - - std::pair version_key = {op_code, op_version}; - auto it = op_version_map->find(version_key); - if (it == op_version_map->end()) { - return std::string(); - } - return it->second; -} - -void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer) { - auto model = GetMutableModel(model_buffer_pointer); - std::string model_min_version; - auto subgraphs = model->subgraphs(); - for (int i = 0; i < subgraphs->size(); ++i) { - const SubGraph* subgraph = subgraphs->Get(i); - for (int j = 0; j < subgraph->operators()->size(); ++j) { - const Operator* op = subgraph->operators()->Get(j); - const OperatorCode* op_code = - model->operator_codes()->Get(op->opcode_index()); - std::string runtime_version = FindMinimumRuntimeVersionForOp( - GetBuiltinCode(op_code), op_code->version()); - // If we didn't find the current op version in the map, skip comparison. - if (runtime_version.empty()) { - continue; - } - if (CompareRuntimeVersion(model_min_version, runtime_version)) { - // Current min model runtime version should be bumped if we see a - // higher op version. - model_min_version = runtime_version; - } - } - } - // The size of the `min_runtime_version` metadata buffer is 16 bytes. If the - // generated `model_min_version` is equal or longer than 16 bytes, print a - // warning message and return. - if (model_min_version.size() >= 16) { - LOG(WARNING) << "Skip writing minimum runtime version string since it's " - << "longer than 16 bytes."; - return; - } - // Copy over the bytes from `model_min_version` into the buffer. - for (int i = 0; i < model->metadata()->size(); ++i) { - if (model->metadata()->Get(i)->name()->str() == "min_runtime_version") { - auto buffer = model->metadata()->Get(i)->buffer(); - auto buffer_data = - model->mutable_buffers()->GetMutableObject(buffer)->mutable_data(); - memset(buffer_data->data(), 0, buffer_data->size()); - memcpy(buffer_data->data(), model_min_version.data(), - model_min_version.size()); - break; - } - } -} - -} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h deleted file mode 100644 index 7d586df5ab4c00..00000000000000 --- a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ - -#include -#include - -#include "flatbuffers/flatbuffers.h" // from @flatbuffers // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" - -namespace tflite { -// Update minimum runtime version of the given TFL flatbuffer model. -void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer); - -// Find the minimum runtime version of a given op version. Return an empty -// string the version is not registered. -std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, - int op_version); - -// Returns true if the first version string precedes the second. -// For example, '1.9' should precede '1.14', also '1.14' should precede -// '1.14.1'. If two version string is equal, then false will be returned. -bool CompareRuntimeVersion(const std::string&, const std::string&); - -} // namespace tflite - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 3a66e9f0874a7a..bce9627fbd3381 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -386,7 +386,6 @@ if(TFLITE_ENABLE_GPU) ${TFLITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_SRCS} ${TFLITE_SOURCE_DIR}/tools/versioning/gpu_compatibility.cc ${TFLITE_SOURCE_DIR}/tools/versioning/op_signature.cc - ${TF_SOURCE_DIR}/compiler/mlir/lite/tools/versioning/op_signature.cc ) include_directories( AFTER @@ -685,7 +684,6 @@ set(_ALL_TFLITE_SRCS ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.cc ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.h ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.cc - ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/flatbuffer_conversions.cc ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/verifier.h ${TF_SOURCE_DIR}/compiler/mlir/lite/allocation.h ${TF_SOURCE_DIR}/compiler/mlir/lite/allocation.cc diff --git a/tensorflow/lite/kernels/CMakeLists.txt b/tensorflow/lite/kernels/CMakeLists.txt index b63f7f5b6cc8d6..946b56353e6c15 100644 --- a/tensorflow/lite/kernels/CMakeLists.txt +++ b/tensorflow/lite/kernels/CMakeLists.txt @@ -154,7 +154,6 @@ target_link_libraries(tensorflow-lite-test-external-main tensorflow-lite-test-base -Wl,--no-whole-archive gtest - absl::log ) macro(add_kernel_test TEST_SRC TEST_LIB) diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD index cfcb613719baca..7377ec00d6b666 100644 --- a/tensorflow/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -27,8 +27,6 @@ cc_library( deps = [ ":types", "//tensorflow/compiler/mlir/lite/delegates/flex:allowlisted_flex_ops_lib", - "//tensorflow/compiler/mlir/lite/tools/versioning", - "//tensorflow/compiler/mlir/lite/tools/versioning:op_signature", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/c:c_api_types", @@ -38,6 +36,8 @@ cc_library( "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:runtime", "//tensorflow/lite/toco:toco_port", + "//tensorflow/lite/tools/versioning", + "//tensorflow/lite/tools/versioning:op_signature", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -114,7 +114,6 @@ cc_library( ":types", "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantize_weights", "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", - "//tensorflow/compiler/mlir/lite/tools/versioning", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:status", "//tensorflow/lite:schema_fbs_version", @@ -124,6 +123,7 @@ cc_library( "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:toco_port", "//tensorflow/lite/toco:tooling_util", + "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@flatbuffers", diff --git a/tensorflow/lite/toco/tflite/builtin_operator.h b/tensorflow/lite/toco/tflite/builtin_operator.h index fa4c1df9b9c50c..69a1d4e5970b0b 100644 --- a/tensorflow/lite/toco/tflite/builtin_operator.h +++ b/tensorflow/lite/toco/tflite/builtin_operator.h @@ -18,7 +18,6 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" #include "tensorflow/lite/toco/tflite/operator.h" namespace toco { diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index a1ca936464e95a..44223eac63c130 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -25,7 +25,6 @@ limitations under the License. #include "flatbuffers/string.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" -#include "tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -34,6 +33,7 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" +#include "tensorflow/lite/tools/versioning/runtime_version.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index c73e30781faf09..428afa0a9076e2 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -31,8 +31,6 @@ limitations under the License. // graph_transformation module. #include "tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops.h" -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -44,6 +42,8 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/simple_operator.h" #include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/lite/toco/toco_types.h" +#include "tensorflow/lite/tools/versioning/op_signature.h" +#include "tensorflow/lite/tools/versioning/op_version.h" namespace toco { diff --git a/tensorflow/lite/toco/tflite/operator.h b/tensorflow/lite/toco/tflite/operator.h index 836c287674e084..7b8b6b64e21e83 100644 --- a/tensorflow/lite/toco/tflite/operator.h +++ b/tensorflow/lite/toco/tflite/operator.h @@ -22,9 +22,10 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/tools/versioning/op_signature.h" +#include "tensorflow/lite/tools/versioning/op_version.h" namespace toco { diff --git a/tensorflow/lite/toco/tflite/simple_operator.h b/tensorflow/lite/toco/tflite/simple_operator.h index 7f26ee2eaac339..150b0d0721706e 100644 --- a/tensorflow/lite/toco/tflite/simple_operator.h +++ b/tensorflow/lite/toco/tflite/simple_operator.h @@ -15,7 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ #define TENSORFLOW_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_version.h" #include "tensorflow/lite/toco/tflite/operator.h" namespace toco { diff --git a/tensorflow/lite/tools/benchmark/CMakeLists.txt b/tensorflow/lite/tools/benchmark/CMakeLists.txt index 79986f6c0ecd29..477254a42ecb9e 100644 --- a/tensorflow/lite/tools/benchmark/CMakeLists.txt +++ b/tensorflow/lite/tools/benchmark/CMakeLists.txt @@ -52,7 +52,6 @@ list(APPEND TFLITE_BENCHMARK_LIBS example_proto model_runtime_info_proto protobuf::libprotobuf - absl::log ) # TODO(b/171007016): Enable performance options on Windows. diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD index c0945a5c7c0fa6..f173ce2c89734b 100644 --- a/tensorflow/lite/tools/versioning/BUILD +++ b/tensorflow/lite/tools/versioning/BUILD @@ -66,7 +66,6 @@ cc_library( ], compatible_with = get_compatible_with_portable(), deps = [ - "//tensorflow/compiler/mlir/lite/tools/versioning:op_signature", "//tensorflow/lite:stderr_reporter", "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:c_api_types", diff --git a/tensorflow/lite/tools/versioning/op_signature.cc b/tensorflow/lite/tools/versioning/op_signature.cc index 19373155956f29..64b97924b47352 100644 --- a/tensorflow/lite/tools/versioning/op_signature.cc +++ b/tensorflow/lite/tools/versioning/op_signature.cc @@ -14,18 +14,88 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/tools/versioning/op_signature.h" -#include -#include -#include +#include -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" -#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" +#include "tensorflow/lite/stderr_reporter.h" namespace tflite { namespace { +// A BuiltinDataAllocator which just uses malloc()/free(). +class MallocDataAllocator : public BuiltinDataAllocator { + public: + void* Allocate(size_t size, size_t alignment_hint) override { + return malloc(size); + } + void Deallocate(void* data) override { free(data); } +}; + +// Get the number of dimensions of a tensor with idx of an operator op. +inline int GetNumDims(const SubGraph* subgraph, const Operator* op, int idx) { + const flatbuffers::Vector* ret = + subgraph->tensors()->Get(op->inputs()->Get(idx))->shape(); + if (ret) { + return ret->size(); + } else { + return 0; + } +} + +std::vector GetOpSignatureTensorSpecs( + const flatbuffers::Vector* tensors, const SubGraph* subgraph, + const Model* model) { + std::vector tensor_specs; + if (!tensors) { + return tensor_specs; + } + StderrReporter error_reporter; + + for (int32_t i = 0; i < tensors->Length(); ++i) { + int32_t tensor_no = tensors->Get(i); + + OpSignatureTensorSpec tensor_spec = {kTfLiteNoType}; + if (tensor_no >= 0) { + if (subgraph->tensors() && tensor_no < subgraph->tensors()->Length()) { + auto* fb_tensor = subgraph->tensors()->Get(tensor_no); + ConvertTensorType(fb_tensor->type(), &tensor_spec.type, + &error_reporter); + auto buffer_idx = fb_tensor->buffer(); + // Check if the tensor is a constant tensor. + if (buffer_idx != 0 && buffer_idx < model->buffers()->Length()) { + auto* buffer = model->buffers()->Get(buffer_idx); + if (buffer->data() && buffer->data()->size() != 0) { + tensor_spec.is_const = true; + } + } + const flatbuffers::Vector* shape_vec = fb_tensor->shape(); + if (shape_vec) { + for (int32_t j = 0; j < shape_vec->Length(); ++j) { + tensor_spec.dims.push_back(shape_vec->Get(j)); + } + } + const flatbuffers::Vector* shape_signature_vec = + fb_tensor->shape_signature(); + tensor_spec.is_shape_dynamic = false; + if (shape_signature_vec) { + for (int32_t j = 0; j < shape_signature_vec->Length(); ++j) { + if (shape_signature_vec->Get(j) == -1) { + tensor_spec.is_shape_dynamic = true; + break; + } + } + } + } + } + tensor_specs.push_back(tensor_spec); + } + return tensor_specs; +} + std::vector GetOpSignatureTensorSpecs( TfLiteIntArray* tensors, const TfLiteContext* context, const TfLiteNode* tflite_node) { @@ -60,6 +130,167 @@ std::vector GetOpSignatureTensorSpecs( } // namespace +OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, + const SubGraph* subgraph, const Model* model) { + auto builtin_code = GetBuiltinCode(op_code); + OpSignature op_sig = {builtin_code}; + std::memset(&op_sig.ext_options, 0, sizeof(op_sig.ext_options)); + + if (builtin_code != BuiltinOperator_CUSTOM) { + StderrReporter error_reporter; + MallocDataAllocator allocator; + ParseOpData(op, builtin_code, &error_reporter, &allocator, + &op_sig.builtin_data); + } else { + op_sig.custom_name = op_code->custom_code()->str(); + } + + switch (builtin_code) { + case BuiltinOperator_DEPTHWISE_CONV_2D: { + const Tensor* filter_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + const QuantizationParameters* filter_quant = + filter_tensor->quantization(); + int num_channels = filter_tensor->shape()->Get(3); + if (filter_quant && filter_quant->scale() && + filter_quant->scale()->Length() && + filter_quant->scale()->Length() == num_channels) { + op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized = true; + } + } break; + + case BuiltinOperator_FULLY_CONNECTED: { + const Tensor* weight_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + op_sig.ext_options.fully_connected.sparse_weight = + (weight_tensor->sparsity() != nullptr); + const QuantizationParameters* weight_quant = + weight_tensor->quantization(); + if (weight_quant && weight_quant->scale() && + weight_quant->scale()->size() && weight_tensor->shape() && + weight_tensor->shape()->size()) { + op_sig.ext_options.fully_connected.is_per_channel_quantized = + weight_quant->scale()->size() > 1 && + weight_quant->scale()->size() == weight_tensor->shape()->Get(0); + } + } break; + + case BuiltinOperator_MUL: { + if (op->inputs()->Length() < 2 || op->outputs()->Length() < 1) { + break; + } + const Tensor* input1_tensor = + subgraph->tensors()->Get(op->inputs()->Get(0)); + const Tensor* input2_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + const Tensor* output_tensor = + subgraph->tensors()->Get(op->outputs()->Get(0)); + const QuantizationParameters* input1_quant = + input1_tensor->quantization(); + const QuantizationParameters* input2_qunt = input2_tensor->quantization(); + const QuantizationParameters* output_quant = + output_tensor->quantization(); + if (input1_quant && input1_quant->scale() && + input1_quant->scale()->Length() && input2_qunt && + input2_qunt->scale() && input2_qunt->scale()->Length() && + output_quant && output_quant->scale() && + output_quant->scale()->Length()) { + op_sig.ext_options.mul.input1_scale = input1_quant->scale()->Get(0); + op_sig.ext_options.mul.input2_scale = input2_qunt->scale()->Get(0); + op_sig.ext_options.mul.output_scale = output_quant->scale()->Get(0); + } + if (input1_quant || input2_qunt) { + op_sig.ext_options.mul.input_quantized = true; + } + } break; + + case BuiltinOperator_CONV_2D: { + const Tensor* input_tensor = + subgraph->tensors()->Get(op->inputs()->Get(0)); + const Tensor* filter_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + const QuantizationParameters* filter_quant = + filter_tensor->quantization(); + int num_filters = filter_tensor->shape()->Get(0); + if (filter_quant && filter_quant->scale() && + filter_quant->scale()->Length() && + filter_quant->scale()->Length() == num_filters) { + op_sig.ext_options.conv_2d.is_per_channel_quantized = true; + } + if (input_tensor->shape() && input_tensor->shape()->size()) { + int num_input_channels = input_tensor->shape()->Get(3); + int num_filter_input_channels = filter_tensor->shape()->Get(3); + op_sig.ext_options.conv_2d.is_grouped_convolution = + num_input_channels != num_filter_input_channels; + } else { + op_sig.ext_options.conv_2d.is_grouped_convolution = false; + } + } break; + + case BuiltinOperator_STRIDED_SLICE: { + op_sig.ext_options.strided_slice.num_dims = GetNumDims(subgraph, op, 0); + } break; + + case BuiltinOperator_ABS: { + if (subgraph->tensors()->Get(op->inputs()->Get(0))->quantization()) { + op_sig.ext_options.abs.input_quantized = true; + } + } break; + + case BuiltinOperator_DEQUANTIZE: { + const Tensor* input_tensor = + subgraph->tensors()->Get(op->inputs()->Get(0)); + const QuantizationParameters* input_quant = input_tensor->quantization(); + if (input_quant && input_quant->scale() && + input_quant->scale()->Length() > 1 && + input_quant->scale()->Length() == + input_tensor->shape()->Get(input_quant->quantized_dimension())) { + op_sig.ext_options.dequantize.is_per_channel_quantized = true; + } + } break; + + case BuiltinOperator_QUANTIZE: { + const Tensor* output_tensor = + subgraph->tensors()->Get(op->outputs()->Get(0)); + const QuantizationParameters* output_quant = + output_tensor->quantization(); + if (output_quant && output_quant->scale() && + output_quant->scale()->Length() > 1 && + output_quant->scale()->Length() == + output_tensor->shape()->Get( + output_quant->quantized_dimension())) { + op_sig.ext_options.quantize.is_per_channel_quantized = true; + } + } break; + + case BuiltinOperator_ADD: { + if (subgraph->tensors()->Get(op->inputs()->Get(0))->quantization()) { + op_sig.ext_options.add.input_quantized = true; + } + } break; + + case BuiltinOperator_EMBEDDING_LOOKUP: { + const Tensor* table_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + const QuantizationParameters* table_quant = table_tensor->quantization(); + if (table_quant && table_quant->scale() && table_quant->scale()->size() && + table_tensor->shape() && table_tensor->shape()->size()) { + op_sig.ext_options.embedding_lookup.is_per_channel_quantized = + table_quant->scale()->size() > 1 && + table_quant->scale()->size() == table_tensor->shape()->Get(0); + } + } break; + + default: + break; + } + + op_sig.inputs = GetOpSignatureTensorSpecs(op->inputs(), subgraph, model); + op_sig.outputs = GetOpSignatureTensorSpecs(op->outputs(), subgraph, model); + op_sig.version = op_code->version(); + return op_sig; +} + OpSignature GetOpSignature(const TfLiteContext* context, const TfLiteNode* node, const TfLiteRegistration* registration) { OpSignature op_sig = { diff --git a/tensorflow/lite/tools/versioning/op_signature.h b/tensorflow/lite/tools/versioning/op_signature.h index b2dd3086c2d0d4..6f83d119d5938f 100644 --- a/tensorflow/lite/tools/versioning/op_signature.h +++ b/tensorflow/lite/tools/versioning/op_signature.h @@ -15,11 +15,83 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ #define TENSORFLOW_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ -#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" // iwyu pragma: export +#include +#include + +#include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { +// OpSignature contains operator parameters for version functions. +typedef struct { + TfLiteType type; + std::vector dims; + bool is_const; + bool is_shape_dynamic; +} OpSignatureTensorSpec; + +typedef struct { + BuiltinOperator op; + std::vector inputs; + std::vector outputs; + void* builtin_data; + int version; + const void* custom_initial_data; + std::string custom_name; + union { + struct { + bool is_per_channel_quantized; + bool is_grouped_convolution; + } conv_2d; + struct { + bool is_per_channel_quantized; + } depthwise_conv_2d; + struct { + // TODO(b/156530611): Make this global when more ops support sparse + // computation. + bool sparse_weight; + bool is_per_channel_quantized; + } fully_connected; + struct { + float input1_scale; + float input2_scale; + float output_scale; + bool input_quantized; + } mul; + struct { + int32_t num_dims; + } strided_slice; + struct { + bool input_quantized; + } abs; + struct { + bool is_per_channel_quantized; + } dequantize; + struct { + bool is_per_channel_quantized; + } quantize; + struct { + bool input_quantized; + } add; + struct { + bool is_per_channel_quantized; + } embedding_lookup; + } ext_options; +} OpSignature; + +// Generate OpSignature with the given OperatorCode, Operator and Tensors (from +// SubGraph). The OpSignature will be used by GetBuiltinOperatorVersion() and +// mostly input and output tensor types are enough to figure out op version. +// But some ops (DEPTHWISE_CONV_2D, FULLY_CONNECTED, ...) require to pass their +// options to decide op version. +// +// WARNING: The caller is responsible to free the allocated +// OpSignature.builtin_data memory. +OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, + const SubGraph* subgraph, const Model* model); + // Generate OpSignature with the given TfLiteContext, TfLiteNode and // TfLiteRegistration. // The function can be used by a compatibility checker of a delegate such as From 01d4268b1d0803993e80d969c779998598ed2ccb Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Tue, 24 Sep 2024 18:34:00 -0700 Subject: [PATCH 217/483] Initialize TfLiteTensor sparsity in SimpleConstTensor PiperOrigin-RevId: 678473333 --- tensorflow/lite/testing/matchers.h | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/testing/matchers.h b/tensorflow/lite/testing/matchers.h index e32576c941cac2..17646ffb811eb4 100644 --- a/tensorflow/lite/testing/matchers.h +++ b/tensorflow/lite/testing/matchers.h @@ -257,6 +257,7 @@ struct SimpleConstTensor : public TfLiteTensor { std::memcpy(dims->data, shape.data(), shape.size() * sizeof(int)); data = {.data = buf.data()}; bytes = buf.size() * sizeof(T); + sparsity = nullptr; } ~SimpleConstTensor() { TfLiteIntArrayFree(dims); } }; From 99bad4d415350fd0489628a85552a09e73885ad9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 18:44:00 -0700 Subject: [PATCH 218/483] Load the builtin Bazel java rules from @rules_java PiperOrigin-RevId: 678475673 --- tensorflow/lite/delegates/flex/test/BUILD | 1 + tensorflow/lite/java/ovic/BUILD | 1 + .../lite/java/src/testhelper/java/org/tensorflow/lite/BUILD | 1 + 3 files changed, 3 insertions(+) diff --git a/tensorflow/lite/delegates/flex/test/BUILD b/tensorflow/lite/delegates/flex/test/BUILD index 65467a9a84f903..92f81d68892b69 100644 --- a/tensorflow/lite/delegates/flex/test/BUILD +++ b/tensorflow/lite/delegates/flex/test/BUILD @@ -1,5 +1,6 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework") +load("@rules_java//java:defs.bzl", "java_library", "java_test") load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_jni_library") diff --git a/tensorflow/lite/java/ovic/BUILD b/tensorflow/lite/java/ovic/BUILD index e36c77b7369780..a6ce1d4a07aeea 100644 --- a/tensorflow/lite/java/ovic/BUILD +++ b/tensorflow/lite/java/ovic/BUILD @@ -2,6 +2,7 @@ # OVIC Benchmarker Java API. load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@rules_java//java:defs.bzl", "java_binary", "java_library", "java_test") load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") package( diff --git a/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD index 365942d6490601..894858dbd4e022 100644 --- a/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD +++ b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -2,6 +2,7 @@ # Internal helper function to test TF Lite API. load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@rules_java//java:defs.bzl", "java_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], From bcde12724b6e3b42ee6f57d043964221925001e5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 21:40:13 -0700 Subject: [PATCH 219/483] Automated Code Change PiperOrigin-RevId: 678521406 --- .../tests/fuse_binary_into_following_affine_test.cc | 1 - .../tests/fuse_binary_into_preceding_affine_test.cc | 1 - .../tests/identify_l2_normalization_test.cc | 2 -- .../toco/graph_transformations/tests/identify_l2_pool_test.cc | 2 -- .../lite/toco/graph_transformations/tests/lstm_utils_test.cc | 1 - .../tests/remove_successive_transpose_test.cc | 2 -- .../tests/resolve_constant_concatenation_test.cc | 1 - .../graph_transformations/tests/resolve_constant_unary_test.cc | 2 -- .../toco/graph_transformations/tests/unpack_quantize_test.cc | 1 - 9 files changed, 13 deletions(-) diff --git a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_following_affine_test.cc b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_following_affine_test.cc index 518a6832066a3c..a66ad270a3d347 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_following_affine_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_following_affine_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc index 314521b6ab2711..35888667d4b3c9 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc index 4c55b7d6dcbb06..c21118f4df7e2e 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc @@ -16,10 +16,8 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc index bfb2acf3aa6fe8..ab487b4cf3bb28 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc @@ -16,10 +16,8 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc b/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc index e18b2a8a486423..ae9006af978237 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc b/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc index 746a7579d41a57..561ca830fcb34b 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc @@ -16,11 +16,9 @@ limitations under the License. #include #include -#include #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace { diff --git a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 4ecd9c992bb058..405c79b8d52c40 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc index 136c14cad9b834..af26eef7ff6922 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc @@ -19,10 +19,8 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc b/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc index b7302051043052..3a22849b949955 100755 --- a/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { From 5329ec8dd396487982ef3e743f98c0195af39a6b Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 24 Sep 2024 22:42:50 -0700 Subject: [PATCH 220/483] Fix `tsl/platform/cloud:curl_http_request_test` after breakage PiperOrigin-RevId: 678538860 --- .../xla/third_party/tsl/tsl/platform/cloud/BUILD | 2 ++ .../tsl/tsl/platform/cloud/curl_http_request_test.cc | 10 +++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD index 46bca19c70940d..42d9a7985119cd 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD @@ -437,12 +437,14 @@ tsl_cc_test( srcs = ["curl_http_request_test.cc"], deps = [ ":curl_http_request", + "//tsl/platform", "//tsl/platform:env_impl", "//tsl/platform:path", "//tsl/platform:platform_port", "//tsl/platform:test", "//tsl/platform:test_main", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc index 9cc1d8e075e0b8..429006a3724bdc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc @@ -19,9 +19,11 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/mem.h" #include "tsl/platform/path.h" +#include "tsl/platform/platform.h" #include "tsl/platform/test.h" namespace tsl { @@ -495,9 +497,11 @@ TEST(CurlHttpRequestTest, GetRequest_CouldntResolveHost) { const auto& status = http_request.Send(); EXPECT_EQ(error::FAILED_PRECONDITION, status.code()); EXPECT_EQ( - "Error executing an HTTP request: libcurl code 6 meaning " - "'Could not resolve hostname', error details: Could not resolve host " - "'metadata'", + absl::StrCat( + "Error executing an HTTP request: libcurl code 6 meaning ", + (kIsOpenSource ? "'Couldn't resolve host name', error details: " + : "'Could not resolve hostname', error details: "), + "Could not resolve host ", "'metadata'"), status.message()); EXPECT_EQ(0, http_request.GetResponseCode()); } From f2ec1f22401086d96a6ce8a401bd5f4ed883addb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 22:52:14 -0700 Subject: [PATCH 221/483] Automated Code Change PiperOrigin-RevId: 678541202 --- tensorflow/python/BUILD | 4 ++++ tensorflow/python/mlir_wrapper.cc | 1 + 2 files changed, 5 insertions(+) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 0ad8e84b889f18..4331c4a77d8a21 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1097,8 +1097,12 @@ tf_python_pybind_extension( "_pywrap_mlir.pyi", ], deps = [ + "//tensorflow/compiler/tf2tensorrt:common_utils", + "//tensorflow/compiler/tf2tensorrt:trt_parameters", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "//tensorflow/core/platform:statusor", "//tensorflow/python/lib/core:pybind11_lib", "//tensorflow/python/lib/core:pybind11_status", "//tensorflow/python/lib/core:safe_pyobject_ptr", diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc index 8fe71b77ef11c9..4846a089b35ddf 100644 --- a/tensorflow/python/mlir_wrapper.cc +++ b/tensorflow/python/mlir_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 +#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/safe_ptr.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/python/mlir.h" From 84bbb81cb98159bdeb3895628b6e063e2bf38f9d Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Tue, 24 Sep 2024 22:59:59 -0700 Subject: [PATCH 222/483] [IFRT] Add DeviceList::AddressableDeviceList() This change adds `DeviceList::AddressableDeviceList()`, which returns a `DeviceList` that contains only addressable devices from the original `DeviceList`. It returns itself if the original `DeviceList` contains only addressable devices. `BasicDeviceList` gets a tsan-friendly implementation of `AddressableDeviceList()` that lazily computes it. It also removes internal state sharing between `BasicDeviceList` objects, as `BasicDeviceList` is no longer copyable since `DeviceList` is wrapped with `tsl::RCReference<>`. PiperOrigin-RevId: 678543054 --- third_party/xla/xla/python/ifrt/BUILD | 7 +- .../xla/xla/python/ifrt/device_list.cc | 37 +++++++--- third_party/xla/xla/python/ifrt/device_list.h | 69 +++++++------------ .../{device_test.cc => device_list_test.cc} | 66 ++++++++++++++---- 4 files changed, 109 insertions(+), 70 deletions(-) rename third_party/xla/xla/python/ifrt/{device_test.cc => device_list_test.cc} (58%) diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 9380b8c811db8e..6665d25c6cc013 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -107,6 +107,7 @@ cc_library( "//xla/tsl/concurrency:ref_count", "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_set", @@ -517,15 +518,15 @@ tf_proto_library( ) xla_cc_test( - name = "device_test", + name = "device_list_test", size = "small", - srcs = ["device_test.cc"], + srcs = ["device_list_test.cc"], deps = [ ":device_proto_cc", ":device_test_util", ":ifrt", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", diff --git a/third_party/xla/xla/python/ifrt/device_list.cc b/third_party/xla/xla/python/ifrt/device_list.cc index fdd58588cf6516..35b37b5ec1a1dd 100644 --- a/third_party/xla/xla/python/ifrt/device_list.cc +++ b/third_party/xla/xla/python/ifrt/device_list.cc @@ -17,16 +17,17 @@ limitations under the License. #include #include -#include #include #include #include +#include "absl/base/call_once.h" #include "absl/base/optimization.h" #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device.pb.h" #include "xla/tsl/concurrency/ref_count.h" @@ -64,12 +65,32 @@ tsl::RCReference BasicDeviceList::Create(Devices devices) { return tsl::MakeRef(std::move(devices)); } -BasicDeviceList::BasicDeviceList(Devices devices) : hash_(kUnsetHash) { - if (devices.size() <= kInlineDeviceSize) { - state_ = State{std::move(devices)}; - } else { - state_ = std::make_shared(State{std::move(devices)}); - } +BasicDeviceList::BasicDeviceList(Devices devices) + : devices_(std::move(devices)), hash_(kUnsetHash) {} + +DeviceList* BasicDeviceList::AddressableDeviceList() const { + absl::call_once(addressable_device_list_cache_.once_flag, [this] { + Devices addressable_devices; + for (Device* device : devices_) { + if (device->IsAddressable()) { + addressable_devices.push_back(device); + } + } + const bool already_fully_addressable = + addressable_devices.size() == devices_.size(); + if (already_fully_addressable) { + // `device_list_holder` is intentionally unset. We skip storing a + // reference-counted copy in the holder to avoid creating a self cycle. + addressable_device_list_cache_.device_list = + const_cast(this); + } else { + addressable_device_list_cache_.device_list_holder = + BasicDeviceList::Create(std::move(addressable_devices)); + addressable_device_list_cache_.device_list = + addressable_device_list_cache_.device_list_holder.get(); + } + }); + return addressable_device_list_cache_.device_list; } uint64_t BasicDeviceList::hash() const { @@ -86,7 +107,7 @@ uint64_t BasicDeviceList::hash() const { std::string BasicDeviceList::ToString() const { return absl::StrCat("BasicDeviceList([", - absl::StrJoin(state().devices, ",", + absl::StrJoin(devices_, ",", [](std::string* out, Device* device) { absl::StrAppend(out, device->DebugString()); diff --git a/third_party/xla/xla/python/ifrt/device_list.h b/third_party/xla/xla/python/ifrt/device_list.h index b2ddbe221abe3f..f4dc7eb1398129 100644 --- a/third_party/xla/xla/python/ifrt/device_list.h +++ b/third_party/xla/xla/python/ifrt/device_list.h @@ -18,12 +18,10 @@ limitations under the License. #include #include -#include #include -#include -#include #include +#include "absl/base/call_once.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/log/check.h" @@ -34,7 +32,6 @@ limitations under the License. #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device.pb.h" #include "xla/tsl/concurrency/ref_count.h" -#include "xla/tsl/lib/gtl/int_type.h" namespace xla { namespace ifrt { @@ -73,6 +70,13 @@ class DeviceList : public tsl::ReferenceCounted, // Returns a list of `Devices*` represented by this `DeviceList`. virtual absl::Span devices() const = 0; + // Returns a `DeviceList*` containing only addressable devices from this + // `DeviceList`. It returns itself if all devices are addressable. It points + // to a heap-allocated object; the pointer is valid at least until this + // `DeviceList` is destroyed, and it can be persisted beyond this + // `DeviceList`'s lifetime by using `tsl::FormRef()`. + virtual DeviceList* AddressableDeviceList() const = 0; + virtual bool operator==(const DeviceList& other) const = 0; bool operator!=(const DeviceList& other) const { return !(*this == other); } @@ -125,22 +129,20 @@ class BasicDeviceList : public llvm::RTTIExtends { // Returns a `DeviceListProto` representation. DeviceListProto ToProto() const; - absl::Span devices() const override { return state().devices; } + absl::Span devices() const override { return devices_; } + + DeviceList* AddressableDeviceList() const override; bool operator==(const DeviceList& other) const override { + if (this == &other) { + return true; + } const auto* other_basic_device_list = llvm::dyn_cast(&other); if (other_basic_device_list == nullptr) { return false; } - const std::shared_ptr* lhs = - std::get_if>(&state_); - const std::shared_ptr* rhs = - std::get_if>(&other_basic_device_list->state_); - if (lhs != nullptr && rhs != nullptr && lhs->get() == rhs->get()) { - return true; - } - return devices() == other.devices(); + return devices_ == other_basic_device_list->devices_; } uint64_t hash() const override; @@ -153,40 +155,17 @@ class BasicDeviceList : public llvm::RTTIExtends { template friend tsl::RCReference tsl::MakeRef(Args&&... args); - // Internal state that may be shared across `DeviceList` instances. - struct State { - Devices devices; - }; - - State& state() { - return std::visit( - [](auto& state) -> State& { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return state; - } else if constexpr (std::is_same_v>) { - return *state; - } - }, - state_); - } - - const State& state() const { - return std::visit( - [](auto& state) -> const State& { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return state; - } else if constexpr (std::is_same_v>) { - return *state; - } - }, - state_); - } - std::string ToString() const override; - std::variant> state_; + Devices devices_; + + // Addressable device list is dynamically computed and cached. + struct AddressableDeviceListCache { + absl::once_flag once_flag; + DeviceList* device_list = nullptr; + tsl::RCReference device_list_holder; + }; + mutable AddressableDeviceListCache addressable_device_list_cache_; // Cached hash. 0 indicates the hash needs to be computed and cached. // May be written multiple times with the same non-zero value. diff --git a/third_party/xla/xla/python/ifrt/device_test.cc b/third_party/xla/xla/python/ifrt/device_list_test.cc similarity index 58% rename from third_party/xla/xla/python/ifrt/device_test.cc rename to third_party/xla/xla/python/ifrt/device_list_test.cc index 713b9ca3ce5ec1..961015c56a26c4 100644 --- a/third_party/xla/xla/python/ifrt/device_test.cc +++ b/third_party/xla/xla/python/ifrt/device_list_test.cc @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include #include +#include #include #include +#include #include #include "absl/status/statusor.h" -#include "absl/synchronization/blocking_counter.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device.pb.h" -#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/device_test_util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" @@ -35,6 +37,8 @@ namespace xla { namespace ifrt { namespace { +using ::testing::ElementsAreArray; + class DeviceListTest : public test_util::DeviceTest {}; TEST_P(DeviceListTest, ToFromProto) { @@ -48,23 +52,55 @@ TEST_P(DeviceListTest, ToFromProto) { EXPECT_EQ(*device_list_copy, *device_list); } +TEST_P(DeviceListTest, AddressableDevices) { + auto device_list = GetDevices({0, 1}); + std::vector addressable_devices; + for (Device* device : device_list->devices()) { + if (device->IsAddressable()) { + addressable_devices.push_back(device); + } + } + EXPECT_THAT(device_list->AddressableDeviceList()->devices(), + ElementsAreArray(addressable_devices)); +} + +TEST_P(DeviceListTest, AddressableDevicesFromConcurrentCalls) { + auto device_list = GetDevices({0, 1}); + + const int num_threads = 16; + auto thread_pool = std::make_unique( + tsl::Env::Default(), tsl::ThreadOptions(), "test_pool", + std::min(num_threads, tsl::port::MaxParallelism())); + std::vector addressable_device_lists(num_threads); + for (int i = 0; i < num_threads; ++i) { + thread_pool->Schedule([&, i]() { + addressable_device_lists[i] = device_list->AddressableDeviceList(); + // Touch a device in the list so that tsan can verify access to the + // content of the addressable device list. + addressable_device_lists[i]->devices().front()->Id(); + }); + } + + thread_pool.reset(); + for (int i = 0; i < num_threads; ++i) { + EXPECT_EQ(*addressable_device_lists[i], + *device_list->AddressableDeviceList()); + } +} + TEST_P(DeviceListTest, IdenticalHashFromConcurrentCalls) { auto device_list = GetDevices({0, 1}); const int num_threads = 16; - absl::BlockingCounter counter(num_threads); - tsl::thread::ThreadPool thread_pool( + auto thread_pool = std::make_unique( tsl::Env::Default(), tsl::ThreadOptions(), "test_pool", std::min(num_threads, tsl::port::MaxParallelism())); std::vector hashes(num_threads); for (int i = 0; i < num_threads; ++i) { - thread_pool.Schedule([&, i]() { - hashes[i] = device_list->hash(); - counter.DecrementCount(); - }); + thread_pool->Schedule([&, i]() { hashes[i] = device_list->hash(); }); } - counter.Wait(); + thread_pool.reset(); for (int i = 0; i < num_threads; ++i) { EXPECT_EQ(hashes[i], device_list->hash()); } @@ -89,10 +125,12 @@ TEST_P(DeviceListTest, EqualityTest) { EXPECT_NE(*device_list1, *device_list6); } -INSTANTIATE_TEST_SUITE_P(NumDevices, DeviceListTest, - testing::Values(test_util::DeviceTestParam{ - /*num_devices=*/2, - /*num_addressable_devices=*/2})); +INSTANTIATE_TEST_SUITE_P( + NumDevices, DeviceListTest, + testing::Values(test_util::DeviceTestParam{/*num_devices=*/2, + /*num_addressable_devices=*/1}, + test_util::DeviceTestParam{/*num_devices=*/2, + /*num_addressable_devices=*/2})); } // namespace } // namespace ifrt From 4d5cd2a42eb259c7559741d05a14a12e363f0e37 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 24 Sep 2024 23:19:56 -0700 Subject: [PATCH 223/483] Introduce rocm-only tag and remove if_rocm_is_configured This replaces all `if_rocm_is_configured` guards in `stream_executor/rocm/...` with a filtering tag `rocm-only`. The CUDA build on the CI gets adjusted to skip those targets. This uncovered some additional problems that get fixed as well: - A wrong library name for the hipfft library in the Bazel CUDA configuration - A wrong test case in the RocmVersionParser test that so far has not been running anywhere. - Missing tags for the platform alias targets in `stream_executor/BUILD` PiperOrigin-RevId: 678548961 --- third_party/gpus/rocm_configure.bzl | 2 +- .../workflows/bazel_dependency_violations.yml | 2 +- third_party/xla/build_tools/ci/build.py | 4 +- .../xla/build_tools/ci/golden_commands.txt | 2 +- .../xla/build_tools/dependencies/aspects.bzl | 8 + third_party/xla/build_tools/lint/tags.py | 1 + .../tsl/third_party/gpus/rocm_configure.bzl | 2 +- .../xla/xla/backends/profiler/gpu/BUILD | 28 +- third_party/xla/xla/stream_executor/BUILD | 11 + .../xla/xla/stream_executor/rocm/BUILD | 499 ++++++++++++------ .../rocm/rocm_version_parser_test.cc | 2 +- 11 files changed, 374 insertions(+), 187 deletions(-) diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index fb63d4db886c1c..03d350cb6f5d87 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -476,7 +476,7 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": _lib_name("hipfft"), + "%{hipfft_or_rocfft}": "hipfft", "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), diff --git a/third_party/xla/.github/workflows/bazel_dependency_violations.yml b/third_party/xla/.github/workflows/bazel_dependency_violations.yml index 43c85576c2ba38..a0187f232b3b96 100644 --- a/third_party/xla/.github/workflows/bazel_dependency_violations.yml +++ b/third_party/xla/.github/workflows/bazel_dependency_violations.yml @@ -29,7 +29,7 @@ jobs: dependency-violations: strategy: matrix: - tag: [gpu, cuda-only] + tag: [gpu, cuda-only, rocm-only] name: no-${{ matrix.tag }}-targets-in-cpu-build runs-on: ubuntu-22.04 defaults: diff --git a/third_party/xla/build_tools/ci/build.py b/third_party/xla/build_tools/ci/build.py index 7ce80be519c240..7741cab5609096 100755 --- a/third_party/xla/build_tools/ci/build.py +++ b/third_party/xla/build_tools/ci/build.py @@ -213,9 +213,9 @@ def nvidia_gpu_build_with_compute_capability( image_url=_DEFAULT_IMAGE, target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, configs=configs, - test_tag_filters=("-no_oss", "requires-gpu-nvidia", "gpu") + test_tag_filters=("-no_oss", "requires-gpu-nvidia", "gpu", "-rocm-only") + extra_gpu_tags, - build_tag_filters=("-no_oss", "requires-gpu-nvidia", "gpu"), + build_tag_filters=("-no_oss", "requires-gpu-nvidia", "gpu", "-rocm-only"), options={ "run_under": "//tools/ci_build/gpu_build:parallel_gpu_execute", "repo_env": f"TF_CUDA_COMPUTE_CAPABILITIES={compute_capability/10}", diff --git a/third_party/xla/build_tools/ci/golden_commands.txt b/third_party/xla/build_tools/ci/golden_commands.txt index 17cfb64ff950a5..16f398e7e3b261 100644 --- a/third_party/xla/build_tools/ci/golden_commands.txt +++ b/third_party/xla/build_tools/ci/golden_commands.txt @@ -19,7 +19,7 @@ $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html nvidia-smi parallel --ungroup --retries 3 --delay 15 docker pull ::: gcr.io/tensorflow-sigs/build:latest-python3.11 docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/xla gcr.io/tensorflow-sigs/build:latest-python3.11 bash -docker exec xla_ci bazel test --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-amd --config=warnings --config=rbe_linux_cuda_nvcc --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --@cuda_driver//:enable_forward_compatibility=true --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/... //build_tools/... @local_tsl//tsl/... +docker exec xla_ci bazel test --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-amd --config=warnings --config=rbe_linux_cuda_nvcc --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --@cuda_driver//:enable_forward_compatibility=true --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/... //build_tools/... @local_tsl//tsl/... docker exec xla_ci bazel analyze-profile profile.json.gz docker stop xla_ci # END BuildType.GPU diff --git a/third_party/xla/build_tools/dependencies/aspects.bzl b/third_party/xla/build_tools/dependencies/aspects.bzl index 40ea09beca11da..c5e51da52aae7b 100644 --- a/third_party/xla/build_tools/dependencies/aspects.bzl +++ b/third_party/xla/build_tools/dependencies/aspects.bzl @@ -80,3 +80,11 @@ validate_cuda_only_tag = aspect( implementation = _cuda_only_tag_violation_aspect_impl, attr_aspects = ["deps"], ) + +def _rocm_only_tag_violation_aspect_impl(target, ctx): + return _dependency_violation_aspect_impl(target, ctx, "rocm-only") + +validate_rocm_only_tag = aspect( + implementation = _rocm_only_tag_violation_aspect_impl, + attr_aspects = ["deps"], +) diff --git a/third_party/xla/build_tools/lint/tags.py b/third_party/xla/build_tools/lint/tags.py index aa555e7ddf63e1..02331b1bbe41ac 100644 --- a/third_party/xla/build_tools/lint/tags.py +++ b/third_party/xla/build_tools/lint/tags.py @@ -65,6 +65,7 @@ "gpu": "Catch-all tag for targets that should be built/tested on GPU CI", "cpu": "Catch-all tag for targets that should be built/tested on CPU CI.", "cuda-only": "Targets that require the CUDA backend to be enabled.", + "rocm-only": "Targets that require the ROCm backend to be enabled.", # Below tags are generated by `xla_test`. "broken": "Test will be marked with other tags to disable in `xla_test`.", "xla_interpreter": "Uses interpreter backend.", diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl index fb63d4db886c1c..03d350cb6f5d87 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -476,7 +476,7 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": _lib_name("hipfft"), + "%{hipfft_or_rocfft}": "hipfft", "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index c3efcc2bd5a784..36275f118c52d8 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -26,14 +26,6 @@ tsl_gpu_library( name = "device_tracer", srcs = tf_additional_device_tracer_srcs(), copts = tf_profiler_copts() + tsl_copts(), - cuda_deps = [ - ":cupti_buffer_events", - ":cupti_collector", - ":cupti_tracer", - ":cupti_wrapper", - ":rocm_collector", - ":rocm_tracer", - ], deps = [ ":cupti_utils", "//xla/tsl/util:env_var", @@ -47,7 +39,17 @@ tsl_gpu_library( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_tsl//tsl/profiler/utils:time_utils", - ], + ] + if_cuda([ + # keep sorted + ":cupti_buffer_events", + ":cupti_collector", + ":cupti_tracer", + ":cupti_wrapper", + ]) + if_rocm([ + # keep sorted + ":rocm_collector", + ":rocm_tracer", + ]), alwayslink = 1, ) @@ -218,6 +220,10 @@ tsl_gpu_library( srcs = if_rocm(["rocm_collector.cc"]), hdrs = if_rocm(["rocm_collector.h"]), copts = tf_profiler_copts() + tsl_copts(), + tags = [ + "gpu", + "rocm-only", + ], visibility = ["//visibility:public"], deps = [ "//xla/stream_executor/rocm:roctracer_wrapper", @@ -253,6 +259,10 @@ tsl_gpu_library( srcs = if_rocm(["rocm_tracer.cc"]), hdrs = if_rocm(["rocm_tracer.h"]), copts = tf_profiler_copts() + tsl_copts(), + tags = [ + "gpu", + "rocm-only", + ], visibility = ["//visibility:public"], deps = [ ":rocm_collector", diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 4bc6ee63aef070..295abcb3cd7740 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -925,14 +925,25 @@ xla_cc_test( alias( name = "cuda_platform", actual = "//xla/stream_executor/cuda:all_runtime", + tags = [ + "cuda-only", + "gpu", + ], ) alias( name = "rocm_platform", actual = "//xla/stream_executor/rocm:all_runtime", + tags = [ + "gpu", + "rocm-only", + ], ) alias( name = "sycl_platform", actual = "//xla/stream_executor/sycl:all_runtime", + tags = [ + "gpu", + ], ) diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 491e558b1a4e4e..aa6a019eb25a26 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -1,22 +1,23 @@ # Description: # ROCm-platform specific StreamExecutor support code. -# buildifier: disable=out-of-order-load -# buildifier: disable=out-of-order-load - -load( - "//xla/stream_executor:build_defs.bzl", - "stream_executor_friends", -) load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_hipblaslt", - "if_rocm_is_configured", "rocm_library", ) -load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "//xla/stream_executor:build_defs.bzl", + "stream_executor_friends", +) +load( + "//xla/tsl:tsl.bzl", + "if_google", + "internal_visibility", + "tsl_copts", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -31,28 +32,41 @@ package_group( cc_library( name = "rocm_diagnostics", - srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]), - hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]), - deps = if_rocm_is_configured([ + srcs = ["rocm_diagnostics.cc"], + hdrs = ["rocm_diagnostics.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + "//xla/stream_executor/gpu:gpu_diagnostics_header", + "//xla/stream_executor/platform", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "//xla/stream_executor/gpu:gpu_diagnostics_header", - "//xla/stream_executor/platform", - "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:logging", - ]), + "@local_tsl//tsl/platform:platform_port", + ], ) cc_library( name = "rocm_driver", - srcs = if_rocm_is_configured(["rocm_driver.cc"]), - hdrs = if_rocm_is_configured([ - "rocm_driver_wrapper.h", + srcs = ["rocm_driver.cc"], + hdrs = [ "rocm_driver.h", + "rocm_driver_wrapper.h", + ], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":rocm_diagnostics", "//xla/stream_executor", "//xla/stream_executor/gpu:context", @@ -77,15 +91,21 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:stacktrace", - ]), + ], ) cc_library( name = "rocm_runtime", - srcs = if_rocm_is_configured(["rocm_runtime.cc"]), - hdrs = if_rocm_is_configured(["rocm_runtime.h"]), - deps = if_rocm_is_configured([ - # keep sorted + srcs = ["rocm_runtime.cc"], + hdrs = ["rocm_runtime.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocm_driver", "//xla/stream_executor", "//xla/stream_executor/gpu:context", @@ -104,30 +124,42 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - ]), + ], ) cc_library( name = "rocm_event", - srcs = if_rocm_is_configured(["rocm_event.cc"]), - hdrs = if_rocm_is_configured(["rocm_event.h"]), - deps = if_rocm_is_configured([ - # keep sorted + srcs = ["rocm_event.cc"], + hdrs = ["rocm_event.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocm_driver", "//xla/stream_executor", "//xla/stream_executor/gpu:context", "//xla/stream_executor/gpu:gpu_event_header", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:scoped_activate_context", - ]), + ], ) cc_library( name = "rocm_executor", - srcs = if_rocm_is_configured(["rocm_executor.cc"]), - hdrs = if_rocm_is_configured(["rocm_executor.h"]), - deps = if_rocm_is_configured([ - # keep sorted + srcs = ["rocm_executor.cc"], + hdrs = ["rocm_executor.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocm_diagnostics", ":rocm_driver", ":rocm_event", @@ -180,24 +212,38 @@ cc_library( "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - ]), + ], alwayslink = True, ) cc_library( name = "rocm_kernel", - srcs = if_rocm_is_configured(["rocm_kernel.cc"]), + srcs = ["rocm_kernel.cc"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - "//xla/stream_executor/gpu:gpu_kernel_header", + deps = [ "//xla/stream_executor/gpu:gpu_driver_header", - ]), + "//xla/stream_executor/gpu:gpu_kernel_header", + ], alwayslink = True, ) cc_library( name = "command_buffer_kernels", srcs = ["command_buffer_kernels.cc"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), deps = [ "//xla/stream_executor:kernel_spec", "@com_google_absl//absl/status", @@ -207,11 +253,17 @@ cc_library( cc_library( name = "rocm_platform", - srcs = if_rocm_is_configured(["rocm_platform.cc"]), - hdrs = if_rocm_is_configured(["rocm_platform.h"]), + srcs = ["rocm_platform.cc"], + hdrs = ["rocm_platform.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":rocm_driver", ":rocm_executor", ":rocm_platform_id", @@ -226,7 +278,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", - ]), + ], alwayslink = True, # Registers itself with the PlatformManager. ) @@ -239,23 +291,29 @@ cc_library( cc_library( name = "rocblas_if_static", - deps = if_static([ - ":rocblas_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "rocblas_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:rocblas", ]), ) cc_library( name = "rocblas_wrapper", - hdrs = if_rocm_is_configured(["rocblas_wrapper.h"]), - deps = if_rocm_is_configured([ - # keep sorted + hdrs = ["rocblas_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocblas_if_static", ":rocm_executor", ":rocm_platform_id", @@ -265,17 +323,23 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) cc_library( name = "rocblas_plugin", - srcs = if_rocm_is_configured(["rocm_blas.cc"]), - hdrs = if_rocm_is_configured(["rocm_blas.h"]), + srcs = ["rocm_blas.cc"], + hdrs = ["rocm_blas.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":hipblas_lt_header", ":rocblas_if_static", ":rocblas_wrapper", @@ -305,31 +369,37 @@ cc_library( "@eigen_archive//:eigen3", "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:logging", - ]), + ], alwayslink = True, ) cc_library( name = "hipfft_if_static", - deps = if_static([ - ":hipfft_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "hipfft_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:hipfft", ]), ) cc_library( name = "hipfft_plugin", - srcs = if_rocm_is_configured(["rocm_fft.cc"]), - hdrs = if_rocm_is_configured(["rocm_fft.h"]), + srcs = ["rocm_fft.cc"], + hdrs = ["rocm_fft.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":hipfft_if_static", ":rocm_complex_converters", ":rocm_platform_id", @@ -346,36 +416,42 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", - ]), + ], alwayslink = True, ) cc_library( name = "miopen_if_static", - deps = if_static([ - ":miopen_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "miopen_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:miopen", ]), ) cc_library( name = "miopen_plugin", - srcs = if_rocm_is_configured(["rocm_dnn.cc"]), - hdrs = if_rocm_is_configured(["rocm_dnn.h"]), + srcs = ["rocm_dnn.cc"], + hdrs = ["rocm_dnn.h"], copts = [ # STREAM_EXECUTOR_CUDNN_WRAP would fail on Clang with the default # setting of template depth 256 "-ftemplate-depth-512", ], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":miopen_if_static", ":rocm_diagnostics", ":rocm_driver", @@ -409,115 +485,143 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:logging", - ]), + ], alwayslink = True, ) cc_library( name = "hiprand_if_static", - deps = if_static([ - ":hiprand_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "hiprand_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:hiprand", ]), ) cc_library( name = "hipsparse_if_static", - deps = if_static([ - ":hipsparse_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "hipsparse_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:hipsparse", ]), ) cc_library( name = "hipsparse_wrapper", - srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]), - hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]), - deps = if_rocm_is_configured([ + srcs = ["hipsparse_wrapper.h"], + hdrs = ["hipsparse_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":hipsparse_if_static", ":rocm_executor", ":rocm_platform_id", - "@local_config_rocm//rocm:rocm_headers", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", + "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) cc_library( name = "rocsolver_if_static", - deps = if_static([ - ":rocsolver_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "rocsolver_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:rocsolver", ]), ) cc_library( name = "rocsolver_wrapper", - srcs = if_rocm_is_configured(["rocsolver_wrapper.h"]), - hdrs = if_rocm_is_configured(["rocsolver_wrapper.h"]), - deps = if_rocm_is_configured([ + srcs = ["rocsolver_wrapper.h"], + hdrs = ["rocsolver_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocm_executor", ":rocm_platform_id", ":rocsolver_if_static", - "@local_config_rocm//rocm:rocm_headers", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", + "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) cc_library( name = "hipsolver_if_static", - deps = if_static([ - ":hipsolver_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "hipsolver_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:hipsolver", ]), ) cc_library( name = "hipsolver_wrapper", - hdrs = if_rocm_is_configured(["hipsolver_wrapper.h"]), - deps = if_rocm_is_configured([ + hdrs = ["hipsolver_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":hipsolver_if_static", ":rocm_executor", ":rocm_platform_id", - ":hipsolver_if_static", - "@local_config_rocm//rocm:rocm_headers", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", + "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) cc_library( name = "hipblaslt_if_static", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), deps = if_rocm_hipblaslt([ "@local_config_rocm//rocm:hipblaslt", ]), @@ -525,14 +629,20 @@ cc_library( cc_library( name = "amdhipblaslt_plugin", - srcs = if_rocm_is_configured(["hip_blas_lt.cc"]), - hdrs = if_rocm_is_configured([ + srcs = ["hip_blas_lt.cc"], + hdrs = [ "hip_blas_lt.h", - "hipblaslt_wrapper.h", "hip_blas_utils.h", + "hipblaslt_wrapper.h", + ], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":hip_blas_utils", ":hipblas_lt_header", ":rocblas_plugin", @@ -557,7 +667,7 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", - ]) + if_static([ + ] + if_static([ ":hipblaslt_if_static", ]), alwayslink = True, @@ -565,14 +675,20 @@ cc_library( cc_library( name = "hipblas_lt_header", - hdrs = if_rocm_is_configured([ + hdrs = [ "hip_blas_lt.h", - "hipblaslt_wrapper.h", "hip_blas_utils.h", + "hipblaslt_wrapper.h", + ], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted + deps = [ "//xla:types", "//xla/stream_executor", "//xla/stream_executor:blas", @@ -585,15 +701,21 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", - ]), + ], ) cc_library( name = "hip_blas_utils", - srcs = if_rocm_is_configured(["hip_blas_utils.cc"]), - hdrs = if_rocm_is_configured(["hip_blas_utils.h"]), - deps = if_rocm_is_configured([ - # keep sorted + srcs = ["hip_blas_utils.cc"], + hdrs = ["hip_blas_utils.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":hipblas_lt_header", ":rocblas_plugin", "//xla/stream_executor", @@ -603,29 +725,35 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", - ]), + ], ) cc_library( name = "roctracer_if_static", - deps = if_static([ - ":roctracer_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "roctracer_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:roctracer", ]), ) cc_library( name = "roctracer_wrapper", - srcs = if_rocm_is_configured(["roctracer_wrapper.h"]), - hdrs = if_rocm_is_configured(["roctracer_wrapper.h"]), - deps = if_rocm_is_configured([ - # keep sorted + srcs = ["roctracer_wrapper.h"], + hdrs = ["roctracer_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocm_executor", ":rocm_platform_id", ":roctracer_if_static", @@ -634,56 +762,78 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) rocm_library( name = "rocm_helpers", - srcs = if_rocm_is_configured(["rocm_helpers.cu.cc"]), - deps = if_rocm_is_configured([ + srcs = ["rocm_helpers.cu.cc"], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ "@local_config_rocm//rocm:rocm_headers", - ]), + ], alwayslink = True, ) cc_library( name = "rocm_complex_converters", - hdrs = if_rocm_is_configured(["rocm_complex_converters.h"]), - deps = ["@com_google_absl//absl/log:check"] + if_rocm_is_configured([ + hdrs = ["rocm_complex_converters.h"], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ + "@com_google_absl//absl/log:check", "@local_config_rocm//rocm:rocm_headers", - ]), + ], ) cc_library( name = "all_runtime", copts = tsl_copts(), + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - ":miopen_plugin", + deps = [ + ":amdhipblaslt_plugin", ":hipfft_plugin", + ":miopen_plugin", ":rocblas_plugin", ":rocm_driver", - ":rocm_platform", ":rocm_helpers", - ":amdhipblaslt_plugin", - ]), + ":rocm_platform", + ], alwayslink = 1, ) cc_library( name = "rocm_rpath", - data = [], linkopts = select({ "//conditions:default": [ "-Wl,-rpath,../local_config_rocm/rocm/rocm/lib", ], }), - deps = [], ) cc_library( name = "stream_executor_rocm", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), deps = [ ":rocm_rpath", "//xla/stream_executor", @@ -712,7 +862,14 @@ cc_library( cc_test( name = "rocm_version_parser_test", - srcs = if_rocm_is_configured(["rocm_version_parser_test.cc"]), + srcs = ["rocm_version_parser_test.cc"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), deps = [ ":rocm_version_parser", "//xla/stream_executor:semantic_version", diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_version_parser_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_version_parser_test.cc index 2306ae8717a110..1859aed034fbdb 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_version_parser_test.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_version_parser_test.cc @@ -30,7 +30,7 @@ using tsl::testing::IsOkAndHolds; using tsl::testing::StatusIs; TEST(ParseRocmVersionTest, Simple) { - EXPECT_THAT(stream_executor::ParseRocmVersion(60102), + EXPECT_THAT(stream_executor::ParseRocmVersion(60'100'002), IsOkAndHolds(SemanticVersion(6, 1, 2))); } From c6c32660983af880644464b80063f850bbd074a6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Sep 2024 23:52:43 -0700 Subject: [PATCH 224/483] Automated Code Change PiperOrigin-RevId: 678558576 --- tensorflow/compiler/mlir/lite/converter_gen.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 2de28d68703b7c..ce5632680cc1bc 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" From 8342e09b5d895669b17fd320e3a1f308cd24f512 Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Wed, 25 Sep 2024 00:07:31 -0700 Subject: [PATCH 225/483] PR #17500: Move HostOffloadLegalize before LayoutNormalization for GPUs Imported from GitHub PR https://github.com/openxla/xla/pull/17500 Fix ActivationOffloadingTest.test_remat_scan_layout_change_offloadable in JAX. The test in memories_test.py failed with an INVALID_ARGUMENT error: - A tensor moved to host (from "dynamic-update-slice.13") was used by an instruction ("transpose.32") not acceptable during pure memory offload. Root cause: - LayoutNormalization inserts a transpose - AlgebraicSimplifier replaces certain transposes with bitcast transposes - These transposes/bitcasts are invalid in host memory offloading segments Solution: Move HostOffloadLegalize before LayoutNormalization to prevent this issue. Copybara import of the project: -- 107d6b462084331f7366e5ae60c150dce090bf14 by Jane Liu : Move HostOffloadLegalize before LayoutNormalization for GPUs -- f0fb7347a1bb61370a29e19230d87a2161c29ef7 by Jane Liu : Add comments to explain the pass order -- 30d2b4450be58ec32ca29f98dc7b822e80ec09df by Jane Liu : Add the test to validate the pass order Merging this change closes #17500 PiperOrigin-RevId: 678563091 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 9 ++++++--- third_party/xla/xla/service/gpu/gpu_compiler_test.cc | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 0db21cdf7173d2..3bc29969e3997c 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -996,6 +996,12 @@ absl::Status RunLayoutAssignmentPasses(HloModule* hlo_module, pipeline.AddPass( SubByteNormalization::SET_ELEMENT_SIZE); pipeline.AddPass(true); + // Run HostOffloadLegalize before LayoutNormalization to prevent + // the creation of invalid transpose/bitcast operations within + // host memory offloading segments. + pipeline.AddPass( + static_cast(stream_executor::MemoryType::kHost), + /* after_layout= */ true); return pipeline.Run(hlo_module).status(); } @@ -1570,9 +1576,6 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); - pipeline.AddPass( - static_cast(stream_executor::MemoryType::kHost), - /* after_layout= */ true); pipeline.AddPass( static_cast(stream_executor::MemoryType::kHost)); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index b938ee13b1f55c..8338f2b1e3b277 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -1097,6 +1097,7 @@ ENTRY main { // This test captures known dependencies between passes. VerifyPassOrder(passes, "layout-assignment", "priority-fusion"); VerifyPassOrder(passes, "layout-assignment", "layout_normalization"); + VerifyPassOrder(passes, "host-offload-legalize", "layout_normalization"); } } // namespace From 8caabd26da555dc1426cb12bfa0cbee6f00d4f6e Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 25 Sep 2024 00:20:06 -0700 Subject: [PATCH 226/483] Add rocm-only tag to AMD GPU tests generated by xla_test This allows filtering out those tests using tags like we do for library targets as well. PiperOrigin-RevId: 678567062 --- third_party/xla/xla/tests/build_defs.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl index 006811c7b322be..94670a8e4c2ade 100644 --- a/third_party/xla/xla/tests/build_defs.bzl +++ b/third_party/xla/xla/tests/build_defs.bzl @@ -133,6 +133,7 @@ def prepare_amd_gpu_backend_data(backends, disabled_backends, backend_tags, back if "cuda-only" not in gpu_backend_tags: new_backend_tags[backend].append("requires-gpu-amd") new_backend_tags[backend].append("notap") + new_backend_tags[backend].append("rocm-only") return new_backends, new_disabled_backends, new_backend_tags, backend_args From 0627b30129ae5b5023cc24e98f2563f927e761a5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 00:26:04 -0700 Subject: [PATCH 227/483] Disable MSAN for failing test. PiperOrigin-RevId: 678568670 --- third_party/xla/xla/service/gpu/model/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 4e4df85e58f5aa..978b6695c4353f 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -538,6 +538,9 @@ cc_library( xla_cc_test( name = "indexing_map_serialization_test", srcs = ["indexing_map_serialization_test.cc"], + tags = [ + "nomsan", + ], deps = [ ":indexing_map_serialization", ":indexing_test_utils", From 8b7e6dcf5bc81f216c35fda175401614eb9ed908 Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Wed, 25 Sep 2024 00:42:47 -0700 Subject: [PATCH 228/483] PR #17493: [XLA:GPU] Sort groups in NCCL clique keys Imported from GitHub PR https://github.com/openxla/xla/pull/17493 It turns out that in some cases, we get two calls to GetNcclCliqueKey that only differ by the order of the groups (for instance, one call with groups={{0,1},{2, 3}}, another call with groups={{2,3},{0,1}}. This leads to unnecessary creation of equivalent NCCL communicators. This patch sorts the groups in NCCL clique key so that we create the same key for the same set of groups (up to permutation). Copybara import of the project: -- 291bd38a94c3e64490d5f6919fe7b9417de92c0e by Jaroslav Sevcik : Sort groups in NCCL clique keys -- 6be7f5b7f83ad977ab4bb1a1351f7a2809d13cc9 by Jaroslav Sevcik : Address reviewer comments Merging this change closes #17493 PiperOrigin-RevId: 678573431 --- third_party/xla/xla/service/gpu/runtime/BUILD | 1 + .../service/gpu/runtime/nccl_clique_key.cc | 15 ++++++++++++- .../gpu/runtime/nccl_clique_key_test.cc | 21 +++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 4d6e321f26b0bc..0cf0912aac171d 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -300,6 +300,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc index 9bbc6f4019eab1..2cdb0fd2be1705 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/service/global_device_id.h" +#include "tsl/platform/logging.h" namespace xla::gpu { @@ -44,7 +45,19 @@ NcclCliqueKey::NcclCliqueKey( : devices_(std::move(devices)), stream_id_(stream_id), stream_kind_(stream_kind), - participant_groups_(std::move(participant_groups)) {} + participant_groups_(std::move(participant_groups)) { + for (std::vector& group : participant_groups_) { + absl::c_sort(group); + } + // Compare the groups by their first element. + auto compare_groups = [](const std::vector& lhs, + const std::vector& rhs) { + CHECK(!lhs.empty()); + CHECK(!rhs.empty()); + return lhs[0] < rhs[0]; + }; + absl::c_sort(participant_groups_, compare_groups); +} absl::Span NcclCliqueKey::devices() const { return devices_; diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc index e55e401ace8ee0..50f43b116145fb 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc @@ -89,6 +89,27 @@ TEST(NcclCliqueKeyTest, CompareWithParticipantGroups) { EXPECT_EQ(key0_nogroups, key1_nogroups); } +TEST(NcclCliqueKeyTest, CompareWithPermutedParticipantGroups) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + // The keys are equal because the replica groups are same up to permutation. + NcclCliqueKey key0( + {id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id3, id2}, {id0, id1}}); + NcclCliqueKey key1( + {id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id0, id1}, {id2, id3}}); + EXPECT_EQ(key0, key1); + + NcclCliqueKey key_other( + {id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id0, id2}, {id1, id3}}); + EXPECT_FALSE(key0 == key_other); +} + TEST(NcclCliqueKeyTest, BtreeIterationOrder) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); From 2676c50fd0ea0d9399d085117a0a9498afc6ab02 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 25 Sep 2024 00:57:38 -0700 Subject: [PATCH 229/483] [XLA:GPU][IndexAnalysis] Fix MSAN failure: 1st token was accessed before initialization. PiperOrigin-RevId: 678577534 --- .../gpu/model/indexing_map_serialization.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc index cbaea73fb879ed..af48bcb1f318d5 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc @@ -122,11 +122,17 @@ class Parser { public: explicit Parser(llvm::StringRef input) : input_(input), it_(input.begin()) { // Set the parser to the first token. - Advance(); + current_token_ = GetNextTokenImpl(); } const Token& GetCurrentToken() const { return current_token_; }; - void Advance() { current_token_ = GetNextTokenImpl(); } + void Advance() { + if (current_token_.kind == Token::Kind::kError || + current_token_.kind == Token::Kind::kEOF) { + return; + } + current_token_ = GetNextTokenImpl(); + } Token GetNextToken() { Advance(); return current_token_; @@ -253,10 +259,6 @@ bool Parser::ConsumeToken(Token::Kind kind) { } Token Parser::GetNextTokenImpl() { - if (current_token_.kind == Token::Kind::kError || - current_token_.kind == Token::Kind::kEOF) { - return current_token_; - } ConsumeWhitespace(); if (it_ == input_.end()) { return Token{"", Token::Kind::kEOF}; From 4f9818682f701a324e844042e4aeda1824bd0b81 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Wed, 25 Sep 2024 01:01:26 -0700 Subject: [PATCH 230/483] Add a pass to fuse xla_gpu.loops PiperOrigin-RevId: 678578744 --- .../gpu/fusions/mlir/mlir_fusion_emitter.cc | 1 + .../xla/service/gpu/fusions/transforms/BUILD | 1 + .../gpu/fusions/transforms/fuse_loops.cc | 229 +++++++++++++ .../service/gpu/fusions/transforms/passes.h | 1 + .../service/gpu/fusions/transforms/passes.td | 32 ++ .../fusions/transforms/tests/fuse_loops.mlir | 304 ++++++++++++++++++ 6 files changed, 568 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc create mode 100644 third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 448d3050c30bd4..2d5e5ad701c105 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -547,6 +547,7 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm) { })); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); + pm.addNestedPass(CreateFuseLoopsPass()); pm.addNestedPass(CreatePeelLoopsPass()); pm.addNestedPass(CreateLowerXlaGpuLoopsToScfPass()); pm.addPass(mlir::mhlo::createConvertToSignlessPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD index e06494acb6262e..b26e0ec1ab9cb7 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD @@ -38,6 +38,7 @@ cc_library( "erase_dead_functions.cc", "expand_float_ops.cc", "flatten_tensors.cc", + "fuse_loops.cc", "lower_tensors.cc", "lower_to_llvm.cc", "lower_xla_gpu_to_scf.cc", diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc new file mode 100644 index 00000000000000..f68570d2c59935 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc @@ -0,0 +1,229 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { +namespace { + +using mlir::MLIRContext; +using mlir::SmallVector; +using mlir::Value; +using mlir::ValueRange; +namespace mv = ::mlir::vector; + +#define GEN_PASS_DEF_FUSELOOPSPASS +#include "xla/service/gpu/fusions/transforms/passes.h.inc" + +bool LoopsHaveTheSameDomain(LoopOp& loop1, LoopOp& loop2) { + auto map1 = loop1.getIndexingMap(); + auto map2 = loop2.getIndexingMap(); + if (map1.GetDimVarsCount() != map2.GetDimVarsCount() || + map1.GetRangeVarsCount() != map2.GetRangeVarsCount() || + map1.GetConstraintsCount() != map2.GetConstraintsCount()) { + return false; + } + for (auto [d1, d2] : llvm::zip(map1.GetDimVars(), map2.GetDimVars())) { + if (d1 != d2) return false; + } + for (auto [r1, r2] : llvm::zip(map1.GetRangeVars(), map2.GetRangeVars())) { + if (r1 != r2) return false; + } + if (map1.GetConstraints() != map2.GetConstraints()) return false; + + // Check dimensions come from the same op. This is technically not a + // requirement and could be modified to handle different dim args. + for (auto [dim1, dim2] : llvm::zip(loop1.getDims(), loop2.getDims())) { + if (dim1.getDefiningOp() != dim2.getDefiningOp()) { + continue; + } + } + return true; +} + +// Check that the loops: +// 1. insert and extract from the same location within each iteration, +// 2. use all their IVs (so we don't overwrite the values in another iteration), +// 3. all indices are IVs (so they are confirmed injective). +bool IndicesAreEqualAndInjective(int64_t iv_count, mv::InsertOp insert, + mv::ExtractOp extract) { + auto insert_indices = insert.getDynamicPosition(); + auto extract_indices = extract.getDynamicPosition(); + if (insert_indices.size() != extract_indices.size()) { + return false; + } + if (insert_indices.size() != iv_count) { + return false; + } + + SmallVector matched_indices(iv_count, false); + for (auto [in, ex] : llvm::zip(insert_indices, extract_indices)) { + auto in_arg = mlir::dyn_cast(in); + auto ex_arg = mlir::dyn_cast(ex); + if (!in_arg || !ex_arg || in_arg.getArgNumber() != ex_arg.getArgNumber()) { + return false; + } + // Check #3 - all indices are IVs. + if (in_arg.getArgNumber() >= iv_count) { + return false; + } + matched_indices[in_arg.getArgNumber()] = true; + } + // If there is a loop IV that we didn't use in the insert op, then don't + // match. It's possible that we overwrite the value on a subsequent iteration + // so the loops cannot be fused. + return llvm::all_of(matched_indices, [](bool matched) { return matched; }); +} + +// Fuse insert_loop and extract_loop into a single loop, and remove the +// vector.insert and vector.extract ops. +void FuseLoops(MLIRContext* mlir_context, LoopOp insert_loop, + LoopOp extract_loop, mv::InsertOp insert, + mv::ExtractOp extract) { + mlir::IRRewriter rewriter(mlir_context); + rewriter.setInsertionPointAfter(extract_loop); + // Create a new map that has the results of both loops. + // map = (d0...dn)[s0...sn] -> + // (insert_loop_results..., extract_loop_results...) + auto insert_loop_map = insert_loop.getIndexingMap(); + auto extract_loop_map = extract_loop.getIndexingMap(); + auto map = insert_loop_map.GetAffineMap(); + for (auto res : extract_loop_map.GetAffineMap().getResults()) { + map = map.insertResult(res, map.getNumResults()); + } + IndexingMap new_map(map, insert_loop_map.GetDimVars(), + insert_loop_map.GetRangeVars(), + /*rt_vars=*/{}, insert_loop_map.GetConstraints()); + + auto new_loop = + rewriter.create(insert_loop.getLoc(), new_map, + insert_loop.getDims(), extract_loop.getInits()); + + // Make the loops independent of the vector.insert/extract & erase. + auto vector_cst = insert_loop.getInits().back(); + insert_loop->replaceAllUsesWith(ValueRange(vector_cst)); + extract_loop->replaceAllUsesWith(new_loop.getResults()); + extract.replaceAllUsesWith(insert.getSource()); + auto insert_loop_yield = + mlir::dyn_cast(insert_loop.getRegion().front().back()); + rewriter.eraseOp(insert_loop_yield); + rewriter.eraseOp(extract); + rewriter.eraseOp(insert); + + // Map old loop arguments to new loop arguments. + // new_args = [s0...sn, insert_loop_results..., extract_loop_results..., + // extract_inits...] + auto new_args = new_loop.getRegion().front().getArguments(); + auto range_vars = new_args.take_front(new_map.GetRangeVarsCount()); + new_args = new_args.drop_front(range_vars.size()); + auto in_loop_results = new_args.take_front(insert_loop_map.GetNumResults()); + new_args = new_args.drop_front(in_loop_results.size()); + auto ex_loop_results = new_args.take_front(extract_loop_map.GetNumResults()); + auto extract_inits = new_args.take_back(extract_loop.getInits().size()); + + // old_insert_args = [s0...sn, insert_loop_results..., vector_cst] + SmallVector old_insert_args; + old_insert_args.append(range_vars.begin(), range_vars.end()); + old_insert_args.append(in_loop_results.begin(), in_loop_results.end()); + old_insert_args.push_back(vector_cst); + + // old_insert_args = [s0...sn, extract_loop_results..., extract_inits...] + SmallVector old_extract_args; + old_extract_args.append(range_vars.begin(), range_vars.end()); + old_extract_args.append(ex_loop_results.begin(), ex_loop_results.end()); + old_extract_args.append(extract_inits.begin(), extract_inits.end()); + + // Merge the loops: first insert, then extract. + rewriter.mergeBlocks(&insert_loop.getRegion().front(), + &new_loop.getRegion().front(), old_insert_args); + rewriter.mergeBlocks(&extract_loop.getRegion().front(), + &new_loop.getRegion().front(), old_extract_args); + rewriter.eraseOp(insert_loop); + rewriter.eraseOp(extract_loop); +} + +struct FuseLoopsPass : public impl::FuseLoopsPassBase { + void runOnOperation() override { + SmallVector extracts; + getOperation()->walk([&](mlir::Operation* op) -> void { + if (auto extract = mlir::dyn_cast(op)) { + extracts.push_back(extract); + } + }); + + for (auto extract : extracts) { + // Check that it has the following pattern: + // %insert_loop = { %insert = vector.insert ... } + // %extract_loop = { %extract = vector.extract %insert_loop } + auto extract_loop = extract->getParentOfType(); + if (!extract_loop) continue; + if (!extract.getVector().getDefiningOp()) continue; + auto insert_loop = + mlir::dyn_cast(extract.getVector().getDefiningOp()); + if (!insert_loop) continue; + SmallVector inserts; + // If necessary, the insert_loop result size constraint may be relaxed. + if (insert_loop.getResults().size() != 1) continue; + for (auto user : insert_loop.getRegionIterArgs().back().getUsers()) { + if (auto insert = mlir::dyn_cast(user)) { + inserts.push_back(insert); + } + } + if (inserts.size() != 1) continue; + auto insert = inserts.front(); + + // Check that the vector isn't being used anywhere else so it can be + // removed entirely; we already know from above it's being used by + // extract so it should have exactly one use. + if (!insert_loop.getResult(0).hasOneUse()) continue; + + if (!LoopsHaveTheSameDomain(insert_loop, extract_loop)) continue; + // Only fuse loops if we are extracting from the same position that we are + // inserting into on each iteration. + if (!IndicesAreEqualAndInjective(insert_loop.getNumInductionVars(), + insert, extract)) { + continue; + } + + // All requirements have been met: fuse loops. + FuseLoops(&getContext(), insert_loop, extract_loop, insert, extract); + } + } +}; + +} // namespace + +std::unique_ptr CreateFuseLoopsPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h index 470a333f70ccca..99304ed9a1f8da 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h @@ -51,6 +51,7 @@ std::unique_ptr CreateLowerXlaGpuToScfPass(); std::unique_ptr CreateLowerXlaGpuLoopsToScfPass(); std::unique_ptr CreateMergePointersToSameSlicePass(); std::unique_ptr CreateOptimizeLoopsPass(); +std::unique_ptr CreateFuseLoopsPass(); std::unique_ptr CreatePeelLoopsPass(); std::unique_ptr CreatePropagateSliceIndicesPass(); std::unique_ptr CreateRewriteReductionsPass(); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td index 52a0dacbc3db8f..f19e984fd9e6ec 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td @@ -273,6 +273,38 @@ def VectorizeLoadsAndStoresPass : let constructor = "CreateVectorizeLoadsAndStoresPass()"; } +def FuseLoopsPass : Pass<"xla-gpu-fuse-loops", "mlir::func::FuncOp"> { + let summary = "Fuse xla_gpu.loop."; + let description = [{ + This pass fuses similar xla_gpu.loops into one if the second one is + extracting the same value from a vector in which the first one inserts to. + + Before fuse-loops: + %loop0 = xla_gpu.loop (%tid, %bid) -> (%ra, %rb, %rc)[%i, %j] + in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] + %1 = vector.insert %extracted, %iter [%i, %j] + xla_gpu.yield %1 + } + %loop1 = xla_gpu.loop (%tid, %bid) -> (%ra, %rb)[%i, %j] + in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<32x33xf32>) { + %2 = vector.extract %loop0 [%i, %j] + %inserted = tensor.insert %iter[%ra, %rb] + xla_gpu.yield %extracted + } + + After fuse-loops: + %loop = xla_gpu.loop (%tid, %bid) -> (%ra, %rb, %rc, %rd, %re)[%i, %j] + in #indexing_map iter_args(%iter = %shmem) -> (tensor<32x33xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] + %inserted = tensor.insert %extracted into %iter[%rd, %re] + xla_gpu.yield %inserted + } + }]; + let dependentDialects = ["xla::gpu::XlaGpuDialect"]; + let constructor = "CreateFuseLoopsPass()"; +} + def PeelLoopsPass : Pass<"xla-gpu-peel-loops", "mlir::func::FuncOp"> { let summary = "Peels xla_gpu.loop."; let description = [{ diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir new file mode 100644 index 00000000000000..557335b6a7ff72 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir @@ -0,0 +1,304 @@ +// RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-fuse-loops \ +// RUN: | FileCheck %s + +#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (d1 floordiv 30, + ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, + (d1 mod 6) * 32 + d0 mod 32), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (0, + d0 mod 32, + d0 floordiv 32 + s0 * 4), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + + +// CHECK: #[[$FUSED_MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> +// CHECK-SAME: (d1 floordiv 30, ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, +// CHECK-SAME: (d1 mod 6) * 32 + d0 mod 32, 0, d0 mod 32, d0 floordiv 32 + s0 * 4), +// CHECK-SAME: domain: d0 in [0, 127], d1 in [0, 599], +// CHECK-SAME: s0 in [0, 7], s1 in [0, 0], (d1 mod 6) * 32 + d0 mod 32 in [0, 169] + +// CHECK: %[[FUSED_LOOP:.*]] = xla_gpu.loop {{.*}} in #[[$FUSED_MAP]] +// CHECK-NOT: vector.insert +// CHECK-NOT: vector.extract +// CHECK: %[[EXTRACTED:.*]] = tensor.extract +// CHECK: %[[EXP:.*]] = math.exp %[[EXTRACTED]] +// CHECK: tensor.insert %[[EXP]] + +// CHECK: xla_gpu.sync_threads %[[FUSED_LOOP]] + +// ----- + +#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (d1 floordiv 30, + ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, + (d1 mod 6) * 32 + d0 mod 32), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (0, + d0 mod 32, + d0 floordiv 32 + s0 * 4), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%j, %i] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + return %xla_loop_0 : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_index_mismatch +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (d1 floordiv 30, + ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, + (d1 mod 6) * 32 + d0 mod 32), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (0, + d0 mod 32, + d0 floordiv 32 + s0 * 4), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + %0 = vector.extract %xla_loop [2, 0] : f32 from vector<8x1xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_multiple_uses +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (d1 floordiv 30, + ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, + (d1 mod 6) * 32 + d0 mod 32), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (0, + d0 mod 32, + d0 floordiv 32 + s0 * 4), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 5], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_map_domain_mismatch +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (d1 floordiv 30, + ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, + (d1 mod 6) * 32 + d0 mod 32), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> + (0, + d0 mod 32, + d0 floordiv 32 + s0 * 4), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], + (d1 mod 5) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_map_constraint_mismatch +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> + (d1 floordiv 30, + ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, + (d1 mod 6) * 32 + d0 mod 32), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], s2 in [0, 1], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> + (0, + d0 mod 32, + d0 floordiv 32 + s0 * 4), + domain: + d0 in [0, 127], d1 in [0, 599], + s0 in [0, 7], s1 in [0, 0], s2 in [0, 1], + (d1 mod 6) * 32 + d0 mod 32 in [0, 169], + is_simplified: true> +func.func @do_not_fuse_unused_loop_iv(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j, %k] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j, %k] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_unused_loop_iv +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract From e2b87298343aa99013401bfc247917a1aba13ba7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 01:02:54 -0700 Subject: [PATCH 231/483] Automated Code Change PiperOrigin-RevId: 678579286 --- tensorflow/lite/delegates/xnnpack/BUILD | 80 +++++++++++++++++++ .../quantized_binary_elementwise_tester.cc | 10 ++- .../xnnpack/quantized_conv_2d_tester.cc | 12 ++- .../quantized_depthwise_conv_2d_tester.cc | 12 ++- .../quantized_fully_connected_tester.cc | 11 ++- .../xnnpack/quantized_leaky_relu_tester.cc | 8 +- .../delegates/xnnpack/quantized_pad_tester.cc | 8 +- .../xnnpack/quantized_pool_2d_tester.cc | 10 ++- .../xnnpack/quantized_reduce_tester.cc | 8 +- .../quantized_resize_bilinear_tester.cc | 8 +- .../quantized_transpose_conv_tester.cc | 11 +-- .../quantized_unary_elementwise_tester.cc | 8 +- .../xnnpack/quantized_variable_ops_tester.cc | 5 +- .../lite/delegates/xnnpack/reduce_tester.cc | 8 +- .../lite/delegates/xnnpack/relu6_test.cc | 2 + .../delegates/xnnpack/relu_n1_to_1_test.cc | 2 + .../lite/delegates/xnnpack/relu_test.cc | 2 + .../lite/delegates/xnnpack/reshape_test.cc | 2 + .../lite/delegates/xnnpack/reshape_tester.cc | 7 +- .../delegates/xnnpack/resize_bilinear_test.cc | 1 + .../xnnpack/resize_bilinear_tester.cc | 8 +- .../lite/delegates/xnnpack/round_test.cc | 2 + .../xnnpack/signed_dequantize_test.cc | 1 + .../xnnpack/signed_quantized_add_test.cc | 2 + .../signed_quantized_concatenation_test.cc | 2 + .../xnnpack/signed_quantized_conv_2d_test.cc | 1 + .../signed_quantized_depth_to_space_test.cc | 2 + ...signed_quantized_depthwise_conv_2d_test.cc | 1 + .../xnnpack/signed_quantized_elu_test.cc | 2 + .../signed_quantized_fully_connected_test.cc | 1 + .../signed_quantized_leaky_relu_test.cc | 1 + .../xnnpack/signed_quantized_logistic_test.cc | 2 + .../signed_quantized_max_pool_2d_test.cc | 2 + .../xnnpack/signed_quantized_mean_test.cc | 2 + .../xnnpack/signed_quantized_mul_test.cc | 2 + .../xnnpack/signed_quantized_pad_test.cc | 1 + .../xnnpack/signed_quantized_reshape_test.cc | 2 + .../signed_quantized_resize_bilinear_test.cc | 1 + .../xnnpack/signed_quantized_slice_test.cc | 2 + .../signed_quantized_space_to_depth_test.cc | 2 + .../xnnpack/signed_quantized_split_test.cc | 2 + .../signed_quantized_strided_slice_test.cc | 1 + .../xnnpack/signed_quantized_sub_test.cc | 2 + .../xnnpack/signed_quantized_tanh_test.cc | 2 + .../signed_quantized_transpose_conv_test.cc | 1 + .../signed_quantized_transpose_test.cc | 2 + ...zed_variable_ops_multiple_subgraph_test.cc | 1 + .../signed_quantized_variable_ops_test.cc | 1 + .../lite/delegates/xnnpack/slice_test.cc | 2 + .../lite/delegates/xnnpack/slice_tester.cc | 6 +- .../lite/delegates/xnnpack/softmax_test.cc | 1 + 51 files changed, 229 insertions(+), 46 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index ad4fc172112086..c99ea87f784e47 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -720,8 +720,10 @@ cc_library( srcs = ["quantized_binary_elementwise_tester.cc"], hdrs = ["quantized_binary_elementwise_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -739,8 +741,10 @@ cc_library( hdrs = ["quantized_conv_2d_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -758,8 +762,10 @@ cc_library( hdrs = ["quantized_depthwise_conv_2d_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -777,8 +783,10 @@ cc_library( hdrs = ["quantized_fully_connected_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -795,6 +803,7 @@ cc_library( srcs = ["quantized_leaky_relu_tester.cc"], hdrs = ["quantized_leaky_relu_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -813,6 +822,7 @@ cc_library( srcs = ["quantized_pad_tester.cc"], hdrs = ["quantized_pad_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -831,8 +841,10 @@ cc_library( srcs = ["quantized_pool_2d_tester.cc"], hdrs = ["quantized_pool_2d_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -849,6 +861,7 @@ cc_library( srcs = ["quantized_reduce_tester.cc"], hdrs = ["quantized_reduce_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -867,6 +880,7 @@ cc_library( srcs = ["quantized_resize_bilinear_tester.cc"], hdrs = ["quantized_resize_bilinear_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -885,6 +899,7 @@ cc_library( srcs = ["quantized_unary_elementwise_tester.cc"], hdrs = ["quantized_unary_elementwise_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -904,8 +919,10 @@ cc_library( hdrs = ["quantized_variable_ops_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/core:cc_api_stable", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/schema:schema_conversion_utils", @@ -926,8 +943,10 @@ cc_library( hdrs = ["quantized_variable_ops_tester.h"], deps = [ ":xnnpack_delegate", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/core:cc_api_stable", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/schema:schema_conversion_utils", @@ -943,6 +962,7 @@ cc_library( srcs = ["reduce_tester.cc"], hdrs = ["reduce_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -961,6 +981,7 @@ cc_library( srcs = ["reshape_tester.cc"], hdrs = ["reshape_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -979,6 +1000,7 @@ cc_library( srcs = ["resize_bilinear_tester.cc"], hdrs = ["resize_bilinear_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -997,8 +1019,10 @@ cc_library( srcs = ["slice_tester.cc"], hdrs = ["slice_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/core:cc_api_stable", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/schema:schema_conversion_utils", @@ -1143,6 +1167,7 @@ cc_library( hdrs = ["quantized_transpose_conv_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -1891,6 +1916,8 @@ cc_test( ":test_main", ":unary_elementwise_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1906,6 +1933,8 @@ cc_test( ":test_main", ":unary_elementwise_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1921,6 +1950,8 @@ cc_test( ":test_main", ":unary_elementwise_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1936,6 +1967,8 @@ cc_test( ":reshape_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1951,6 +1984,8 @@ cc_test( ":reshape_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1981,6 +2016,7 @@ cc_test( ":resize_bilinear_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -1996,6 +2032,8 @@ cc_test( ":test_main", ":unary_elementwise_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2028,6 +2066,7 @@ cc_test( ":dequantize_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2043,6 +2082,8 @@ cc_test( ":quantized_binary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2058,6 +2099,8 @@ cc_test( ":concatenation_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2073,6 +2116,7 @@ cc_test( ":quantized_conv_2d_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2088,6 +2132,7 @@ cc_test( ":quantized_depthwise_conv_2d_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2103,6 +2148,8 @@ cc_test( ":depth_to_space_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2118,6 +2165,8 @@ cc_test( ":quantized_unary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2133,6 +2182,7 @@ cc_test( ":quantized_fully_connected_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2148,6 +2198,7 @@ cc_test( ":quantized_leaky_relu_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2163,6 +2214,8 @@ cc_test( ":quantized_unary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2178,6 +2231,8 @@ cc_test( ":quantized_pool_2d_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2193,6 +2248,8 @@ cc_test( ":quantized_reduce_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2208,6 +2265,8 @@ cc_test( ":quantized_binary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2223,6 +2282,7 @@ cc_test( ":quantized_pad_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2238,6 +2298,7 @@ cc_test( ":quantized_resize_bilinear_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2253,6 +2314,8 @@ cc_test( ":slice_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2268,6 +2331,8 @@ cc_test( ":space_to_depth_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2283,6 +2348,8 @@ cc_test( ":split_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2298,6 +2365,7 @@ cc_test( ":strided_slice_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2313,6 +2381,8 @@ cc_test( ":quantized_binary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2328,6 +2398,8 @@ cc_test( ":quantized_unary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2343,6 +2415,8 @@ cc_test( ":test_main", ":transpose_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2358,6 +2432,7 @@ cc_test( ":quantized_transpose_conv_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2377,6 +2452,7 @@ cc_test( ":xnnpack_delegate", ":quantized_variable_ops_tester_no_test_mode", "@com_google_googletest//:gtest", + "//tensorflow/lite/c:c_api_types", ], ) @@ -2391,6 +2467,7 @@ cc_test( ":quantized_variable_ops_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2406,6 +2483,8 @@ cc_test( ":slice_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2421,6 +2500,7 @@ cc_test( ":softmax_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.cc index a2665d612d3864..0109427159c729 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.cc @@ -26,10 +26,14 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.cc index 266a8a22ac7b0b..3e91c73d09934c 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.cc @@ -25,10 +25,16 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.cc index 017713b4761f3e..162037f9a74f68 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.cc @@ -25,10 +25,16 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.cc index c42f8d78f97a21..fdcb8999565394 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.cc @@ -26,10 +26,15 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.cc index 88dead108e7b9b..410c1dbf21c872 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_pad_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_pad_tester.cc index aa1f2391613684..545f5cfd761a46 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_pad_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_pad_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.cc index 4918b34aeb7d7f..f1cd0249e7d0c3 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.cc @@ -23,10 +23,14 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.cc index c4a8cf1b381db5..ad055047d35e14 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.cc index d9ab3f5359547d..484b25f2c68bd1 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.cc index e9a9a19d856bcb..f5bc843ca9479e 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.cc @@ -26,13 +26,14 @@ limitations under the License. #include #include -#include "fp16.h" // from @FP16 -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/core/c/builtin_op_data.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.cc index b8e297ae46175a..6efcbafca015d3 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.cc index 71e7f3e7630a07..61ba2b60f2b8bb 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.cc @@ -25,11 +25,14 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/reduce_tester.cc b/tensorflow/lite/delegates/xnnpack/reduce_tester.cc index cc6f69066f84be..8ece825a7458a6 100644 --- a/tensorflow/lite/delegates/xnnpack/reduce_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/reduce_tester.cc @@ -25,11 +25,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/relu6_test.cc b/tensorflow/lite/delegates/xnnpack/relu6_test.cc index 75f32dcfd39116..5f2de211ec4ef9 100644 --- a/tensorflow/lite/delegates/xnnpack/relu6_test.cc +++ b/tensorflow/lite/delegates/xnnpack/relu6_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/relu_n1_to_1_test.cc b/tensorflow/lite/delegates/xnnpack/relu_n1_to_1_test.cc index 9e799577e6ed73..07aab082aefc0a 100644 --- a/tensorflow/lite/delegates/xnnpack/relu_n1_to_1_test.cc +++ b/tensorflow/lite/delegates/xnnpack/relu_n1_to_1_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/relu_test.cc b/tensorflow/lite/delegates/xnnpack/relu_test.cc index 8996ff5d04b8c4..b088a2b9053e18 100644 --- a/tensorflow/lite/delegates/xnnpack/relu_test.cc +++ b/tensorflow/lite/delegates/xnnpack/relu_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/reshape_test.cc b/tensorflow/lite/delegates/xnnpack/reshape_test.cc index fc8d240f120ff5..56c252f461eef6 100644 --- a/tensorflow/lite/delegates/xnnpack/reshape_test.cc +++ b/tensorflow/lite/delegates/xnnpack/reshape_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/reshape_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/reshape_tester.cc b/tensorflow/lite/delegates/xnnpack/reshape_tester.cc index e2f4fe2e63e9ad..a3c5a17fd38105 100644 --- a/tensorflow/lite/delegates/xnnpack/reshape_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/reshape_tester.cc @@ -25,11 +25,12 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/resize_bilinear_test.cc b/tensorflow/lite/delegates/xnnpack/resize_bilinear_test.cc index b1fc49ca93fbd1..c66004e3205617 100644 --- a/tensorflow/lite/delegates/xnnpack/resize_bilinear_test.cc +++ b/tensorflow/lite/delegates/xnnpack/resize_bilinear_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc b/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc index e4ee08280f4260..c2832b0c64d68b 100644 --- a/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc @@ -25,11 +25,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/round_test.cc b/tensorflow/lite/delegates/xnnpack/round_test.cc index 0481762ca1947b..1e4f861bf12b6a 100644 --- a/tensorflow/lite/delegates/xnnpack/round_test.cc +++ b/tensorflow/lite/delegates/xnnpack/round_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_dequantize_test.cc b/tensorflow/lite/delegates/xnnpack/signed_dequantize_test.cc index f86379326d0bc6..927152db28af04 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_dequantize_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_dequantize_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/dequantize_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_add_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_add_test.cc index 87e669c5d0cde2..ad159deb61b55c 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_add_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_add_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_concatenation_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_concatenation_test.cc index 35377adfb86e26..c7590f21f8a944 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_concatenation_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_concatenation_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/concatenation_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_conv_2d_test.cc index a43b3c42fbf40f..f67ba714b01cc8 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_conv_2d_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_depth_to_space_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_depth_to_space_test.cc index f6003e99398949..d85a9cfb1ceac4 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_depth_to_space_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_depth_to_space_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/depth_to_space_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_depthwise_conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_depthwise_conv_2d_test.cc index 33ff96aa84594f..3acfbaaf34778e 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_depthwise_conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_depthwise_conv_2d_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_elu_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_elu_test.cc index b66434f7454d82..676fc06bdf4fe5 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_elu_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_elu_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_fully_connected_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_fully_connected_test.cc index bf341b8bd8e00f..1be48daba79655 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_fully_connected_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_fully_connected_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_leaky_relu_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_leaky_relu_test.cc index c03c767eb7a64f..4aa74580b6b827 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_leaky_relu_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_leaky_relu_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_logistic_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_logistic_test.cc index 02f3383954ce24..9067ffebf02dfd 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_logistic_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_logistic_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_max_pool_2d_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_max_pool_2d_test.cc index 509f2cd1e72849..4a12e817039b32 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_max_pool_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_max_pool_2d_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_mean_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_mean_test.cc index 87fcf35bb87972..1cfd2597565be3 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_mean_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_mean_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_mul_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_mul_test.cc index fbab966d6229f5..b28ed665ed3542 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_mul_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_mul_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_pad_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_pad_test.cc index 13c4ff2d2ade90..7ce3ad1a2b4653 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_pad_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_pad_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_pad_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc index 1f57d47cdba326..71f0843406535a 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/reshape_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_resize_bilinear_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_resize_bilinear_test.cc index 8c77ba185f552a..c3cf1cef9dc3af 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_resize_bilinear_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_resize_bilinear_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_slice_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_slice_test.cc index 2b487ee4151cc0..48ca30e3adfbc5 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_slice_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_slice_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/slice_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_space_to_depth_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_space_to_depth_test.cc index 3e23a7701f1a73..99d4ce31ea9a74 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_space_to_depth_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_space_to_depth_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/space_to_depth_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite::xnnpack { namespace { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_split_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_split_test.cc index 3a5338da1726b1..2cf61c50ef6662 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_split_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_split_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/split_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_strided_slice_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_strided_slice_test.cc index f934d56ad7e123..3540057a7d676a 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_strided_slice_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_strided_slice_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/lite/delegates/xnnpack/strided_slice_tester.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_sub_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_sub_test.cc index 8a0e5f5204c851..bd5e92dcf16582 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_sub_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_sub_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_tanh_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_tanh_test.cc index 939d296fca6a72..708ac12112beca 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_tanh_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_tanh_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_conv_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_conv_test.cc index 095256fff2d656..7daae13ebdea16 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_conv_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_conv_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_test.cc index 71ddef900fdcd7..d32af38da21c61 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/transpose_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_multiple_subgraph_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_multiple_subgraph_test.cc index 5e0df9cb63982b..22a59e76720a5d 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_multiple_subgraph_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_multiple_subgraph_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_test.cc index 2d54b207969d64..5c083a37570ce9 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/slice_test.cc b/tensorflow/lite/delegates/xnnpack/slice_test.cc index 3a80181a6e74b7..3a1790b1143d85 100644 --- a/tensorflow/lite/delegates/xnnpack/slice_test.cc +++ b/tensorflow/lite/delegates/xnnpack/slice_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/slice_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/slice_tester.cc b/tensorflow/lite/delegates/xnnpack/slice_tester.cc index 5c1aa6c5921242..da97c89e983645 100644 --- a/tensorflow/lite/delegates/xnnpack/slice_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/slice_tester.cc @@ -25,9 +25,13 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/softmax_test.cc b/tensorflow/lite/delegates/xnnpack/softmax_test.cc index ae33a1afad37af..f55d3c23f66019 100644 --- a/tensorflow/lite/delegates/xnnpack/softmax_test.cc +++ b/tensorflow/lite/delegates/xnnpack/softmax_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/softmax_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" From 4c6ae1a88c8818740ae877787929728c79253d85 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 01:20:34 -0700 Subject: [PATCH 232/483] [XLA:GPU] Increase the size limit for dot merger to infinity (behind a flag). PiperOrigin-RevId: 678585161 --- third_party/xla/xla/debug_options_flags.cc | 6 ++++++ third_party/xla/xla/service/gpu/gpu_compiler.cc | 9 ++++++--- third_party/xla/xla/xla.proto | 5 ++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 281f936fd772bb..71602cb4e8d6bf 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_executable_terminate_timeout_seconds(30); opts.set_xla_gpu_experimental_disable_binary_libraries(false); opts.set_xla_experimental_ignore_channel_id(false); + opts.set_xla_gpu_dot_merger_threshold_mb(32); return opts; } @@ -1951,6 +1952,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_experimental_ignore_channel_id), debug_options->xla_experimental_ignore_channel_id(), "Experimental: ignore channel ids for collective operations.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_dot_merger_threshold_mb", + int32_setter_for(&DebugOptions::set_xla_gpu_dot_merger_threshold_mb), + debug_options->xla_gpu_dot_merger_threshold_mb(), + "Dot merger pass threshold to be set in MB.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 3bc29969e3997c..68a26a18bb5f6e 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -790,9 +790,12 @@ absl::Status RunOptimizationPasses( // AlgebraicSimplifier may add contracting dimensions to a dot. pipeline.AddPass(); pipeline.AddPass(); - // Only merge "smallish" dots. This threshold was not set carefully, but - // so far we know that 1mb is too small. - pipeline.AddPass(/*max_size_to_merge=*/int64_t{32} << 20); + // Only merge "smallish" dots. This threshold defaults to 32MB today, with + // a flag to override. + pipeline.AddPass( + /*max_size_to_merge=*/int64_t{ + debug_options.xla_gpu_dot_merger_threshold_mb()} + << 20); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 74eaf1166e459e..018320d9c8df69 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -978,7 +978,10 @@ message DebugOptions { // for collectives in the given HLO. bool xla_experimental_ignore_channel_id = 330; - // Next id: 331 + // DotMerger pass threshold size to be used in MB. + int32 xla_gpu_dot_merger_threshold_mb = 331; + + // Next id: 332 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 4ec2d54f63c5fd1d621044e67828fd8d82812487 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 25 Sep 2024 01:38:14 -0700 Subject: [PATCH 233/483] Allow vectorization in DynamicUpdateSlice in-place emitter. We can use the same conditions for when to allow vectorization as for loop fusion. PiperOrigin-RevId: 678591064 --- third_party/xla/xla/service/gpu/fusions/BUILD | 1 + .../fusions/in_place_dynamic_update_slice_mlir.cc | 5 +++-- .../fusions/in_place_dynamic_update_slice_mlir.h | 7 +++++-- .../vectorize_x1_too_small.hlo | 14 ++++++++++++++ .../tests/dynamic_update_slice/vectorize_x4.hlo | 14 ++++++++++++++ 5 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo create mode 100644 third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 290c451dfffb8b..9b813219939ef2 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -16,6 +16,7 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc index 823b3c8f765ad6..db4605be601a29 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc @@ -69,7 +69,8 @@ LaunchDimensions MlirInPlaceDynamicUpdateSliceFusion::launch_dimensions() const { const auto& update_shape = dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); - return CalculateLaunchDimensions(update_shape, analysis_.device_info()); + return CalculateLaunchDimensions(update_shape, analysis_.device_info(), + config_); } std::optional @@ -84,7 +85,7 @@ MlirInPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( // It is guaranteed that all DUS ops have the same output shape at this point. const auto& update_shape = dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); - return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, + return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor, update_shape, indexing_context); } diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h index 2ed84a06522b16..7d94f74ced0e46 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" @@ -47,8 +48,9 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase { explicit MlirInPlaceDynamicUpdateSliceFusion( const HloFusionAnalysis& analysis) : analysis_(analysis), - dus_ops_( - GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {} + dus_ops_(GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())), + config_(ComputeLoopFusionConfig( + analysis, dus_ops_[0].instruction().operand(1)->shape())) {} LaunchDimensions launch_dimensions() const override; @@ -77,6 +79,7 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase { private: const HloFusionAnalysis& analysis_; std::vector dus_ops_; + LaunchDimensionsConfig config_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo new file mode 100644 index 00000000000000..b9dbddaa26bb4d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo @@ -0,0 +1,14 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=dus:1 + +dus { + %input = f32[40,40,300] parameter(0) + %update = f32[1,1,40] parameter(1) + %idx = s32[] parameter(2) + %zero = s32[] constant(0) + ROOT dus = f32[40,40,300] dynamic-update-slice(%input, %update, %idx, %zero, %zero) +} + +// CHECK-NOT: vector.transfer_read {{.*}} vector<4xf32> +// CHECK-NOT: vector.transfer_write {{.*}} vector<4xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo new file mode 100644 index 00000000000000..775dd248d0bb7b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo @@ -0,0 +1,14 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=dus:1 + +dus { + %input = f32[40,40,300] parameter(0) + %update = f32[20,40,300] parameter(1) + %idx = s32[] parameter(2) + %zero = s32[] constant(0) + ROOT dus = f32[40,40,300] dynamic-update-slice(%input, %update, %idx, %zero, %zero) +} + +// CHECK: vector.transfer_read {{.*}} vector<4xf32> +// CHECK: vector.transfer_write {{.*}} vector<4xf32> From 869d6ccb2ea44fb4ac7a59dc519179653ed78e5c Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Wed, 25 Sep 2024 01:40:32 -0700 Subject: [PATCH 234/483] Lower transpose to materialize & insert op PiperOrigin-RevId: 678591749 --- .../gpu/fusions/tests/transpose/epilogue.hlo | 7 +- .../tests/transpose/fused_transpose_021.hlo | 9 +- .../tests/transpose/fused_transpose_102.hlo | 7 +- .../tests/transpose/fused_transpose_210.hlo | 7 +- .../xla/service/gpu/fusions/transpose_mlir.cc | 97 ++++++++++++------- .../service/gpu/model/indexing_analysis.cc | 15 +-- .../xla/service/gpu/model/indexing_analysis.h | 2 + 7 files changed, 88 insertions(+), 56 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo index 1ca70362596696..25695c8212f7d0 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo @@ -11,10 +11,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_p0 -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_p0 +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo index 60e5cd404e1504..7c2d63c78a47ef 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo @@ -9,13 +9,14 @@ fusion { ROOT %abs = f32[20,170,160] abs(%transpose) } // CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[INPUT:.*]]: tensor<20x160x170xf32> { // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_exp -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize +// CHECK-SAME: @fusion_exp(%[[INPUT]]) at #indexing_map +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo index 55c2976d32b341..2fc3855efe4c11 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo @@ -10,10 +10,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x3xi8> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x33x3xi8> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[P0:.*]] = xla_gpu.pure_call @fusion_p0 -// CHECK: tensor.insert %[[P0]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_p0 +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo index 0dd4a27547514f..97e23f171b713a 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo @@ -12,10 +12,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x20xf32> // // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x1x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_exp -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_exp +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index fd18cef310a8fb..57e7d4fa104f5d 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" @@ -211,6 +212,8 @@ IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing( std::vector dim_var_sizes(6, 1); dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] = kNumThreadsPerBlock; + dim_var_sizes[KernelFusionInterface::kIndexingMapBlockIdxDims[0]] = + Product(block_counts_); return {mlir::AffineMap::get(6, 2, thread_offsets, ctx), DimVarsFromTensorSizes(dim_var_sizes), RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}), @@ -233,44 +236,72 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( } else { ++shmem_tensor_size.back(); } + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + SmallVector callee_operands( + entry_function.getArguments().take_front(num_inputs)); + auto tids_and_bids = EmitThreadAndBlockIds(builder); + auto identity_map = + IndexingMapAttr::get(ctx, CreateIdentityMap(shmem_tensor_size, ctx)); + + // We can assume that all transpose operands have the same shape. + Shape operand_shape = shmem_transposes_.front()->operand(0)->shape(); - // Allocate shared memory. - SmallVector inits; + // Indexing for MaterializeOp to read from input. + auto indexing = GetIndexing(/*input=*/true, operand_shape, ctx); + + // Indexing for InsertOp to write into shared memory. + IndexingMap write_indexing = GetSharedMemoryIndexing(/*read=*/false, ctx); + // As we are writing the same elements that we are reading, any read + // constraints can also be constraints for the write. + for (auto constraint : indexing.GetConstraints()) { + write_indexing.AddConstraint(constraint.first, constraint.second); + } + for (auto [index, bound] : llvm::enumerate(indexing.GetSymbolBounds())) { + write_indexing.GetMutableSymbolBound(index) = bound; + } + write_indexing.Simplify(); + auto dimensions = SmallVector(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + SmallVector shmem_tensors; for (auto* transpose : shmem_transposes_) { auto elem_type = mlir_converter::PrimitiveTypeToMlirType( transpose->shape().element_type(), builder); - inits.push_back(builder.create( - RankedTensorType::get(shmem_tensor_size, elem_type))); + auto shmem = builder.create( + RankedTensorType::get(shmem_tensor_size, elem_type)); + auto indexed_vector = + IndexedVectorType::get(ctx, shmem_tensor_size, elem_type, + IndexingMapAttr::get(ctx, write_indexing)); + auto callee = + mlir::SymbolRefAttr::get(call_target_provider(transpose->operand(0))); + + auto materialized = builder.create( + /* result_type=*/indexed_vector, + /*input=*/callee_operands, + /*indices(dimensions)=*/tids_and_bids, + /*callee=*/callee, + /*map=*/IndexingMapAttr::get(ctx, indexing)); + + auto insert = builder.create( + /*result_type=*/shmem.getType(), + /*source=*/materialized.getResult(), + /*indices(dimensions)=*/tids_and_bids, + /*dest=*/shmem, + /*map=*/identity_map); + shmem_tensors.push_back(insert.getResult()); } - // Add output arguments for side outputs. - int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + // Produce all side outputs and then write them. + SmallVector side_output_inits; for (int index : side_output_root_indices_) { - inits.push_back(entry_function.getArgument(num_inputs + index)); + side_output_inits.push_back(entry_function.getArgument(num_inputs + index)); } - - IndexingMap write_indexing = GetSharedMemoryIndexing(/*read=*/false, ctx); auto body_builder = [&](ValueRange symbol_values, ValueRange map_results, ValueRange output_tensors) -> SmallVector { auto input_indices = [&](const HloInstruction* instr) { return ApplyIndexing(GetIndexing(/*input=*/true, instr->shape(), ctx), thread_and_block_ids, symbol_values, builder); }; - SmallVector result_tensors; - auto shmem_indices = ApplyIndexing(write_indexing, thread_and_block_ids, - symbol_values, builder); - for (auto [transpose, output] : - llvm::zip(shmem_transposes_, output_tensors)) { - // Emit loop that writes subgraphs of transpose operands to shmem. - auto result_scalar = mlir_converter::ProvideParameter( - root_computation, transpose, - /*operand_index=*/0, input_indices(transpose->operand(0)), - call_target_provider, entry_function, builder)[0]; - result_tensors.push_back(builder.create( - result_scalar, output, shmem_indices)); - } - // Produce all side outputs and then write them. SmallVector side_outputs; SmallVector> side_output_indices; auto* root_tuple = fusion.fused_expression_root(); @@ -283,22 +314,21 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( side_outputs.append(param_values.begin(), param_values.end()); } + SmallVector result_tensors; for (const auto& [value, indices, output] : - llvm::zip(side_outputs, side_output_indices, - output_tensors.take_back(side_output_roots_.size()))) { + llvm::zip(side_outputs, side_output_indices, output_tensors)) { result_tensors.push_back( builder.create(value, output, indices)); } return result_tensors; }; - - auto indexing = GetIndexing( - /*input=*/true, shmem_transposes_.front()->operand(0)->shape(), ctx); - auto written_vector = mlir_converter::EmitXlaLoopOp( - builder, thread_and_block_ids, inits, indexing, body_builder); - ValueRange written = written_vector; - auto shmem_tensors = written.take_front(shmem_transposes_.size()); + mlir::ValueRange side_output_vector; + if (!side_output_inits.empty()) { + side_output_vector = mlir_converter::EmitXlaLoopOp( + builder, thread_and_block_ids, side_output_inits, indexing, + body_builder); + } WriteResult result; result.shmem_tensors = @@ -307,8 +337,7 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( .getResults(); result.updated_outputs = output_args; for (auto [index, side_output_result] : - llvm::zip(side_output_root_indices_, - written.take_back(side_output_roots_.size()))) { + llvm::zip(side_output_root_indices_, side_output_vector)) { result.updated_outputs[index] = side_output_result; } return result; diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index 18ed0d526862a7..ad9215c921b09f 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -1151,18 +1151,21 @@ std::vector ToTransposeDimensions(const Layout& l) { } // namespace +IndexingMap CreateIdentityMap(absl::Span dimensions, + mlir::MLIRContext* mlir_context) { + return IndexingMap::FromTensorSizes( + AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), + /*dim_upper_bounds=*/dimensions, /*symbol_upper_bounds=*/{}, + /*is_simplified=*/dimensions.empty()); +} + IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) { if (shape.IsTuple()) { // Should happen only for variadic reduce. In that case all tuple shapes are // equal. return CreateIdentityMap(shape.tuple_shapes(0), mlir_context); } - - auto dimensions = shape.dimensions(); - IndexingMap identity_map = IndexingMap::FromTensorSizes( - AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), - dimensions, {}, /*is_simplified=*/dimensions.empty()); - return identity_map; + return CreateIdentityMap(shape.dimensions(), mlir_context); } llvm::SmallVector DelinearizeInBoundsIndex( diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h index 965b060da30be7..d4c170aace2063 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h @@ -163,6 +163,8 @@ std::vector DelinearizeIndex(absl::Span dims, // Creates an identity indexing map corresponding to the parameter shape. IndexingMap CreateIdentityMap(const Shape& shape, mlir::MLIRContext* mlir_context); +IndexingMap CreateIdentityMap(absl::Span dimensions, + mlir::MLIRContext* mlir_context); llvm::SmallVector DelinearizeInBoundsIndex( mlir::AffineExpr linear, absl::Span sizes); From 3e7f3a5a5329b439841b8559ee0696a6f1b0c467 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 02:03:53 -0700 Subject: [PATCH 235/483] compat: Update forward compatibility horizon to 2024-09-25 PiperOrigin-RevId: 678599408 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index df65c9e37bbc68..39c9f37847b730 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 24) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 25) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 8a9aebe99b4b68e55260beb0a5846031fc1da76f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 02:03:54 -0700 Subject: [PATCH 236/483] Update GraphDef version to 1996. PiperOrigin-RevId: 678599411 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 71f6b4932cbd33..b7d03cb07c7474 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1995 // Updated: 2024/9/24 +#define TF_GRAPH_DEF_VERSION 1996 // Updated: 2024/9/25 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 76b794b1137282782acb49203b2ff11d97bfef5d Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Wed, 25 Sep 2024 02:04:45 -0700 Subject: [PATCH 237/483] PR #17239: Remove erroneous bash invocation from LSP docs Imported from GitHub PR https://github.com/openxla/xla/pull/17239 Copybara import of the project: -- 4802f4dc2282b84229bfe24b0f15ca693cfc189b by Andrey Portnoy : Remove erroneous bash invocation from LSP docs Merging this change closes #17239 PiperOrigin-RevId: 678599791 --- third_party/xla/docs/lsp.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/docs/lsp.md b/third_party/xla/docs/lsp.md index 12dd0f17247823..a29067866e662e 100644 --- a/third_party/xla/docs/lsp.md +++ b/third_party/xla/docs/lsp.md @@ -15,6 +15,6 @@ each file in a project. Use the [build_tools/lint/generate_compile_commands.py](https://github.com/openxla/xla/blob/main/build_tools/lint/generate_compile_commands.py) script. The following invocation from XLA repo root generates a -`compile_commands.json` file in place: `bash bazel aquery "mnemonic(CppCompile, +`compile_commands.json` file in place: `bazel aquery "mnemonic(CppCompile, //xla/...)" --output=jsonproto | \ python3 build_tools/lint/generate_compile_commands.py` From 48c6c4d849fe17bb22191339f86fcb72dee68153 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 25 Sep 2024 02:06:11 -0700 Subject: [PATCH 238/483] [XLA:GPU] Automatically unroll a while loop by a factor of two if collectives are present in it's body. Flag flip will be done in the subsequent change. PiperOrigin-RevId: 678600322 --- third_party/xla/xla/debug_options_flags.cc | 21 ++++ .../xla/xla/service/gpu/gpu_compiler.cc | 6 + .../xla/xla/service/gpu/transforms/BUILD | 2 + .../double_buffer_loop_unrolling.cc | 19 ++++ .../transforms/double_buffer_loop_unrolling.h | 2 +- .../double_buffer_loop_unrolling_test.cc | 106 ++++++++++++++++-- third_party/xla/xla/xla.proto | 10 +- 7 files changed, 157 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 71602cb4e8d6bf..8c62942f2742e9 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -621,6 +621,17 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return absl::StrJoin(collective_ops, ", ", Formatter()); }; + // Custom parser for `xla_gpu_enable_while_loop_unrolling` flag. + auto setter_for_xla_gpu_enable_while_loop_unrolling = + [&debug_options](absl::string_view input) { + DebugOptions::WhileLoopUnrolling unroll_strategy; + bool parsed = DebugOptions::WhileLoopUnrolling_Parse( + absl::AsciiStrToUpper(input), &unroll_strategy); + if (!parsed) return false; + debug_options->set_xla_gpu_enable_while_loop_unrolling(unroll_strategy); + return true; + }; + // Custom parser for xla_gpu_disable_async_collectives. auto setter_for_xla_gpu_disable_async_collectives = [debug_options](const absl::string_view& input) { @@ -1161,6 +1172,16 @@ void MakeDebugOptionsFlags(std::vector* flag_list, " synchronous ones. By default, this is empty which indicates enabling" " async execution for all collectives. A sample usage is: " " --xla_gpu_disable_async_collectives=ALLREDUCE,REDUCESCATTER")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_while_loop_unrolling", + setter_for_xla_gpu_enable_while_loop_unrolling, + DebugOptions::WhileLoopUnrolling_Name( + debug_options->xla_gpu_enable_while_loop_unrolling()), + "Enables while loop unrolling features. " + "`WHILE_LOOP_UNROLLING_DOUBLE_BUFFER` unrolls the loop by factor of 2, " + "`WHILE_LOOP_UNROLLING_FULL_UNROLL` will unroll the entire loop " + "`WHILE_LOOP_UNROLLING_AUTO_UNROLL` unrolls by a factor of 2, if there is" + " any collective present within a while loop.")); flag_list->push_back(tsl::Flag( "xla_gpu_all_reduce_combine_threshold_bytes", int64_setter_for( diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 68a26a18bb5f6e..30577edad54fca 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1095,6 +1095,12 @@ absl::Status RunPostFusionPasses( "`xla_gpu_enable_while_loop_double_buffering` flag."; unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll; } + if (opts.xla_gpu_enable_while_loop_unrolling() == + DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL && + opts.xla_gpu_enable_heuristic_pass_configuration() && + !opts.xla_gpu_enable_while_loop_double_buffering()) { + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kAuto; + } if (unroll_strategy != std::nullopt) { pipeline.AddPass(*unroll_strategy); pipeline.AddPass(); diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 31c99ac6984708..e70d040a304457 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -1405,6 +1405,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/utils:hlo_query", "//xla/service:tuple_simplifier", "//xla/tests:filecheck", @@ -1412,6 +1413,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc index f1d9248ae12d94..0c25e2bf4691e4 100644 --- a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc @@ -523,6 +523,23 @@ absl::StatusOr DoubleBufferingUnroll(HloInstruction* while_instr, return true; // changed } +// Function performs double buffering unrolling strategy iff there is any +// collective operation within a body computation. +absl::StatusOr AutoUnroll(HloInstruction* while_instr, + HloModule* module) { + CHECK_EQ(while_instr->opcode(), HloOpcode::kWhile); + + bool any_collective_present = absl::c_any_of( + while_instr->while_body()->MakeInstructionPostOrder(), + [](HloInstruction* instr) { + return hlo_query::IsCollectiveCommunicationOp(instr->opcode()); + }); + if (any_collective_present) { + return DoubleBufferingUnroll(while_instr, module); + } + return false; // IR not changed. +} + } // namespace absl::StatusOr DoubleBufferLoopUnrolling::Run( @@ -555,6 +572,8 @@ absl::StatusOr DoubleBufferLoopUnrolling::Run( TF_ASSIGN_OR_RETURN(changed, FullyUnroll(while_instr, module)); } else if (unroll_strategy_ == UnrollStrategy::kDoubleBuffer) { TF_ASSIGN_OR_RETURN(changed, DoubleBufferingUnroll(while_instr, module)); + } else if (unroll_strategy_ == UnrollStrategy::kAuto) { + TF_ASSIGN_OR_RETURN(changed, AutoUnroll(while_instr, module)); } else { LOG(FATAL) << absl::StrCat("Unhandled unrolling strategy: ", unroll_strategy_); diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h index c3d774fe60b176..aa4803457a1815 100644 --- a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h @@ -47,7 +47,7 @@ namespace gpu { // unrolled. class DoubleBufferLoopUnrolling : public HloModulePass { public: - enum class UnrollStrategy { kDoubleBuffer, kFullUnroll }; + enum class UnrollStrategy { kDoubleBuffer, kFullUnroll, kAuto }; explicit DoubleBufferLoopUnrolling( UnrollStrategy unroll_strategy = UnrollStrategy::kDoubleBuffer) diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc index a12376e931defd..fedaf171ad8b6b 100644 --- a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc @@ -21,10 +21,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/tuple_simplifier.h" #include "xla/test.h" @@ -55,13 +57,103 @@ int64_t CountInstructions(HloModule& module, HloOpcode opcode) { return count; } -class GpuLoopDoubleBufferTransformerTest : public HloTestBase { - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_while_loop_double_buffering(true); - return debug_options; - } -}; +using GpuLoopDoubleBufferTransformerTest = HloTestBase; + +TEST_F(GpuLoopDoubleBufferTransformerTest, + AutoUnrollLoopWhenCollectivesArePresent) { + absl::string_view kModuleString = R"( +HloModule m +condition { + input_tuple = (f32[], s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=1 + trip_count = s32[] constant(10) + ROOT done = pred[] compare(cond, trip_count), direction=LT +} + +ar_add { + Arg_1 = f32[] parameter(1) + Arg_0 = f32[] parameter(0) + ROOT add_ar = f32[] add(Arg_1, Arg_0) +} + +body { + input_tuple = (f32[], s32[]) parameter(0) + param_0 = f32[] get-tuple-element(input_tuple), index=0 + cond = s32[] get-tuple-element(input_tuple), index=1 + all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}" + one = s32[] constant(1) + all-reduce-done = f32[] all-reduce-done(all-reduce-start) + cond_plus_1 = s32[] add(cond, one) + ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1) +} + +ENTRY main { + param_0 = f32[] parameter(0) + param_2 = s32[] constant(0) + tuple = (f32[], s32[]) tuple(param_0, param_2) + ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + HloPassPipeline pipeline("double-buffering-pipeline"); + DoubleBufferLoopUnrolling unroller( + DoubleBufferLoopUnrolling::UnrollStrategy::kAuto); + TF_ASSERT_OK_AND_ASSIGN(bool changed, unroller.Run(module.get())); + + EXPECT_TRUE(changed); + + HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode( + *module->entry_computation(), HloOpcode::kWhile); + TF_ASSERT_OK_AND_ASSIGN( + WhileLoopBackendConfig config, + while_instruction->backend_config()); + EXPECT_EQ(config.known_trip_count().n(), 5); + EXPECT_EQ(CountInstructions((*while_instruction->while_body()), + HloOpcode::kAllReduceStart), + 2); +} + +TEST_F(GpuLoopDoubleBufferTransformerTest, + DoNotAutoUnrollLoopWhenCollectivesAreNotPresent) { + absl::string_view kModuleString = R"( +HloModule m +condition { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + trip_count = s32[] constant(10) + ROOT done = pred[] compare(cond, trip_count), direction=LT +} + +body { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + one = s32[] constant(1) + cond_plus_1 = s32[] add(cond, one) + ROOT output_tuple = (s32[]) tuple(cond_plus_1) +} + +ENTRY main { + param_0 = s32[] constant(0) + tuple = (s32[]) tuple(param_0) + ROOT while = (s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + DoubleBufferLoopUnrolling unroller( + DoubleBufferLoopUnrolling::UnrollStrategy::kAuto); + TF_ASSERT_OK_AND_ASSIGN(bool changed, unroller.Run(module.get())); + + EXPECT_FALSE(changed); + + HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode( + *module->entry_computation(), HloOpcode::kWhile); + TF_ASSERT_OK_AND_ASSIGN( + WhileLoopBackendConfig config, + while_instruction->backend_config()); + EXPECT_EQ(config.known_trip_count().n(), 10); +} TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollOddTripCountTest) { const char* const kModuleString = R"( diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 018320d9c8df69..c8900a65753b32 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -740,6 +740,9 @@ message DebugOptions { WHILE_LOOP_UNROLLING_DOUBLE_BUFFER = 1; // Enables full loop unrolling using the same strategy as `DOUBLE_BUFFER`. WHILE_LOOP_UNROLLING_FULL_UNROLL = 2; + // Enables loop unrolling when we have at least one collective within a + // while loop. + WHILE_LOOP_UNROLLING_AUTO_UNROLL = 3; } // Determine the while loop unrolling scheme. @@ -981,7 +984,12 @@ message DebugOptions { // DotMerger pass threshold size to be used in MB. int32 xla_gpu_dot_merger_threshold_mb = 331; - // Next id: 332 + // If enabled, in the absence of user provided knobs might tune pass + // configurations based on the HLO. For example it decides to unroll the while + // loop by a factor of two if a collective op is present. + bool xla_gpu_enable_heuristic_pass_configuration = 332; + + // Next id: 333 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 3a28e03be84d06e7f43611f124a9fbde2d31058b Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 25 Sep 2024 02:08:10 -0700 Subject: [PATCH 239/483] [XLA:GPU][IndexAnalysis] Use the parser in indexing_map_test. Add support for parsing of negative int literals and affine expressions with (). PiperOrigin-RevId: 678601020 --- third_party/xla/xla/service/gpu/model/BUILD | 2 + .../gpu/model/indexing_map_serialization.cc | 57 +- .../model/indexing_map_serialization_test.cc | 15 +- .../service/gpu/model/indexing_map_test.cc | 972 +++++++++++------- 4 files changed, 658 insertions(+), 388 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 978b6695c4353f..c793c7508c133f 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -506,6 +506,7 @@ xla_cc_test( deps = [ ":affine_map_printer", ":indexing_analysis", + ":indexing_map_serialization", ":indexing_test_utils", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", @@ -513,6 +514,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc index af48bcb1f318d5..31e7bfaa53ec21 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc @@ -209,20 +209,25 @@ bool Parser::ParseInterval(Interval* interval) { bool Parser::ParseAffineExprString(std::string* affine_expr_str) { unsigned num_unmatched_parens = 0; while (true) { - if (!IsPartOfAffineExpr(current_token_)) { - if (ConsumeToken(Token::Kind::kLParen)) { - ++num_unmatched_parens; - } else if (current_token_.kind == Token::Kind::kRParen && - num_unmatched_parens > 0) { - --num_unmatched_parens; - Advance(); - } else { - break; - } + if (IsPartOfAffineExpr(current_token_)) { + affine_expr_str->append(current_token_.spelling); + affine_expr_str->push_back(' '); + Advance(); + continue; } - affine_expr_str->append(current_token_.spelling); - affine_expr_str->push_back(' '); - Advance(); + if (ConsumeToken(Token::Kind::kLParen)) { + affine_expr_str->push_back('('); + ++num_unmatched_parens; + continue; + } + if (current_token_.kind == Token::Kind::kRParen && + num_unmatched_parens > 0) { + affine_expr_str->push_back(')'); + --num_unmatched_parens; + Advance(); + continue; + } + break; } return current_token_.kind != Token::Kind::kError; } @@ -302,11 +307,20 @@ Token Parser::GetNextTokenImpl() { } if (*it_ == '-') { ++it_; - if (it_ != input_.end() && *it_ == '>') { - ++it_; - return Token{"->", Token::Kind::kArrow}; - } else { - return Token{"-", Token::Kind::kMinus}; + if (it_ != input_.end()) { + if (*it_ == '>') { + ++it_; + return Token{"->", Token::Kind::kArrow}; + } else if (std::isdigit(*it_)) { + auto start = it_ - 1; + while (it_ != input_.end() && std::isdigit(*it_)) { + ++it_; + } + StringRef spelling = input_.substr(start - input_.data(), it_ - start); + return Token{spelling, Token::Kind::kIntLiteral}; + } else { + return Token{"-", Token::Kind::kMinus}; + } } } StringRef spelling = input_.substr(start - input_.data(), 1); @@ -407,6 +421,7 @@ std::optional ParseIndexingMap(llvm::StringRef input, if (!parser.ConsumeToken(Token::Kind::kComma) || !parser.ConsumeToken(Token::Kind::kKeywordDomain) || !parser.ConsumeToken(Token::Kind::kColon)) { + llvm::errs() << "Failed to parse domain keyword\n"; return std::nullopt; } // Parse dimension variables. @@ -418,9 +433,11 @@ std::optional ParseIndexingMap(llvm::StringRef input, !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || !parser.ConsumeToken(Token::Kind::kComma)) { + llvm::errs() << "Failed to parse DimVar\n"; return std::nullopt; } if (var_name != dim_name) { + llvm::errs() << "Dimension name mismatch\n"; return std::nullopt; } dim_vars.push_back(DimVar{interval}); @@ -434,9 +451,11 @@ std::optional ParseIndexingMap(llvm::StringRef input, !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || !parser.ConsumeToken(Token::Kind::kComma)) { + llvm::errs() << "Failed to parse RangeVar\n"; return std::nullopt; } if (var_name != symbol_var) { + llvm::errs() << "Symbol name mismatch\n"; return std::nullopt; } range_vars.push_back(RangeVar{interval}); @@ -450,6 +469,7 @@ std::optional ParseIndexingMap(llvm::StringRef input, !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || !parser.ConsumeToken(Token::Kind::kComma)) { + llvm::errs() << "Failed to parse constraint\n"; return std::nullopt; } affine_expr_strs.push_back(affine_expr_str); @@ -459,6 +479,7 @@ std::optional ParseIndexingMap(llvm::StringRef input, bool is_simplified; if (!parser.ConsumeToken(Token::Kind::kColon) || !parser.ParseBool(&is_simplified)) { + llvm::errs() << "Failed to parse is_simplified\n"; return std::nullopt; } // Check that the input is consumed. diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc index 7efd04b5442804..c7d39e8c8690fc 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc @@ -45,7 +45,7 @@ TEST_F(IndexingMapSerializationTest, DimsOnly) { (d0, d1) -> (d0 mod 2 + d1), domain: d0 in [0, 3], - d1 in [0, 4], + d1 in [-4, 4], is_simplified: true )"); } @@ -88,6 +88,19 @@ TEST_F(IndexingMapSerializationTest, DimsAndSymbolsAndConstraints) { )"); } +TEST_F(IndexingMapSerializationTest, AffineExprsWithParens) { + ParseAndCheck(R"( + (d0, d1)[s0, s1] -> ((d0 + d0 mod 3) floordiv 3 + + s0 + (s0 * 2) mod 3 + (d0 + s0) mod 3), + domain: + d0 in [0, 9], + d1 in [0, 19], + s0 in [0, 29], + s1 in [0, 39], + is_simplified: false + )"); +} + // This test will be updated when the printing uses types of variables. TEST_F(IndexingMapSerializationTest, CustomNames) { auto indexing_map_str = R"( diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index c7bd056b072a5b..c3b8669f1b46c7 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -27,12 +27,14 @@ limitations under the License. #include #include "absl/hash/hash_testing.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -49,6 +51,12 @@ using ::testing::ElementsAre; class IndexingMapTest : public HloTestBase { public: + IndexingMap Parse(absl::string_view indexing_map_str) { + auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); + EXPECT_TRUE(indexing_map.has_value()); + return *indexing_map; + } + mlir::MLIRContext mlir_context_; AffineMapPrinter printer_; }; @@ -112,10 +120,15 @@ TEST_F(IndexingMapTest, RTVar) { } TEST_F(IndexingMapTest, Evaluation) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {4, 4}, {2, 2}); - + IndexingMap indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 3], + s0 in [0, 1], + s1 in [0, 1], + is_simplified: false + )"); auto results = indexing_map.Evaluate( mlir::getAffineConstantExprs({1, 2}, &mlir_context_), mlir::getAffineConstantExprs({3, 4}, &mlir_context_)); @@ -136,12 +149,23 @@ TEST_F(IndexingMapTest, Evaluation) { } TEST_F(IndexingMapTest, Composition_Permutation) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {4, 4}, {2, 2}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {4}, {4}); + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 3], + s0 in [0, 1], + s1 in [0, 1], + is_simplified: false + )"); + + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 3], + s0 in [0, 3], + is_simplified: false + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -156,12 +180,23 @@ TEST_F(IndexingMapTest, Composition_Permutation) { } TEST_F(IndexingMapTest, Composition_RestrictedInterval) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {5, 6}, {7, 2}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 4], + d1 in [0, 5], + s0 in [0, 6], + s1 in [0, 1], + is_simplified: false + )"); + + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 9], + s0 in [0, 7], + is_simplified: false + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -176,20 +211,27 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { } TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {50, 60}, {70, 20}); - producer.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - producer.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{1, 1}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); - consumer.AddConstraint(ParseAffineExpr("d0 + s0", &mlir_context_), - Interval{0, 20}); - consumer.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), - Interval{0, 0}); + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 mod 8 in [0, 0], + s0 mod 3 in [1, 1], + is_simplified: false + )"); + + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 9], + s0 in [0, 7], + d0 + s0 in [0, 20], + s0 mod 4 in [0, 0], + is_simplified: false + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -311,14 +353,18 @@ TEST_F(IndexingMapTest, Composition_OnlyRTVars) { } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, s0, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint cannot be removed, because it contains a dimension. - indexing_map.AddConstraint(ParseAffineExpr("s0 + d0", &mlir_context_), - Interval{1, 100}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, s0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 + s0 in [1, 100], + s0 mod 3 in [0, 0], + is_simplified: false + )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, s0, s1), @@ -334,12 +380,17 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (s0, d1, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint can be removed, because it contains only the unused dim. - indexing_map.AddConstraint(ParseAffineExpr("d0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (s0, d1, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 mod 3 in [0, 0], + is_simplified: false + )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (s0, d0, s1), @@ -352,12 +403,17 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d0, d1, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint can be removed, because it contains only the unused symbol. - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d0, d1, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 mod 3 in [0, 0], + is_simplified: false + )"); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, d1, s0), @@ -370,17 +426,23 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap( - "(d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42)", - &mlir_context_), - {1, 2, 3, 4, 5}, {32, 64, 96}); - indexing_map.AddConstraint( - ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459}); - indexing_map.AddConstraint(ParseAffineExpr("s0 + s2", &mlir_context_), - Interval{0, 512}); - auto unused_vars = indexing_map.RemoveUnusedVars(); + auto indexing_map = Parse(R"( + (d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42), + domain: + d0 in [0, 0], + d1 in [0, 1], + d2 in [0, 2], + d3 in [0, 3], + d4 in [0, 4], + s0 in [0, 31], + s1 in [0, 63], + s2 in [0, 95], + s0 * 4 + d1 + d3 in [24, 459], + s0 + s2 in [0, 512], + is_simplified: false + )"); // dimensions d0, d2, d4 and symbol s1 will be removed. + auto unused_vars = indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42), domain: @@ -398,14 +460,18 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 + s1 in [1, 100], + s0 mod 3 in [0, 0], + is_simplified: false + )"); // This constraint cannot be removed, because it contains a "used symbol". - indexing_map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), - Interval{1, 100}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, d0, s1), @@ -421,12 +487,17 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 mod 3 in [0, 0], + is_simplified: false + )"); // This constraint can be removed, because it contains only the unused symbol. - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d1, d0, s0), @@ -439,10 +510,13 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); - indexing_map.AddConstraint(ParseAffineExpr("0", &mlir_context_), - Interval{-10, 5}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 49], + 0 in [-10, 5], + is_simplified: false + )"); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0) -> (d0), domain: @@ -452,25 +526,34 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { } TEST_F(IndexingMapTest, KnownEmpty_CreatingIndexingMapWithInfeasibleRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {-1}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, -2], + is_simplified: false + )"); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, KnownEmpty_AddingConstraintOutOfRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 49], + 0 in [10, 15], + is_simplified: false + )"); // Addition of this constraint makes the domain empty. - indexing_map.AddConstraint(ParseAffineExpr("0", &mlir_context_), - Interval{10, 15}); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, KnownEmpty_Composition) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); - IndexingMap known_empty = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (0)", &mlir_context_), {0}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0), domain: d0 in [0, 49], is_simplified: false + )"); + auto known_empty = Parse(R"( + (d0) -> (d0), domain: d0 in [0, -1], is_simplified: false + )"); EXPECT_THAT(known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(indexing_map * known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(known_empty * indexing_map, MatchIndexingMap("KNOWN EMPTY")); @@ -480,22 +563,33 @@ TEST_F(IndexingMapTest, KnownEmpty_Composition) { TEST_F(IndexingMapTest, KnownEmpty_AddingConstraintOutOfRangeAfterSimplification) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); - indexing_map.AddConstraint(ParseAffineExpr("s1 floordiv 20", &mlir_context_), - Interval{2, 2}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s1 floordiv 20 in [2, 2], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", - &mlir_context_), - {32}, {1, 2, 3, 4, 5}); - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 4 + s1 + s3", &mlir_context_), Interval{24, 459}); + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42), + domain: + d0 in [0, 31], + s0 in [0, 0], + s1 in [0, 1], + s2 in [0, 2], + s3 in [0, 3], + s4 in [0, 4], + d0 * 4 + s1 + s3 in [24, 459], + is_simplified: false + )"); indexing_map.RemoveUnusedSymbols(); // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -562,11 +656,13 @@ TEST_F(IndexingMapTest, ConvertSymbolsToDimensions) { } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("(d0 mod 8) + 5", &mlir_context_), - Interval{50, 54}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 mod 8 + 5 in [50, 54], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0), @@ -579,13 +675,15 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_IndependentOfSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1)", &mlir_context_), - {2000}, {2, 3}); - - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 6 + s0 * 3 + s1", &mlir_context_), - Interval{0, 599}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), + domain: + d0 in [0, 1999], + s0 in [0, 1], + s1 in [0, 2], + d0 * 6 + s0 * 3 + s1 in [0, 599], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), @@ -599,23 +697,27 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_NotIndependentOfSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1)", &mlir_context_), - {2000}, {2, 3}); - - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 6 + s0 * 3 + s1", &mlir_context_), - Interval{0, 598}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), + domain: + d0 in [0, 1999], + s0 in [0, 1], + s1 in [0, 2], + d0 * 6 + s0 * 3 + s1 in [0, 598], + is_simplified: false + )"); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 * 6 + s0 * 3)", &mlir_context_), {2000}, - {2}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 * 6 + s0 * 3", &mlir_context_), - Interval{0, 599}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 * 6 + s0 * 3), + domain: + d0 in [0, 1999], + s0 in [0, 1], + d0 * 6 + s0 * 3 in [0, 599], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0 * 6 + s0 * 3), @@ -628,11 +730,13 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorPositiveBounds) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 floordiv 8", &mlir_context_), - Interval{5, 11}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 floordiv 8 in [5, 11], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0), @@ -644,12 +748,14 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv 3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 floordiv 3 in [-11, -5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0), @@ -662,12 +768,14 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivNegativeDivisorNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv -3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 floordiv -3 in [-11, -5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0), @@ -680,11 +788,13 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulPositiveMultiplierPositiveBounds) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 * 8", &mlir_context_), - Interval{14, 33}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 * 8 in [14, 33], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0), @@ -696,12 +806,14 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulPositiveMultiplierNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 * 3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 * 3 in [-11, -5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0), @@ -714,12 +826,14 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulNegativeMultiplierNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 * -3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 * -3 in [-11, -5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0), @@ -731,20 +845,19 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, ConstraintMerge_Mod) { - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0, s1] -> (d0, s1, s0)", &mlir_context_), - {DimVar{{0, 4}}}, {RangeVar{{-21, -1}}, RangeVar{{0, 10}}}, - /*rt_vars=*/{}); - indexing_map.AddConstraint(ParseAffineExpr("d0 mod 3", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s1 mod 5", &mlir_context_), - Interval{1, 1}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [-21, -2], + s1 in [0, 10], + d0 mod 3 in [0, 0], + s0 mod 2 in [0, 0], + s0 mod 3 in [0, 0], + s1 mod 5 in [1, 1], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0, s1] -> (d0, s1, s0), domain: @@ -759,9 +872,12 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { } TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0) -> (d0)", &mlir_context_), - {DimVar{{5, 5}}}, /*range_vars=*/{}, /*rt_vars=*/{}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [5, 5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (5), @@ -774,11 +890,16 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { // This is a regression test for a bug where we didn't canonicalize the order // of summands correctly, leading to `Simplify` not being idempotent. - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (((((d0 + (d0 mod 3)) floordiv 3) + " - "(s0 + ((s0 + s0) mod 3))) + (((d0 + s0) mod 3) + 0)))", - &mlir_context_), - {10, 20}, {30, 40}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (((((d0 + (d0 mod 3)) floordiv 3) + + (s0 + ((s0 + s0) mod 3))) + (((d0 + s0) mod 3) + 0))), + domain: + d0 in [0, 9], + d1 in [0, 19], + s0 in [0, 29], + s1 in [0, 39], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); } @@ -786,20 +907,25 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression2) { // This is a regression test for a bug where we didn't simplify the affine // expression fully after a single iteration. - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> ((((s0 + d0) + d0) floordiv 2))", - &mlir_context_), - {10, 20}, {30, 40}); + auto indexing_map = Parse(R"( + (d0)[s0] -> ((((s0 + d0) + d0) floordiv 2)), + domain: + d0 in [0, 9], + s0 in [0, 19], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap( - "(d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6)", - &mlir_context_), - {12, 6}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6), + domain: + d0 in [0, 11], + d1 in [0, 5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 floordiv 6), @@ -811,9 +937,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { - IndexingMap indexing_map( - ParseAffineMap("(d0) -> (d0 mod 42)", &mlir_context_), {{53, 71}}, {}, - {}); + auto indexing_map = Parse(R"( + (d0) -> (d0 mod 42), + domain: + d0 in [53, 71], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0 - 42), @@ -824,8 +953,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { - IndexingMap indexing_map(ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), - {{-5, -1}}, {}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0 mod 5), + domain: + d0 in [-5, -1], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0 + 5), @@ -836,19 +969,22 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsNotAdd) { - IndexingMap indexing_map1( - ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), {{-4, 0}}, {}, {}); + auto indexing_map1 = + Parse("(d0) -> (d0 mod 5), domain: d0 in [-4, 0], is_simplified: false"); EXPECT_FALSE(indexing_map1.Simplify()); - IndexingMap indexing_map2( - ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), {{-6, -1}}, {}, {}); + auto indexing_map2 = + Parse("(d0) -> (d0 mod 5), domain: d0 in [-6, -1], is_simplified: false"); EXPECT_FALSE(indexing_map2.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 - (s0 floordiv 3) * 3 + s0)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 - (s0 floordiv 3) * 3 + s0), + domain: + d0 in [0, 1], + s0 in [0, 3], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + s0 mod 3), @@ -860,10 +996,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 - (s0 floordiv 3) * 12 + s0 * 7)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 - (s0 floordiv 3) * 12 + s0 * 7), + domain: + d0 in [0, 1], + s0 in [0, 3], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 mod 3) * 4 + s0 * 3), @@ -875,10 +1014,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (1 + d0 - ((s0 + 1) floordiv 3) * 3 + s0)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (1 + d0 - ((s0 + 1) floordiv 3) * 3 + s0), + domain: + d0 in [0, 1], + s0 in [0, 3], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 + 1) mod 3), @@ -891,9 +1033,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsIfSmallerThanDivisor) { - auto serialized_map = "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {8, 16}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16), + domain: + d0 in [0, 7], + d1 in [0, 15], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1), @@ -905,15 +1051,17 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { - auto serialized_map = - "(d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, " - "((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, " - "d2 mod 10)"; - - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {9, 9, 9}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, + ((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, + d2 mod 10), + domain: + d0 in [0, 8], + d1 in [0, 8], + d2 in [0, 8], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0, d1, d2), domain: @@ -926,12 +1074,15 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithDivisibleMultipliers) { - auto serialized_map = - "(d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, " - " (d0 * 16 + d1 * 4 + d2) mod 8)"; - - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, + (d0 * 16 + d1 * 4 + d2) mod 8), + domain: + d0 in [0, 9], + d1 in [0, 9], + d2 in [0, 9], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, @@ -945,11 +1096,14 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { - auto serialized_map = - "(d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, " - "d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {8, 9}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, + d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99), + domain: + d0 in [0, 7], + d1 in [0, 8], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1), @@ -961,10 +1115,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { - auto serialized_map = - "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715), + domain: + s0 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0 * 128), @@ -975,10 +1131,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { - auto serialized_map = - "(d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {1024, 128}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024), + domain: + d0 in [0, 1023], + d1 in [0, 127], + is_simplified: false + )"); + ; EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 128 + d1), @@ -990,11 +1150,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { - auto serialized_map = - "(d0, d1) -> (((d1 * 2 + d0 floordiv 64) mod 3) * 256 + (d0 mod 64) * 4 " - "+ ((d1 * 128 + d0) floordiv 192) * 768)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {128, 3072}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (((d1 * 2 + d0 floordiv 64) mod 3) * 256 + (d0 mod 64) * 4 + + ((d1 * 128 + d0) floordiv 192) * 768), + domain: + d0 in [0, 127], + d1 in [0, 3071], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 4 + d1 * 512), @@ -1007,9 +1170,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { TEST_F(IndexingMapTest, AffineMapSimplification_ModWithNegativeMultiplerDoesNotGetSimplified) { - auto serialized_map = "(d0) -> ((-d0) mod 2)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {128}, {}); + auto indexing_map = Parse(R"( + (d0) -> ((-d0) mod 2), + domain: + d0 in [0, 127], + is_simplified: false + )"); EXPECT_FALSE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> ((-d0) mod 2), @@ -1024,12 +1190,15 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { // `((d0 * 2 + d1 floordiv 64) floordiv 3) floordiv 1024`. // This test verifies that we can still simplify the map after the // simplification of the floordiv. - auto serialized_map = - "(d0, d1) -> ((d0 floordiv 1536) * 786432 + (((d0 * 2 + d1 floordiv " - "64) floordiv 3) mod 1024) * 768 + ((d0 * 2 + d1 floordiv 64) mod 3) * " - "256 + (d1 mod 64) * 4)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {3072, 128}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> ((d0 floordiv 1536) * 786432 + + (((d0 * 2 + d1 floordiv 64) floordiv 3) mod 1024) * 768 + + ((d0 * 2 + d1 floordiv 64) mod 3) * 256 + (d1 mod 64) * 4), + domain: + d0 in [0, 3071], + d1 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 512 + d1 * 4), @@ -1042,10 +1211,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { // We have s0 * 128 in the mod, but s0 * 64 in the floordiv *. - auto serialized_map = - "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715), + domain: + s0 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715), @@ -1056,11 +1227,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { } TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { - auto serialized_map = - "()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * " - "14)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * 14), + domain: + s0 in [0, 1233], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0), @@ -1071,9 +1243,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { } TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { - auto serialized_map = "()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3), + domain: + s0 in [0, 1233], + s1 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1] -> ((s0 * 128 + s1) floordiv 192), @@ -1085,9 +1261,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { } TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { - auto serialized_map = "()[s0] -> ((s0 * 6 + 9) floordiv 18)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 6 + 9) floordiv 18), + domain: + s0 in [0, 1233], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> ((s0 * 2 + 3) floordiv 6), @@ -1098,10 +1277,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { } TEST_F(IndexingMapTest, AffineMapSimplification_DivSumDiv) { - auto serialized_map = - "()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6), + domain: + s0 in [0, 1233], + s1 in [0, 127], + is_simplified: false + )"); // The rewrite tested in AffineMapSimplification_DivDiv must not trigger here. EXPECT_FALSE(indexing_map.Simplify()); } @@ -1110,18 +1292,25 @@ TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { // (s0 floordiv 2) floordiv -7 is not s0 floordiv -14: // 15 // 2 // -7 = -1 // 15 // -14 = -2 - auto serialized_map = "()[s0] -> ((s0 floordiv 2) floordiv -7)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 floordiv 2) floordiv -7), + domain: + s0 in [0, 1233], + is_simplified: false + )"); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { - auto serialized_map = - "()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod " - "20000)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {872, 4, 128, 896}); + auto indexing_map = Parse(R"( + ()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod 20000), + domain: + s0 in [0, 871], + s1 in [0, 3], + s2 in [0, 127], + s3 in [0, 895], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1, s2, s3] -> ( @@ -1138,11 +1327,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromDiv_NegativeMultiplier) { - auto serialized_map = - "()[s0, s1] -> ((s0 * 16 - (s1 floordiv 4) floordiv 2 + (s1 floordiv 8) " - "* 2) floordiv 4)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {2, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 * 16 - (s1 floordiv 4) floordiv 2 + (s1 floordiv 8) * 2) + floordiv 4), + domain: + s0 in [0, 1], + s1 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1] -> ( @@ -1156,12 +1348,16 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, RescaleSymbols_Simple) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{0, 0}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 6], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [0, 0], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), @@ -1175,12 +1371,16 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { } TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {42, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 41], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [3, 3], + is_simplified: false + )"); // [BEFORE] Allowed values for s0: 3, 9, 15, ..., 39 = (6 * 6 + 3) // [AFTER] Allowed values for s0: 0, 1, 2, ..., 6 EXPECT_TRUE(indexing_map.RescaleSymbols()); @@ -1196,14 +1396,17 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { } TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 7], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 2 in [0, 0], + s0 mod 3 in [0, 0], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), @@ -1217,14 +1420,17 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { } TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {10, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - indexing_map.AddConstraint(ParseAffineExpr("s0 * s2", &mlir_context_), - Interval{0, 28}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 9], + s1 in [0, 1], + s2 in [0, 5], + s0 * s2 in [0, 28], + s0 mod 6 in [3, 3], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3), @@ -1240,14 +1446,17 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraintsForTheSameSymbolWhichCannotBeMerged) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {100, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 7", &mlir_context_), - Interval{5, 5}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 99], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [3, 3], + s0 mod 7 in [5, 5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); const mlir::AffineExpr result3 = indexing_map.GetAffineMap().getResult(3); @@ -1274,14 +1483,17 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, RescaleSymbolsKeepsHashmapConsistent) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s0, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 * s1", &mlir_context_), - Interval{0, 100}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s0, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 6], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [0, 0], + s0 * s1 in [0, 100], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); for (auto& [expr, interval] : indexing_map.GetConstraints()) { @@ -1291,13 +1503,15 @@ TEST_F(IndexingMapTest, RescaleSymbolsKeepsHashmapConsistent) { } TEST_F(IndexingMapTest, RangeEvaluatorTest) { - auto serialized_map = "(d0, d1, d2, d3)[] -> (0)"; - IndexingMap indexing_map(ParseAffineMap(serialized_map, &mlir_context_), - {{Interval{0, 9}}, - {Interval{-10, -1}}, - {Interval{-1, 2}}, - {Interval{0, 0}}}, - {}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2, d3)[] -> (0), + domain: + d0 in [0, 9], + d1 in [-10, -1], + d2 in [-1, 2], + d3 in [0, 0], + is_simplified: false + )"); RangeEvaluator range_evaluator(indexing_map, &mlir_context_); mlir::AffineExpr d0, d1, d2, d3; bindDims(&mlir_context_, d0, d1, d2, d3); @@ -1936,44 +2150,64 @@ ENTRY e { TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { auto zero_dim_map = AffineMap::get(&mlir_context_); ExpectSupportsAbslHashAndEqAndNe( - {IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1 * 2, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {51, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {71, 80}), - [&] { - auto m = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}); - m.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - m.AddConstraint(ParseAffineExpr("d0 mod 16", &mlir_context_), - Interval{0, 0}); - return m; - }(), - [&] { - auto m = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}); - m.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - m.AddConstraint(ParseAffineExpr("d0 mod 32", &mlir_context_), - Interval{0, 0}); - return m; - }(), + {Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1 * 2, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 50], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + d0 mod 8 in [0, 0], + d0 mod 16 in [0, 0], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + d0 mod 8 in [0, 0], + d0 mod 32 in [0, 0], + is_simplified: false + )"), IndexingMap( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", &mlir_context_), From 39ac4333985542c58baa3170609ba542e64392af Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 25 Sep 2024 02:28:30 -0700 Subject: [PATCH 240/483] [XLA:GPU] Enable auto while loop double buffering. PiperOrigin-RevId: 678606953 --- third_party/xla/xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 8c62942f2742e9..277e8a3f75982e 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -220,7 +220,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cudnn_gemm_fusion_level(0); opts.set_xla_gpu_enable_while_loop_double_buffering(false); opts.set_xla_gpu_enable_while_loop_unrolling( - DebugOptions::WHILE_LOOP_UNROLLING_NO_UNROLL); + DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL); opts.set_xla_gpu_ensure_minor_dot_contraction_dims(false); opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true); opts.set_xla_gpu_llvm_verification_level(0); From 6951488f6bc6d091932a90a22f006cef9342c4a8 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 25 Sep 2024 02:51:28 -0700 Subject: [PATCH 241/483] [XLA:GPU] Add while-loop-simplifier before while loop double buffering. It's not needed for anything in particular, but it's a nice property to work on the simplified IR in subsequent passes. PiperOrigin-RevId: 678613784 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 30577edad54fca..166a0f4ff44048 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1102,6 +1102,7 @@ absl::Status RunPostFusionPasses( unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kAuto; } if (unroll_strategy != std::nullopt) { + pipeline.AddPass(); pipeline.AddPass(*unroll_strategy); pipeline.AddPass(); pipeline.AddPass(); From fb32967ae5a0070f24295d1b0b6b1c0c64ca5a72 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Wed, 25 Sep 2024 02:58:38 -0700 Subject: [PATCH 242/483] [XLA:GPU] Factor out the newly introduced set_to_default_entry_computation_layout parser option to an options class. I'm going to introduce more options. Particularly, currently the parser also resets the layouts of all individual instructions. While it's not a big deal for most of them, for entry computation `parameter` it does matter. PiperOrigin-RevId: 678615879 --- third_party/xla/xla/service/hlo_parser.cc | 22 +++---- third_party/xla/xla/service/hlo_parser.h | 29 ++++++--- .../xla/xla/service/hlo_parser_test.cc | 62 ++++++++++++++++++- third_party/xla/xla/tools/BUILD | 2 + .../xla/xla/tools/hlo_module_loader.cc | 16 +++-- third_party/xla/xla/tools/hlo_module_loader.h | 4 +- 6 files changed, 101 insertions(+), 34 deletions(-) diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index 9befbac1b48fcb..abf31f1767d7f8 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -249,10 +249,8 @@ class HloParserImpl : public HloParser { using BoolList = absl::InlinedVector; explicit HloParserImpl(absl::string_view str, - bool set_to_default_entry_computation_layout = true) - : lexer_(str), - set_to_default_entry_computation_layout_( - set_to_default_entry_computation_layout) {} + const HloParserOptions& options = HloParserOptions()) + : lexer_(str), options_(options) {} // Runs the parser and constructs the resulting HLO in the given (empty) // HloModule. Returns the error status in case an error occurred. @@ -673,7 +671,7 @@ class HloParserImpl : public HloParser { // Used to generate names for anonymous instructions. NameUniquer name_uniquer_{/*separator=*/"."}; - const bool set_to_default_entry_computation_layout_; + const HloParserOptions options_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { @@ -917,7 +915,7 @@ bool HloParserImpl::ParseComputationLayout( } while (lexer_.GetKind() != TokKind::kRparen) { Shape param; - if (!ParseShape(¶m, set_to_default_entry_computation_layout_)) { + if (!ParseShape(¶m, options_.fill_missing_module_parameter_layouts())) { return false; } computation_layout->add_parameter_layout(ShapeLayout(param)); @@ -937,7 +935,7 @@ bool HloParserImpl::ParseComputationLayout( return false; } Shape result; - if (!ParseShape(&result, set_to_default_entry_computation_layout_)) { + if (!ParseShape(&result, options_.fill_missing_module_parameter_layouts())) { return false; } *computation_layout->mutable_result_layout() = ShapeLayout(result); @@ -6990,19 +6988,13 @@ bool HloParserImpl::ParseSingleInstruction(HloModule* module) { absl::StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str, const HloModuleConfig& config, - bool set_to_default_entry_computation_layout) { + const HloParserOptions& options) { auto module = std::make_unique(/*name=*/"_", config); - HloParserImpl parser(str, set_to_default_entry_computation_layout); + HloParserImpl parser(str, options); TF_RETURN_IF_ERROR(parser.Run(module.get())); return std::move(module); } -absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, bool set_to_default_entry_computation_layout) { - return ParseAndReturnUnverifiedModule( - str, HloModuleConfig(), set_to_default_entry_computation_layout); -} - absl::StatusOr ParseSharding(absl::string_view str) { HloParserImpl parser(str); return parser.ParseShardingOnly(); diff --git a/third_party/xla/xla/service/hlo_parser.h b/third_party/xla/xla/service/hlo_parser.h index 2628c15eb00db8..c6b5f545c54cd4 100644 --- a/third_party/xla/xla/service/hlo_parser.h +++ b/third_party/xla/xla/service/hlo_parser.h @@ -24,25 +24,34 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_lexer.h" #include "xla/xla_data.pb.h" namespace xla { -// Given a string in the HloModule::ToString() format, parses the string and -// creates a HloModule with the given config. -// Note: Tests derived from HloTestBase should use -// ParseAndReturnVerifiedModule() instead! -absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, const HloModuleConfig& config, - bool set_to_default_entry_computation_layout = true); +class HloParserOptions { + public: + // If the entry computation parameter layout is not set, set the layout to be + // the default (e.g. {3,2,1,0}). + HloParserOptions& set_fill_missing_module_parameter_layouts(bool value) { + fill_missing_module_parameter_layouts_ = value; + return *this; + } + + bool fill_missing_module_parameter_layouts() const { + return fill_missing_module_parameter_layouts_; + } + + private: + bool fill_missing_module_parameter_layouts_ = true; +}; // Given a string in the HloModule::ToString() format, parses the string and -// creates a HloModule with default config. +// creates a HloModule with the given config. // Note: Tests derived from HloTestBase should use // ParseAndReturnVerifiedModule() instead! absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, bool set_to_default_entry_computation_layout = true); + absl::string_view str, const HloModuleConfig& config = HloModuleConfig(), + const HloParserOptions& options = HloParserOptions()); // Parses sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index e7c52987492fe0..0035b317fb418c 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include #include "absl/log/log.h" #include "absl/status/status.h" @@ -3434,12 +3435,71 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { absl::StatusOr> module = ParseAndReturnUnverifiedModule( - original, /*set_to_default_entry_computation_layout=*/false); + original, {}, + HloParserOptions().set_fill_missing_module_parameter_layouts(false)); TF_ASSERT_OK(module.status()); // Do not set the default layout. EXPECT_FALSE(module.value()->entry_computation_layout().AnyLayoutSet()); } +TEST_F(HloParserTest, DoNotSetEntryComputationLayoutIfSet) { + const std::string original = R"( +HloModule layout_defined, entry_computation_layout={(f32[8,16,256]{1,2,0}) -> f32[8,16]} + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, + HloParserOptions().set_fill_missing_module_parameter_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation_layout() + .parameter_layout(0) + .layout() + .minor_to_major(), + ElementsAre(1, 2, 0)); +} + +TEST_F(HloParserTest, SetEntryComputationLayoutIfNotSet) { + const std::string original = R"( +HloModule layout_defined, entry_computation_layout={(f32[8,16,256]) -> f32[8,16]} + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, + HloParserOptions().set_fill_missing_module_parameter_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation_layout() + .parameter_layout(0) + .layout() + .minor_to_major(), + ElementsAre(2, 1, 0)); +} + TEST_F(HloParserTest, NoEntry) { const std::string original = R"(HloModule no_entry: c1 { diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 06c37dddab3401..49aab5056cd884 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -477,10 +477,12 @@ cc_library( "//xla/service:hlo_parser", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/tools/hlo_module_loader.cc b/third_party/xla/xla/tools/hlo_module_loader.cc index f6a685435825ca..f765acaeeef8ca 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.cc +++ b/third_party/xla/xla/tools/hlo_module_loader.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "re2/re2.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -71,7 +73,7 @@ absl::StatusOr> LoadModuleFromData( const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, BufferAssignmentProto* buffer_assignment_proto, - bool set_to_default_entry_computation_layout) { + bool fill_missing_module_parameter_layouts) { DebugOptions debug_options = GetDebugOptionsFromFlags(); std::unique_ptr module; if (format == "hlo" || format == "txt") { @@ -82,9 +84,11 @@ absl::StatusOr> LoadModuleFromData( if (config_modifier_hook) { config_modifier_hook(&config); } - TF_ASSIGN_OR_RETURN(module, ParseAndReturnUnverifiedModule( - hlo_string, config, - set_to_default_entry_computation_layout)); + HloParserOptions options; + options.set_fill_missing_module_parameter_layouts( + fill_missing_module_parameter_layouts); + TF_ASSIGN_OR_RETURN( + module, ParseAndReturnUnverifiedModule(hlo_string, config, options)); } else { HloSnapshot proto; if (format == "pb") { @@ -133,7 +137,7 @@ absl::StatusOr> LoadModuleFromFile( const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, BufferAssignmentProto* buffer_assignment_proto, - bool set_to_default_entry_computation_layout) { + bool fill_missing_module_parameter_layouts) { std::string data; if (format.empty()) { format = std::string(tsl::io::Extension(path)); @@ -141,7 +145,7 @@ absl::StatusOr> LoadModuleFromFile( TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data)); return LoadModuleFromData(data, format, ovr_config, config_modifier_hook, buffer_assignment_proto, - set_to_default_entry_computation_layout); + fill_missing_module_parameter_layouts); } absl::StatusOr> diff --git a/third_party/xla/xla/tools/hlo_module_loader.h b/third_party/xla/xla/tools/hlo_module_loader.h index a841bceed512d0..9000422fd8d4f0 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.h +++ b/third_party/xla/xla/tools/hlo_module_loader.h @@ -61,7 +61,7 @@ absl::StatusOr> LoadModuleFromData( hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr, - bool set_to_default_entry_computation_layout = true); + bool fill_missing_module_parameter_layouts = true); // Loads an HLO module from file. // The file can be one of the followings: @@ -84,7 +84,7 @@ absl::StatusOr> LoadModuleFromFile( hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr, - bool set_to_default_entry_computation_layout = true); + bool fill_missing_module_parameter_layouts = true); // Loads an HLO snapshot from a string, only for its inputs // The data format must be one of the following: From 7f1b62adc8e4ec8b5696d2140e3a6a6b7130e628 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Wed, 25 Sep 2024 03:28:03 -0700 Subject: [PATCH 243/483] Enable Triton int4 support by default in XLA. It is supported for LHS, RHS with various shapes except those where the batch dim stride is 1. PiperOrigin-RevId: 678624641 --- third_party/xla/xla/debug_options_flags.cc | 2 +- ...riton_fusion_emitter_device_legacy_test.cc | 56 +++++++++++++++---- .../gpu/fusions/triton/triton_support.cc | 1 + .../gpu/transforms/gemm_fusion_test.cc | 24 +++----- 4 files changed, 55 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 277e8a3f75982e..abedfc370dd83f 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -285,7 +285,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cudnn_gemm_max_plans(5); - opts.set_xla_gpu_enable_triton_gemm_int4(false); + opts.set_xla_gpu_enable_triton_gemm_int4(true); opts.set_xla_gpu_enable_pgle_accuracy_checker(false); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index e76faaba5c93b6..da369139994aaf 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -15,21 +15,16 @@ limitations under the License. #include #include -#include -#include #include #include #include -#include #include #include #include "absl/algorithm/container.h" -#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" -#include "absl/types/span.h" #include "llvm/IR/LLVMContext.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" @@ -66,6 +61,7 @@ namespace gpu { namespace { namespace m = ::xla::match; +using tsl::testing::StatusIs; class TritonTest : public GpuCodegenTest { public: @@ -151,6 +147,44 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { } }; +TEST_F(TritonGemmTest, RejectDotInt4HLO) { + constexpr std::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = s4[16,32,64]{2,1,0} parameter(0) + rhs = s4[16,64,16]{2,1,0} parameter(1) + ROOT dot = s4[16,32,16]{2,1,0} dot(lhs, rhs), + lhs_contracting_dims={2}, + rhs_contracting_dims={1}, + lhs_batch_dims={0}, + rhs_batch_dims={0} + } + )"; + EXPECT_THAT(GetOptimizedModule(kHloText).status(), + StatusIs(tsl::error::INVALID_ARGUMENT)); +} + +TEST_F(TritonGemmTest, RejectInt4NegatePlusConvertHLO) { + constexpr std::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = s4[16,32,64]{2,1,0} parameter(0) + lhs_negated = s4[16,32,64]{2,1,0} negate(lhs) + lhs_converted = bf16[16,32,64]{2,1,0} convert(lhs_negated) + rhs = bf16[16,64,16]{2,1,0} parameter(1) + ROOT dot = bf16[16,32,16]{2,1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={2}, + rhs_contracting_dims={1}, + lhs_batch_dims={0}, + rhs_batch_dims={0} + } + )"; + EXPECT_THAT(GetOptimizedModule(kHloText).status(), + StatusIs(tsl::error::INVALID_ARGUMENT)); +} + TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { constexpr std::string_view kHloText = R"( HloModule t @@ -167,7 +201,7 @@ TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { } )"; const std::string pattern = - R"(CHECK-NOT: ""kind":"__triton_gemm","triton_gemm_config"")"; + R"(CHECK-NOT: "kind":"__triton_gemm","triton_gemm_config")"; TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); EXPECT_TRUE(ok); @@ -1061,9 +1095,8 @@ ENTRY entry { EXPECT_THAT( TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, block_level_parameters, &llvm_module, mlir_context), - tsl::testing::StatusIs( - tsl::error::RESOURCE_EXHAUSTED, - ::testing::HasSubstr("Shared memory size limit exceeded"))); + StatusIs(tsl::error::RESOURCE_EXHAUSTED, + ::testing::HasSubstr("Shared memory size limit exceeded"))); config.set_block_m(64); config.set_block_n(128); @@ -1647,9 +1680,8 @@ ENTRY entry { EXPECT_THAT( TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, block_level_parameters, &llvm_module, mlir_context), - tsl::testing::StatusIs( - tsl::error::RESOURCE_EXHAUSTED, - "Tiling complexity heuristic exceeded: 147456 > 9000")); + StatusIs(tsl::error::RESOURCE_EXHAUSTED, + "Tiling complexity heuristic exceeded: 147456 > 9000")); // Succeeds if the tiling is not too complex. config.set_block_m(32); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc index 0a4d5c0e9c18e5..d0a33343fa2237 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc @@ -44,6 +44,7 @@ bool IsTritonSupportedDataType(PrimitiveType type, const se::GpuComputeCapability& gpu_version) { switch (type) { case PRED: + case S4: case S8: case S16: case S32: diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc index 38cf1e66bfc551..af97932ddf3a89 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -1352,16 +1352,16 @@ ENTRY main { EXPECT_FALSE(result.ok()); } -constexpr auto kInt4Dot = R"( -ENTRY e { - p0 = s8[16,16] parameter(0) - p1 = s4[16,16] parameter(1) - p1c = bf16[16,16] convert(p1) - ROOT dot = bf16[16,16] dot(p0, p1c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - TEST_F(SmallDotGemmFusionTest, Int4DotIsRewritten) { + constexpr auto kInt4Dot = R"( + ENTRY e { + p0 = s8[16,16] parameter(0) + p1 = s4[16,16] parameter(1) + p1c = bf16[16,16] convert(p1) + ROOT dot = bf16[16,16] dot(p0, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); module->mutable_config() @@ -1370,12 +1370,6 @@ TEST_F(SmallDotGemmFusionTest, Int4DotIsRewritten) { EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); } -TEST_F(SmallDotGemmFusionTest, Int4DotIsNotRewritten) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kInt4Dot)); - EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value()); -} - TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) { const std::string kInt4Dot = R"( ENTRY main { From 8cf251b4fad0b7ff193a6d5379512e48ae252d05 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 25 Sep 2024 03:33:09 -0700 Subject: [PATCH 244/483] [XLA:GPU][NFC] Move addition of double buffering passes to a separate function. PiperOrigin-RevId: 678626186 --- .../xla/xla/service/gpu/gpu_compiler.cc | 68 +++++++++++-------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 166a0f4ff44048..0cb1e5161bd363 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1045,6 +1045,43 @@ absl::Status RunFusionPasses(HloModule* hlo_module, return absl::OkStatus(); } +// Adds unrolling while loop optimization. Mostly to get rid of extra D2D +// copies, but also there are some performance benefits (better comm-compute +// overlap) when collectives are present within a while loop. +void AddDoubleBufferingPasses(const DebugOptions& opts, + HloPassPipeline& pipeline) { + std::optional unroll_strategy = + std::nullopt; + // Support old flag. + if (opts.xla_gpu_enable_while_loop_double_buffering()) { + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer; + } + // Support new flag setting style, override the old one. + if (opts.xla_gpu_enable_while_loop_unrolling() == + DebugOptions::WHILE_LOOP_UNROLLING_DOUBLE_BUFFER) { + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer; + } + if (opts.xla_gpu_enable_while_loop_unrolling() == + DebugOptions::WHILE_LOOP_UNROLLING_FULL_UNROLL) { + LOG_IF(WARNING, unroll_strategy != std::nullopt) + << "Overriding double buffering set via " + "`xla_gpu_enable_while_loop_double_buffering` flag."; + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll; + } + if (opts.xla_gpu_enable_while_loop_unrolling() == + DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL && + opts.xla_gpu_enable_heuristic_pass_configuration() && + !opts.xla_gpu_enable_while_loop_double_buffering()) { + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kAuto; + } + if (unroll_strategy != std::nullopt) { + pipeline.AddPass(); + pipeline.AddPass(*unroll_strategy); + pipeline.AddPass(); + pipeline.AddPass(); + } +} + absl::Status RunPostFusionPasses( HloModule* hlo_module, std::function @@ -1077,36 +1114,7 @@ absl::Status RunPostFusionPasses( pipeline.AddPass(blueconnect_num_devices_per_host); } - std::optional unroll_strategy = - std::nullopt; - // Support old flag. - if (opts.xla_gpu_enable_while_loop_double_buffering()) { - unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer; - } - // Support new flag setting style, override the old one. - if (opts.xla_gpu_enable_while_loop_unrolling() == - DebugOptions::WHILE_LOOP_UNROLLING_DOUBLE_BUFFER) { - unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer; - } - if (opts.xla_gpu_enable_while_loop_unrolling() == - DebugOptions::WHILE_LOOP_UNROLLING_FULL_UNROLL) { - LOG_IF(WARNING, unroll_strategy != std::nullopt) - << "Overriding double buffering set via " - "`xla_gpu_enable_while_loop_double_buffering` flag."; - unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll; - } - if (opts.xla_gpu_enable_while_loop_unrolling() == - DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL && - opts.xla_gpu_enable_heuristic_pass_configuration() && - !opts.xla_gpu_enable_while_loop_double_buffering()) { - unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kAuto; - } - if (unroll_strategy != std::nullopt) { - pipeline.AddPass(); - pipeline.AddPass(*unroll_strategy); - pipeline.AddPass(); - pipeline.AddPass(); - } + AddDoubleBufferingPasses(opts, pipeline); return pipeline.Run(hlo_module).status(); } From b568c1f638d29fad119664a8c65c0a8ba17b1f34 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 03:33:12 -0700 Subject: [PATCH 245/483] Automated Code Change PiperOrigin-RevId: 678626204 --- tensorflow/java/src/gen/cc/op_generator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 35ed2d517e9241..f9c092b138fb76 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -560,7 +560,7 @@ Status OpGenerator::Run(const OpList& op_list, const string& base_package, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace java From ab7082c62ff4efe602a9ac529e9d315ab879d32f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 03:38:52 -0700 Subject: [PATCH 246/483] Automated Code Change PiperOrigin-RevId: 678627893 --- .../core/distributed_runtime/eager/eager_service_impl.cc | 2 +- .../distributed_runtime/eager/eager_service_impl_test.cc | 8 ++++---- .../core/distributed_runtime/eager/remote_execute_node.cc | 4 ++-- .../core/distributed_runtime/eager/remote_execute_node.h | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index d4a5effeb78a21..71addd352b76d5 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -831,7 +831,7 @@ Status EagerServiceImpl::CleanupFunction( Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, EagerContext* eager_context) { - tensorflow::gtl::InlinedVector tensors; + absl::InlinedVector tensors; for (const auto& tensor_proto : send_tensor.tensors()) { Tensor tensor; if (!tensor.FromProto(tensor_proto)) { diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 84cfe637697c70..942369e8415915 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -934,7 +934,7 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { class TestExecuteNodeArgs : public EagerKernelArgs { public: TestExecuteNodeArgs( - gtl::InlinedVector&& tensor_args, + absl::InlinedVector&& tensor_args, std::function serialize_remote_handle) : EagerKernelArgs(std::move(tensor_args)), @@ -1095,7 +1095,7 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) { input.set_device(local_device_); std::vector inputs = {input}; std::vector outputs; - gtl::InlinedVector tensor_args = {TensorValue()}; + absl::InlinedVector tensor_args = {TensorValue()}; TestExecuteNodeArgs args( std::move(tensor_args), [&inputs](const int i, RemoteTensorHandle* handle) -> Status { @@ -1193,7 +1193,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr, std::nullopt)); // Run MatMulFunction on remote_device. - gtl::InlinedVector input_tensors = {TensorValue()}; + absl::InlinedVector input_tensors = {TensorValue()}; RemoteTensorHandle input; input.set_op_id(1); input.set_output_num(0); @@ -1248,7 +1248,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) { TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr, std::nullopt)); // Run MatMulFunction on remote_device. - gtl::InlinedVector input_tensors = {TensorValue()}; + absl::InlinedVector input_tensors = {TensorValue()}; RemoteTensorHandle input; input.set_op_id(1); input.set_output_num(0); diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc index af7e24d79b80f8..ff1cd85f34fb82 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc @@ -27,8 +27,8 @@ namespace eager { void RemoteExecuteNode::RunAsync(StatusCallback done) { auto response = std::make_shared(); - const gtl::InlinedVector& inputs = inputs_; - const gtl::InlinedVector& retvals = retvals_; + const absl::InlinedVector& inputs = inputs_; + const absl::InlinedVector& retvals = retvals_; Device* device = device_; // Filled and used only when VLOG(3) is on. diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index 148e58a5b008c5..e864eda87f9e6d 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -44,7 +44,7 @@ class RemoteExecuteNode : public AsyncRemoteExecuteNode { CancellationManager* cancellation_manager, const NodeDef& ndef, const FunctionLibraryDefinition* lib_def, - const gtl::InlinedVector& inputs, + const absl::InlinedVector& inputs, absl::Span retvals) : AsyncRemoteExecuteNode(), eager_context_(eager_context), @@ -133,8 +133,8 @@ class RemoteExecuteNode : public AsyncRemoteExecuteNode { CancellationManager* cancellation_manager_; const NodeDef ndef_; const FunctionLibraryDefinition* lib_def_; - gtl::InlinedVector inputs_; - gtl::InlinedVector retvals_; + absl::InlinedVector inputs_; + absl::InlinedVector retvals_; }; } // namespace eager From 7df30f9ce796fcccaf351e8be7a77618cc0a1c47 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 04:08:54 -0700 Subject: [PATCH 247/483] Automated Code Change PiperOrigin-RevId: 678637161 --- tensorflow/cc/tools/freeze_saved_model.cc | 10 +++++----- tensorflow/cc/tools/freeze_saved_model_test.cc | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 8a15850edd839f..c23f9161a448fd 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -133,7 +133,7 @@ Status GetVariableNameToTensorMap( std::unordered_set variable_names_set, std::unordered_map* variable_name_to_value_map) { if (variable_names_set.empty()) { - return OkStatus(); + return absl::OkStatus(); } std::vector variable_names; variable_names.reserve(variable_names_set.size()); @@ -156,7 +156,7 @@ Status GetVariableNameToTensorMap( for (size_t i = 0; i < variable_names.size(); i++) { (*variable_name_to_value_map)[variable_names[i]] = outputs[i]; } - return OkStatus(); + return absl::OkStatus(); } // Converts a Variable NodeDef into a Constant NodeDef. @@ -229,7 +229,7 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, *frozen_graph_def->mutable_library() = graph_def.library(); // If the graph is empty there is nothing left to do. if (graph_def.node_size() == 0) { - return OkStatus(); + return absl::OkStatus(); } // name_to_node_map is needed to get the inputs from the NodeDef corresponding // the a string node name. These inputs are used when doing our backwards @@ -277,7 +277,7 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, // If the node isn't a variable, just copy the node as-is. *frozen_graph_def->add_node() = node; } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -289,7 +289,7 @@ Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs); TF_RETURN_IF_ERROR( FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index eb4ef40b8927f6..a64aab9e0bb5f5 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -81,7 +81,7 @@ class FreezeTest : public ::testing::Test { return saved_model_bundle->session->Run( /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs); } - return OkStatus(); + return absl::OkStatus(); } // Adds `graph_def` to `saved_model_bundle` and initializes a session with From 62b9f83b108d3d98672b52eab8cccf9b8918cb7c Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 25 Sep 2024 04:08:56 -0700 Subject: [PATCH 248/483] [XLA:GPU][IndexAnalysis] Unify parsers for IndexingMap and IndexingMapAttr. Unfortunately, MLIR does not support multiline string attributes right now, so the lit tests don't look as pretty as before. PiperOrigin-RevId: 678637171 --- .../xla/xla/service/gpu/fusions/ir/BUILD | 2 +- .../service/gpu/fusions/ir/tests/attrs.mlir | 110 ++++----- .../gpu/fusions/ir/tests/canonicalize.mlir | 116 +++++----- .../service/gpu/fusions/ir/tests/invalid.mlir | 147 +++--------- .../xla/service/gpu/fusions/ir/tests/ops.mlir | 70 +++--- .../service/gpu/fusions/ir/xla_gpu_attrs.cc | 164 +++---------- .../mlir/elemental_hlo_to_mlir_test.cc | 98 ++++---- .../fusions/tests/concatenate/concat_1d.hlo | 8 +- .../tests/loop/tuple_heterogeneous.hlo | 4 +- .../fusions/tests/scatter/unique_indices.hlo | 4 +- .../transforms/tests/flatten_tensors.mlir | 25 +- .../fusions/transforms/tests/fuse_loops.mlir | 218 +++++++++--------- .../tests/lower_xla_gpu_loops_to_scf.mlir | 18 +- .../tests/lower_xla_gpu_to_scf.mlir | 55 ++--- .../transforms/tests/optimize_loops.mlir | 16 +- .../fusions/transforms/tests/peel_loops.mlir | 42 ++-- .../transforms/tests/rewrite_reductions.mlir | 6 +- .../transforms/tests/simplify_affine.mlir | 19 +- .../transforms/tests/simplify_arith.mlir | 27 ++- .../tests/vectorize_loads_stores.mlir | 72 +++--- .../triton_fusion_emitter_device_test.cc | 14 +- 21 files changed, 511 insertions(+), 724 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD index 0ef48df7fd1741..1d60d912a8345f 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD @@ -135,7 +135,7 @@ cc_library( ":xla_gpu_ops_inc_gen", ":xla_gpu_types_inc_gen", "//xla/service/gpu/model:indexing_analysis", - "@com_google_absl//absl/strings:str_format", + "//xla/service/gpu/model:indexing_map_serialization", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BytecodeOpInterface", diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir index bc37a3ac56fc7c..b990103ea2cfab 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir @@ -9,17 +9,17 @@ // CHECK-SAME: s0 in [0, 32], // CHECK-SAME: d0 + s0 in [1, 10], // CHECK-SAME: d0 mod 2 in [0, 1], -// CHECK-SAME: is_simplified: true +// CHECK-SAME: is_simplified: true" // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0), - domain: - d0 in [1, 2], - d1 in [5, 8], - d2 in [10, 12], - s0 in [0, 32], - d0 mod 2 in [0, 1], - d0 + s0 in [1, 10], - is_simplified: true +#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," + "domain:" + "d0 in [1, 2]," + "d1 in [5, 8]," + "d2 in [10, 12]," + "s0 in [0, 32]," + "d0 mod 2 in [0, 1]," + "d0 + s0 in [1, 10]," + "is_simplified: true" > func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>) @@ -39,20 +39,21 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map> // CHECK-SAME: d0 + s0 in [1, 10] // CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: d1 + s1 + s2 in [1, 32] -// CHECK-SAME: is_simplified: false +// CHECK-SAME: is_simplified: false" // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2), - domain: - d0 in [1, 2], - d1 in [5, 8], - s0 in [0, 10], - s1 in [0, 5], - s2 in [0, 32], - d0 mod 2 in [0, 1], - d0 + s0 in [1, 10], - d1 + s1 + s2 in [1, 32], - is_simplified: false - > +#map = #xla_gpu.indexing_map< + "(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)," + "domain:" + "d0 in [1, 2]," + "d1 in [5, 8]," + "s0 in [0, 10]," + "s1 in [0, 5]," + "s2 in [0, 32]," + "d0 mod 2 in [0, 1]," + "d0 + s0 in [1, 10]," + "d1 + s1 + s2 in [1, 32]," + "is_simplified: false" + > func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // CHECK-LABEL: @more_range_vars // CHECK: !xla_gpu.indexed_vector<100x32xf64, #[[$INDEX_MAP]]> @@ -64,13 +65,13 @@ func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: d0 in [0, 100] // CHECK-SAME: s0 in [-3, -1] -// CHECK-SAME: is_simplified: false +// CHECK-SAME: is_simplified: false" // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0), - domain: - d0 in [0, 100], - s0 in [-3, -1], - is_simplified: false +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0)," + "domain:" + "d0 in [0, 100]," + "s0 in [-3, -1]," + "is_simplified: false" > func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @indexing_map_small @@ -85,15 +86,15 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: d1 in [5, 8] // CHECK-SAME: d2 in [10, 12] // CHECK-SAME: s0 in [0, 32] -// CHECK-SAME: is_simplified: false +// CHECK-SAME: is_simplified: false" // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0), - domain: - d0 in [1, 2], - d1 in [5, 8], - d2 in [10, 12], - s0 in [0, 32], - is_simplified: false +#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," + "domain:" + "d0 in [1, 2]," + "d1 in [5, 8]," + "d2 in [10, 12]," + "s0 in [0, 32]," + "is_simplified: false" > func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // CHECK-LABEL: @no_constraints @@ -106,13 +107,13 @@ func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: s0 in [3, 5] // CHECK-SAME: s0 mod 2 in [0, 1] -// CHECK-SAME: is_simplified: false +// CHECK-SAME: is_simplified: false" // CHECK-SAME: > -#map = #xla_gpu.indexing_map<()[s0] -> (s0), - domain: - s0 in [3, 5], - s0 mod 2 in [0, 1], - is_simplified: false +#map = #xla_gpu.indexing_map<"()[s0] -> (s0)," + "domain:" + "s0 in [3, 5]," + "s0 mod 2 in [0, 1]," + "is_simplified: false" > func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @no_dimensions @@ -125,13 +126,13 @@ func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: d0 in [3, 5] // CHECK-SAME: d0 mod 2 in [0, 1] -// CHECK-SAME: is_simplified: false +// CHECK-SAME: is_simplified: false" // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0) -> (d0), - domain: - d0 in [3, 5], - d0 mod 2 in [0, 1], - is_simplified: false +#map = #xla_gpu.indexing_map<"(d0) -> (d0)," + "domain:" + "d0 in [3, 5]," + "d0 mod 2 in [0, 1]," + "is_simplified: false" > func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @no_symbols @@ -142,7 +143,7 @@ func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< // CHECK-SAME: () -> () // CHECK-SAME: > -#map = #xla_gpu.indexing_map<() -> ()> +#map = #xla_gpu.indexing_map<"() -> ()"> func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @empty // CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]> @@ -151,7 +152,8 @@ func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>) func.func private @tensor_layout( %in0: tensor<42xf32, #xla_gpu.layout<"shmem", - (d0) -> (), domain: d0 in [0, 42], is_simplified: true>>) -// CHECK: #layout = #xla_gpu.layout<"shmem", (d0) -> (), -// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true> -// CHECK: tensor<42xf32, #layout> \ No newline at end of file + "(d0) -> ()," + "domain: d0 in [0, 42], is_simplified: true">>) +// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (), +// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true"> +// CHECK: tensor<42xf32, #layout> diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir index 495456a5ab36d4..bfca90e5c64f53 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir @@ -1,15 +1,13 @@ // RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s -#map0 = #xla_gpu.indexing_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), - domain: s0 in [-10, 10], s1 in [0, 2], - is_simplified: false> +#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2], is_simplified: false"> func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1, d0 mod 2), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1, d0 mod 2), // CHECK-SAME: domain: d0 in [-10, 10] -// CHECK-SAME: is_simplified: true> +// CHECK-SAME: is_simplified: true"> // CHECK-LABEL: func.func @simplify_apply_indexing // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) @@ -17,14 +15,13 @@ func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { // ----- -#map0 = #xla_gpu.indexing_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), - domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3], is_simplified: false"> func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, %d2: index, %s0: index, %s1: index) -> (index, index, index) { %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1] func.return %0#0, %0#1, %0#2 : index, index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1, d2) -> (d2 + 1, d2 mod 2, d0 + d1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d2 + 1, d2 mod 2, d0 + d1), // CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3], d2 in [-11, 11] // CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims @@ -38,23 +35,22 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, // ----- -#map0 = #xla_gpu.indexing_map<(d0) -> (d0 mod 10), domain: d0 in [0, 9], is_simplified: true> +#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 10), domain: d0 in [0, 9], is_simplified: true"> func.func @do_not_simplify_if_is_simplified_is_true(%d0: index) -> (index) { %0 = xla_gpu.apply_indexing #map0(%d0) func.return %0 : index } -// CHECK: #xla_gpu.indexing_map<(d0) -> (d0 mod 10) +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 10) // ----- -#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), - domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false"> func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) -> (index, index, index, index, index) { %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-LABEL: func.func @fold_indexing_map_results // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) @@ -67,13 +63,13 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) // ----- -#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0), - domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)," + "domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false"> func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#2 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 2), // CHECK-SAME: domain: d0 in [0, 2] // CHECK-LABEL: func.func @remove_unused_results @@ -84,8 +80,9 @@ func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) // ----- -#map0 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3), - domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)," + "domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]," + "is_simplified: false"> func.func @fold_operands(%d0: index) -> index { %d1 = arith.constant 1 : index %s0 = arith.constant 2 : index @@ -93,7 +90,7 @@ func.func @fold_operands(%d0: index) -> index { %0 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0, %s1] func.return %0 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 3), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 3), // CHECK-SAME: domain: d0 in [0, 10] // CHECK-LABEL: func.func @fold_operands @@ -104,8 +101,8 @@ func.func @fold_operands(%d0: index) -> index { func.func @fold_operands_and_results(%arg0: index, %arg1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (0, d1), - domain: d0 in [0, 4], d1 in [0, 5], is_simplified: false>(%arg0, %arg1) + %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (0, d1)," + "domain: d0 in [0, 4], d1 in [0, 5], is_simplified: false">(%arg0, %arg1) return %0#0, %0#1 : index, index } @@ -117,14 +114,15 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index) // ----- func.func @fold_sequence(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false>(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 100 + 42), - domain: d0 in [0, 10000], is_simplified: false>(%0) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< + "(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]," + "is_simplified: false">(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 100 + 42)," + "domain: d0 in [0, 10000], is_simplified: false">(%0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 + 42), // CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4] // CHECK-LABEL: func.func @fold_sequence // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) @@ -134,14 +132,15 @@ func.func @fold_sequence(%arg0: index, %arg1: index) -> index { // ----- func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false>(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<()[s0] -> (s0 mod 100 + 42), - domain: s0 in [0, 10000], is_simplified: false>(%0) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), " + "domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false">(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map< + "()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000]," + "is_simplified: false">(%0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 + 42), // CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4] // CHECK-LABEL: func.func @fold_sequence_sym // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) @@ -150,12 +149,11 @@ func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { // ----- -#indexing_map1 = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0 + 8512), - domain: d0 in [0, 1], d1 in [0, 607], is_simplified: false> -#indexing_map2 = #xla_gpu.indexing_map< - (d0, d1, d2) -> (((d1 floordiv 32 + 1) mod 3) * 64 - + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2), - domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0 + 8512)," + "domain: d0 in [0, 1], d1 in [0, 607], is_simplified: false"> +#indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (" + "((d1 floordiv 32 + 1) mod 3) * 64 + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2)," + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> func.func @fold_sequence_no_simplification_needed(%i: index) -> index { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} @@ -168,12 +166,12 @@ func.func @fold_sequence_no_simplification_needed(%i: index) -> index { // ----- -#indexing_map1 = #xla_gpu.indexing_map<(d0) -> (3 * d0), - domain: d0 in [0, 9407], is_simplified: false> -#indexing_map2 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 floordiv 32 + 1), - domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false> -#indexing_map3 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 floordiv 32 + 2), - domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false> +#indexing_map1 = #xla_gpu.indexing_map< + "(d0) -> (3 * d0), domain: d0 in [0, 9407], is_simplified: false"> +#indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 1)," + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> +#indexing_map3 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 2)," + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} @@ -187,14 +185,14 @@ func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { // ----- func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false>(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - domain: d0 in [0, 4], d1 in [0, 10000], is_simplified: false>(%arg1, %0) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," + "domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false">(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," + "domain: d0 in [0, 4], d1 in [0, 10000], is_simplified: false">(%arg1, %0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), // CHECK-SAME: domain: d0 in [0, 4], d1 in [0, 5] // CHECK-LABEL: func.func @fold_sequence_shared_operands // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) @@ -235,15 +233,15 @@ func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index) // ----- -#map0 = #xla_gpu.indexing_map<(d0)[s0] -> (2 * d0 * s0), - domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," + "domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false"> func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) -> index { %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] func.return %0 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d0 * d1) * 2), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> ((d0 * d1) * 2), // CHECK-SAME: domain: d0 in [0, 3], d1 in [0, 2] // CHECK-LABEL: func.func @apply_indexing_move_syms_to_dims // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] @@ -251,8 +249,10 @@ func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) // // ----- -#map0 = #xla_gpu.indexing_map<(d0) -> (4 * d0), domain: d0 in [0, 3], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]," + "is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1)," + "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %idx = xla_gpu.apply_indexing #map0(%dim) %sum = xla_gpu.loop (%idx)[%i, %j] -> (%r0, %r1) in #map1 iter_args(%sum_ = %init) -> (f32) { @@ -263,7 +263,7 @@ func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: func.return %sum : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 * 4 + s0, s1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 * 4 + s0, s1), // CHECK-SAME: domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32] // CHECK-LABEL: func.func @loop_of_apply_indexing // CHECK-SAME: %[[ARG0:.*]]: tensor<1024x32xf32>, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: index) @@ -272,8 +272,10 @@ func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: // ----- -#map0 = #xla_gpu.indexing_map<(d0)[s0] -> (2 * d0 * s0), domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0 + s1), domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," + "domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0 + s1)," + "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: tensor<1024x32xf32>, %init: f32) -> (f32) { %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] %sum = xla_gpu.loop (%0)[%i, %j] -> (%r0) in #map1 iter_args(%sum_ = %init) -> (f32) { @@ -284,7 +286,7 @@ func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: func.return %sum : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> ((d0 * d1) * 2 + s0 + s1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> ((d0 * d1) * 2 + s0 + s1), // CHECK-SAME: domain: d0 in [0, 3], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] // CHECK-LABEL: func.func @loop_of_apply_indexing_with_syms // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir index 922b3f3bbfff0e..3c50b5afcd8068 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir @@ -1,13 +1,6 @@ // RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics -#map0 = #xla_gpu.indexing_map< - (d0, d1)[s0] -> (d0, d1 + s0), - domain: - d0 in [1, 2], - d1 in [5, 8], - s0 in [0, 32], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], is_simplified: false"> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}} %0:2 = xla_gpu.apply_indexing #map0 (%d0) @@ -16,16 +9,7 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // ----- -#map0 = #xla_gpu.indexing_map< - (d0, d1)[s0] -> (d0, d1 + s0), - domain: - d0 in [1, 2], - d1 in [5, 8], - s0 in [0, 32], - d0 mod 2 in [0, 1], - d0 + s0 in [1, 10], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], d0 mod 2 in [0, 1], d0 + s0 in [1, 10], is_simplified: false"> func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{apply indexing op cannot have any constraints}} %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] @@ -34,7 +18,7 @@ func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index // ----- -#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (f32) { // expected-error @+1 {{mismatch in number of loop-carried values and results}} @@ -52,7 +36,7 @@ func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, // ----- -#map = #xla_gpu.indexing_map<()[s0] -> (s0, s0), domain: s0 in [0, 1024], is_simplified: false> +#map = #xla_gpu.indexing_map<"()[s0] -> (s0, s0), domain: s0 in [0, 1024], is_simplified: false"> func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (f32) { // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}} @@ -70,8 +54,7 @@ func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, // ----- -#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), - domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) { // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}} @@ -89,8 +72,7 @@ func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), - domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}} @@ -105,9 +87,7 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> { @@ -119,10 +99,8 @@ func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), - domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @no_thread_id_in(%input: tensor<32x64xf32>, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -134,10 +112,8 @@ func.func @no_thread_id_in(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> -#map1 = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), - domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -149,10 +125,8 @@ func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{thread_id dimension must have the same bounds in both indexing maps}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -161,11 +135,8 @@ func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: inde // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -178,10 +149,8 @@ func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0, s0), - domain: d0 in [0, 32], s0 in [0, 1024], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{number of symbols in both indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -190,10 +159,8 @@ func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, % // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{domain of symbols of indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -202,12 +169,8 @@ func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -219,12 +182,8 @@ func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -236,12 +195,8 @@ func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -254,12 +209,8 @@ func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -271,13 +222,8 @@ func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - d1 mod 2 in [0, 0], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -290,13 +236,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - d1 mod 2 in [0, 0], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -309,14 +250,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - d1 mod 2 in [0, 0], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - d1 mod 4 in [0, 0], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0], is_simplified: false"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -329,12 +264,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 mod 16 + s0, d1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 mod 16 + s0, d1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], is_simplified: false"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -346,12 +277,8 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 mod 16, d1, d2), - domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 mod 16, d1, d2), domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5], is_simplified: false"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir index 572202bf148ce2..81e08968db7590 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir @@ -56,19 +56,13 @@ func.func @caller(%a: f32, %b: f32) -> f32 { // ----- -#map0 = #xla_gpu.indexing_map< -(d0, d1)[s0] -> (d0, d1 + s0), - domain: - d0 in [1, 2], - d1 in [5, 8], - s0 in [0, 32], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0)," + "domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], is_simplified: false"> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" // CHECK-SAME: (d0, d1)[s0] -> (d0, d1 + s0) // CHECK-SAME: domain: // CHECK-SAME: d0 in [1, 2] @@ -83,18 +77,13 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // ----- -#map0 = #xla_gpu.indexing_map< -(d0, d1) -> (d0, d1), - domain: - d0 in [0, 2], - d1 in [1, 3], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," + "domain: d0 in [0, 2], d1 in [1, 3], is_simplified: false"> func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1) func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" // CHECK-SAME: (d0, d1) -> (d0, d1) // CHECK-SAME: domain: // CHECK-SAME: d0 in [0, 2] @@ -108,17 +97,13 @@ func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { // ----- -#map0 = #xla_gpu.indexing_map< - ()[s0] -> (s0, s0), - domain: - s0 in [2, 4], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"()[s0] -> (s0, s0)," + "domain: s0 in [2, 4], is_simplified: false"> func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 [%s0] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" // CHECK-SAME: ()[s0] -> (s0, s0) // CHECK-SAME: domain: // CHECK-SAME: s0 in [2, 4] @@ -130,8 +115,8 @@ func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), - domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), " + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) @@ -155,15 +140,12 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], - is_simplified: false> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1), - domain: d0 in [0, 32], d1 in [0, 2], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)," + "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1)," + "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," + "domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -174,11 +156,11 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.return %1 : tensor<32x64xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] -// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1) +// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] -// CHECK: #[[$MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1) +// CHECK: #[[$MAP2:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], // CHECK-LABEL: @materialize_and_insert // CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at @@ -233,13 +215,14 @@ func.func @reduce_middle_dim(%in: tensor<16x8x4xf32>, %init: f32) // ----- -#map = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1), domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1)," + "domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false"> func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { %0 = xla_gpu.reindex %in0 at #map : tensor<1024xf32> -> tensor<16x64xf32> func.return %0 : tensor<16x64xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1) // CHECK-LABEL: func.func @reindex( // CHECK-SAME: %[[IN1:.*]]: tensor<1024xf32> // CHECK: xla_gpu.reindex %[[IN1]] at #[[$MAP]] : @@ -247,7 +230,8 @@ func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { // ----- -#map = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1), domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1)," + "domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false"> func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> { %c0 = arith.constant 0.0 : f32 %0 = xla_gpu.reindex %in0 at #map default %c0 @@ -255,7 +239,7 @@ func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> { func.return %0 : tensor<16x64xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1) // CHECK-LABEL: func.func @reindex_pad( // CHECK-SAME: %[[IN1:.*]]: tensor<1022xf32> // CHECK: %[[C0:.*]] = arith.constant 0.00 @@ -278,4 +262,4 @@ func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { // CHECK: xla_gpu.shuffle_reduce(%[[IN1]], %[[IN2]]) to 4 // CHECK-SAME: combiner=@do_nothing {xla.range = [0 : index, 42 : index]} -// CHECK-SAME: : f32, i32 \ No newline at end of file +// CHECK-SAME: : f32, i32 diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index cb9ba368702c9f..577ec1262970c6 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include -#include "absl/strings/str_format.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/LogicalResult.h" @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" namespace xla { namespace gpu { @@ -43,144 +44,36 @@ using mlir::AsmPrinter; using mlir::failure; using mlir::success; -constexpr llvm::StringRef kIsSimplifiedKeyword = "is_simplified"; - -ParseResult ParseInterval(AsmParser& parser, Interval& interval) { - // ParseResult converts to `true` if parsing failed. - return failure(parser.parseLSquare() || parser.parseInteger(interval.lower) || - parser.parseComma() || parser.parseInteger(interval.upper) || - parser.parseRSquare()); -} - -ParseResult parseBool(AsmParser& parser, bool* result) { - if (succeeded(parser.parseOptionalKeyword("true"))) { - *result = true; - return success(); +// Parses a chain of string attributes into an indexing map. +// Example: +// "()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)," +// " domain: s0 in [-10, 10], s1 in [0, 2]," +// " is_simplified: false" +// will be parsed as 3 StringAttrs, concatenated into a single string, and then +// parsed into an IndexingMap. +std::optional parseChainOfStringsAsIndexingMap( + mlir::AsmParser& parser) { + mlir::StringAttr indexing_map_attr; + std::string indexing_map_str; + while (parser.parseOptionalAttribute(indexing_map_attr).has_value()) { + indexing_map_str.append(indexing_map_attr.getValue()); } - if (succeeded(parser.parseOptionalKeyword("false"))) { - *result = false; - return success(); - } - return failure(); -} - -void PrintDimVars(AsmPrinter& p, ArrayRef dim_vars) { - for (const auto [index, dim_var] : llvm::enumerate(dim_vars)) { - p << "d" << index << " in " << dim_var.bounds << ", "; - } -} - -ParseResult ParseDimVars(AsmParser& parser, ArrayRef dim_names, - SmallVector& dim_vars) { - dim_vars.reserve(dim_names.size()); - for (const auto& [index, dim_name] : llvm::enumerate(dim_names)) { - if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") || - ParseInterval(parser, dim_vars.emplace_back().bounds) || - parser.parseComma()) { - return failure(); - } - } - return success(); -} - -void PrintRangeVars(AsmPrinter& p, ArrayRef range_vars) { - for (const auto [index, range_var] : llvm::enumerate(range_vars)) { - p << "s" << index << " in " << range_var.range << ", "; - } -} - -ParseResult ParseRangeVars(AsmParser& parser, - ArrayRef range_symbol_names, - SmallVector& range_vars) { - range_vars.reserve(range_symbol_names.size()); - for (const auto& [index, range_symbol_name] : - llvm::enumerate(range_symbol_names)) { - if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") || - ParseInterval(parser, range_vars.emplace_back().range) || - parser.parseComma()) { - return failure(); - } - } - return success(); -} - -void PrintConstraints(AsmPrinter& p, - ArrayRef> constraints) { - for (const auto& [expr, interval] : constraints) { - p << expr << " in " << interval << ", "; - } -} - -mlir::Attribute parseIndexingMapImpl(mlir::AsmParser& parser) { - mlir::AffineMap map; - if (parser.parseAffineMap(map)) { - return {}; - } - - // Store real strings to back up StringRef throughout ParseConstraints. - SmallVector dim_strings(map.getNumDims()); - SmallVector symbol_strings(map.getNumSymbols()); - SmallVector> symbolSet; - symbolSet.reserve(map.getNumDims() + map.getNumSymbols()); - for (int i = 0; i < map.getNumDims(); ++i) { - dim_strings[i] = absl::StrFormat("d%d", i); - symbolSet.push_back( - {dim_strings[i], mlir::getAffineDimExpr(i, parser.getContext())}); - } - for (int i = 0; i < map.getNumSymbols(); ++i) { - symbol_strings[i] = absl::StrFormat("s%d", i); - symbolSet.push_back( - {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())}); - } - if (map.getNumDims() + map.getNumSymbols() == 0) { - if (parser.parseGreater()) return {}; - return IndexingMapAttr::get(parser.getContext(), map, /*dim_vars=*/{}, - /*range_vars=*/{}, - /*constraints=*/{}, /*is_simplified=*/true); - } - if (parser.parseComma() || parser.parseKeyword("domain") || - parser.parseColon()) { - return {}; - } - - SmallVector dim_vars; - if (ParseDimVars(parser, dim_strings, dim_vars)) { - return {}; - } - SmallVector range_vars; - if (ParseRangeVars(parser, symbol_strings, range_vars)) { - return {}; - } - - SmallVector> constraints; - while (failed(parser.parseOptionalKeyword(kIsSimplifiedKeyword))) { - auto& constraint = constraints.emplace_back(); - if (parser.parseAffineExpr(symbolSet, constraint.first) || - parser.parseKeyword("in") || ParseInterval(parser, constraint.second) || - parser.parseComma()) { - return {}; - } - constraints.push_back(constraint); - } - - bool is_simplified = false; - if (parser.parseColon() || parseBool(parser, &is_simplified) || - parser.parseGreater()) { - return {}; - } - return IndexingMapAttr::get(parser.getContext(), map, dim_vars, range_vars, - constraints, is_simplified); + return ParseIndexingMap(indexing_map_str, parser.getContext()); } mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { if (parser.parseLess()) { return {}; } - return parseIndexingMapImpl(parser); + auto indexing_map = parseChainOfStringsAsIndexingMap(parser); + if (!indexing_map.has_value() || parser.parseGreater()) { + return {}; + } + return IndexingMapAttr::get(parser.getContext(), *indexing_map); } void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { - printer << "<" << getIndexingMap().ToString() << ">"; + printer << "<\"" << getIndexingMap().ToString() << "\">"; } IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, @@ -230,18 +123,19 @@ mlir::Attribute LayoutAttr::parse(mlir::AsmParser& parser, mlir::Type) { if (!memspace.has_value()) { return {}; } - auto thread_map = mlir::cast(parseIndexingMapImpl(parser)); - if (!thread_map) { + std::optional indexing_map = + parseChainOfStringsAsIndexingMap(parser); + if (!indexing_map.has_value() || parser.parseGreater()) { return {}; } - mlir::MLIRContext* context = parser.getContext(); - auto memory_space_attr = MemorySpaceAttr::get(context, *memspace); - return LayoutAttr::get(context, memory_space_attr, thread_map); + auto* context = parser.getContext(); + return LayoutAttr::get(context, MemorySpaceAttr::get(context, *memspace), + IndexingMapAttr::get(context, *indexing_map)); } void LayoutAttr::print(mlir::AsmPrinter& printer) const { printer << "<\"" << stringifyMemorySpace(getMemorySpace().getValue()) - << "\", " << getThreadMap().getIndexingMap().ToString() << '>'; + << "\", \"" << getThreadMap().getIndexingMap().ToString() << "\">"; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index e683e199ed03ca..5c87db0045dac0 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -234,10 +234,10 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) - // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 4), domain: d0 in [0, 2], is_simplified: true>(%[[Y]]) + // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 2], is_simplified: true">(%[[Y]]) // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 - 3), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 6], is_simplified: true>(%[[Z]], %[[I]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 3), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 6], is_simplified: true">(%[[Z]], %[[I]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], @@ -284,8 +284,8 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // If symbol rescaling wasn't working we would have a // `d1 floordiv ` in the map: // CHECK: %[[K:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 18], d1 in [0, 3], is_simplified: true>(%[[X]], %[[I]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK-SAME: d0 in [0, 18], d1 in [0, 3], is_simplified: true">(%[[X]], %[[I]]) // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] )")); @@ -505,7 +505,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true>(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -517,9 +517,9 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true>(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true>(%[[Y]]) + // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -547,7 +547,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true>(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -559,9 +559,9 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true>(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true>(%[[Y]]) + // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -878,11 +878,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -924,11 +924,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 2], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK-SAME: d0 in [0, 2], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -971,21 +971,21 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 - 1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 1), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 - 2), - // CHECK-SAME: d0 in [0, 11], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 2), + // CHECK-SAME: d0 in [0, 11], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1025,17 +1025,17 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> ((d0 + d1) floordiv 2), - // CHECK-SAME: d0 in [0, 12], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), + // CHECK-SAME: d0 in [0, 12], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> ((d0 + d1) floordiv 2), - // CHECK-SAME: d0 in [0, 18], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), + // CHECK-SAME: d0 in [0, 18], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1077,11 +1077,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), + // CHECK-SAME: d0 in [0, 3], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), + // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1123,14 +1123,14 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> ((d0 floordiv 8) * 2 + d1), - // CHECK-SAME: d0 in [0, 15], d1 in [0, 1], is_simplified: true>(%[[O]], %[[I]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 floordiv 8) * 2 + d1), + // CHECK-SAME: d0 in [0, 15], d1 in [0, 1], is_simplified: true">(%[[O]], %[[I]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1174,11 +1174,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK-NEXT: %[[R3:.+]] = scf.for %[[G:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A2]]) -> (f32) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1644,8 +1644,8 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1), - // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true>(%[[X]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true">(%[[X]], %[[Y]]) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] )")); @@ -1668,8 +1668,8 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 // CHECK: %[[IDX:.*]] = - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1), - // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true>(%[[X]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true">(%[[X]], %[[Y]]) // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo index f99ff371ef38d1..5ac91b201c6168 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo @@ -8,10 +8,10 @@ fusion { param2 = f32[300] parameter(2) ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} } -// CHECK-DAG: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0) -// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0) -// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0 + 200) -// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0 + 600) +// CHECK-DAG: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 128 + d0) +// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0) +// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0 + 200) +// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0 + 600) // CHECK: func.func @main // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo index 3b5e454584137a..4f93eacbfab93d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo @@ -12,8 +12,8 @@ fusion { ROOT tuple = (f64[8], f64[2,4]) tuple(minimum, bc) } -// CHECK: #[[MAJOR:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 4), -// CHECK: #[[MINOR:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 4), +// CHECK: #[[MAJOR:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 4), +// CHECK: #[[MINOR:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 4), // CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in // CHECK-DAG: %[[MAJOR_IDX:.*]] = xla_gpu.apply_indexing #[[MAJOR]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo index a0663dd88308fb..88043829ebc8f2 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo @@ -24,7 +24,7 @@ scatter { unique_indices=true, to_apply=add } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 2) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 2) // CHECK-LABEL: func.func @main( // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> @@ -60,4 +60,4 @@ scatter { // CHECK: %[[COMBINED:.*]] = arith.addf %[[CURRENT]], %[[UPD_ELEM]] // CHECK: %[[UPDATED:.*]] = tensor.insert %[[COMBINED]] // CHECK-SAME: into %{{[a-z0-9]+}}[%{{.*}}, %[[RC]]] : tensor<10x5xf32> -// CHECK: xla_gpu.yield %[[UPDATED]] : tensor<10x5xf32> \ No newline at end of file +// CHECK: xla_gpu.yield %[[UPDATED]] : tensor<10x5xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir index 1691d3fd748c23..e88324f698d489 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir @@ -8,7 +8,7 @@ func.func @tensor_extract( : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2], is_simplified: true> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2], is_simplified: true"> // CHECK-LABEL: func.func @tensor_extract( // CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, @@ -67,7 +67,7 @@ func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) } return %ret : tensor<2x4xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3], is_simplified: true> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3], is_simplified: true"> // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { @@ -93,8 +93,8 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) } {some_attr} return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32 } -// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024) -// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5) +// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1024) +// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 32 + 5) // CHECK-LABEL: func.func @for_loop( // CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>, // CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) { @@ -114,12 +114,9 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) // ----- -#map = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36), - domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true> -#map1 = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), - domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 9), - domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> +#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) -> tensor<4000x4x9xf32> { @@ -225,7 +222,7 @@ func.func @vector_extract(%arg0: vector<2x3xf32>, %arg1: index) -> f32 { %v = vector.extract %arg0[%arg1, 2] : f32 from vector<2x3xf32> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 3 + 2), +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 3 + 2), // CHECK-SAME: domain: d0 in [0, 1] // CHECK-LABEL: func.func @vector_extract( @@ -241,7 +238,7 @@ func.func @vector_insert(%arg0: vector<10x24xf32>, %i: index) %out = vector.insert %scalar, %arg0 [1, %i] : f32 into vector<10x24xf32> func.return %out : vector<10x24xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 24), +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 24), // CHECK-SAME: domain: d0 in [0, 23] // CHECK-LABEL: func.func @vector_insert( // CHECK-SAME: %[[VECTOR:.*]]: vector<240xf32>, %[[I:.*]]: index) -> @@ -290,8 +287,8 @@ func.func @for_loop_vector(%t0: vector<32x1024xf32>, %t1: vector<64x8x4xf32>) return %for#0, %for#1, %c0_f32 : vector<32x1024xf32>, vector<64x8x4xf32>, f32 } -// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024) -// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5) +// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1024) +// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 32 + 5) // CHECK-LABEL: func.func @for_loop_vector( // CHECK-SAME: %[[V0:.*]]: vector<32768xf32>, // CHECK-SAME: %[[V1:.*]]: vector<2048xf32>) -> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir index 557335b6a7ff72..594c8e1deec7d2 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir @@ -1,24 +1,24 @@ // RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-fuse-loops \ // RUN: | FileCheck %s -#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (d1 floordiv 30, - ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, - (d1 mod 6) * 32 + d0 mod 32), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> -#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (0, - d0 mod 32, - d0 floordiv 32 + s0 * 4), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -43,7 +43,7 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { } -// CHECK: #[[$FUSED_MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> +// CHECK: #[[$FUSED_MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> // CHECK-SAME: (d1 floordiv 30, ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, // CHECK-SAME: (d1 mod 6) * 32 + d0 mod 32, 0, d0 mod 32, d0 floordiv 32 + s0 * 4), // CHECK-SAME: domain: d0 in [0, 127], d1 in [0, 599], @@ -60,24 +60,24 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { // ----- -#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (d1 floordiv 30, - ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, - (d1 mod 6) * 32 + d0 mod 32), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> -#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (0, - d0 mod 32, - d0 floordiv 32 + s0 * 4), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -108,24 +108,24 @@ func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1 // ----- -#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (d1 floordiv 30, - ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, - (d1 mod 6) * 32 + d0 mod 32), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> -#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (0, - d0 mod 32, - d0 floordiv 32 + s0 * 4), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -158,24 +158,24 @@ func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x // ----- -#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (d1 floordiv 30, - ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, - (d1 mod 6) * 32 + d0 mod 32), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> -#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (0, - d0 mod 32, - d0 floordiv 32 + s0 * 4), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 5], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 5], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -207,24 +207,24 @@ func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> ten // ----- -#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (d1 floordiv 30, - ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, - (d1 mod 6) * 32 + d0 mod 32), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> -#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> - (0, - d0 mod 32, - d0 floordiv 32 + s0 * 4), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], - (d1 mod 5) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 5) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -256,24 +256,24 @@ func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> // ----- -#indexing_map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> - (d1 floordiv 30, - ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, - (d1 mod 6) * 32 + d0 mod 32), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], s2 in [0, 1], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> -#indexing_map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> - (0, - d0 mod 32, - d0 floordiv 32 + s0 * 4), - domain: - d0 in [0, 127], d1 in [0, 599], - s0 in [0, 7], s1 in [0, 0], s2 in [0, 1], - (d1 mod 6) * 32 + d0 mod 32 in [0, 169], - is_simplified: true> +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1, s2] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1, s2] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," +" is_simplified: true"> func.func @do_not_fuse_unused_loop_iv(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir index f02f7012b80cf4..427e764d12b914 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir @@ -1,9 +1,9 @@ // RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-loops-to-scf \ // RUN: --split-input-file | FileCheck %s -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + 1, s1 - 1), - domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]," + "is_simplified: false"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%ra, %rb) @@ -15,9 +15,9 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 func.return %sum : f32 } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + s1), -// CHECK-DAG: #[[$MAPA:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + 1), -// CHECK-DAG: #[[$MAPB:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s1 - 1), +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + s1), +// CHECK-DAG: #[[$MAPA:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1), +// CHECK-DAG: #[[$MAPB:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s1 - 1), // CHECK-LABEL: func.func @loop_op( // CHECK-SAME: %[[IN:.*]]: tensor<1024x32xf32>, @@ -60,9 +60,9 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + 1, s1 - 1), - domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]," + "is_simplified: false"> func.func @loop_yields_value_from_above(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir index dd15bdaafc533f..347ed9a943ef82 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir @@ -124,12 +124,8 @@ func.func @predicated_extract( func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) -> !xla_gpu.indexed_vector<32x2x2xf32, #map1> { @@ -137,8 +133,8 @@ func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x2x2xf32, #map1> func.return %0 : !xla_gpu.indexed_vector<32x2x2xf32, #map1> } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 * 2 + s0, s1) // CHECK: @materialize(%[[INPUT:.*]]: tensor<32x64xf32>, %[[INDEX1:.*]]: index, %[[INDEX2:.*]]: index) @@ -153,12 +149,8 @@ func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1) -> (d0 mod 16, d1), - domain: d0 in [0, 32], d1 in [0, 2], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -166,8 +158,8 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, : !xla_gpu.indexed_vector<32x64xf32, #map> -> tensor<32x64xf32> func.return %0 : tensor<32x64xf32> } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 mod 16, d1) +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1) // CHECK: @insert(%[[INPUT:.*]]: !xla_gpu.indexed_vector<32x64xf32, #[[$MAP]]>, // CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index, @@ -179,7 +171,7 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, // CHECK: %[[SCALAR:.*]] = vector.extract %{{.*}}[%[[S0]], %[[S1]]] // CHECK-SAME: : f32 from vector<2x2xf32> -// CHECK: %[[MAP1_RESULT:.*]]:2 = xla_gpu.apply_indexing +// CHECK: %[[MAP1_RESULT:.*]]:2 = xla_gpu.apply_indexing // CHECK-SAME: #[[$MAP1]](%[[MAP_RESULT1]], %[[MAP_RESULT2]]) // CHECK: %[[NEW_TENSOR:.*]] = tensor.insert %[[SCALAR]] // CHECK-SAME: into %[[TENSOR]][%[[MAP1_RESULT]]#0, %[[MAP1_RESULT]]#1] @@ -189,15 +181,9 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1), - domain: d0 in [0, 32], d1 in [0, 2], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -213,12 +199,8 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.func private @exp(%p0: tensor<32x64xcomplex>, %i: index, %j: index) -> complex -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], - s0 in [0, 2], s1 in [0, 3], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], - s0 in [0, 2], s1 in [0, 3], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> func.func @materialize_complex( %input: tensor<32x64xcomplex>, %output: tensor<32x64xcomplex>, @@ -245,11 +227,8 @@ func.func @materialize_complex( // ----- -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], - s0 in [0, 2], s1 in [0, 3], is_simplified: false> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1), - domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> func.func @insert_complex( %input: !xla_gpu.indexed_vector<32x3x4xcomplex, #map1>, %output: tensor<32x64xcomplex>, @@ -274,4 +253,4 @@ func.func @insert_complex( // CHECK: %[[IMAG:.*]] = vector.extract %[[VECTOR]][%[[C1]], %[[I]], %[[J]]] // CHECK: %[[COMPLEX:.*]] = complex.create %[[REAL]], %[[IMAG]] // CHECK: %[[INSERTED:.*]] = tensor.insert %[[COMPLEX]] into %[[ITER]] -// CHECK: xla_gpu.yield %[[INSERTED]] : tensor<32x64xcomplex> \ No newline at end of file +// CHECK: xla_gpu.yield %[[INSERTED]] : tensor<32x64xcomplex> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir index dd7d639e3273e6..17f478b2838dde 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir @@ -1,11 +1,7 @@ // RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s -#map = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 8), - domain: d0 in [0, 31], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0) -> (d0 mod 8), - domain: d0 in [0, 31], is_simplified: false> -#map2 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), - domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 8), domain: d0 in [0, 31], is_simplified: false"> #map1 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 8), domain: d0 in [0, 31], is_simplified: false"> +#map2 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7], is_simplified: false"> module { func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>, @@ -127,7 +123,7 @@ module { } } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_extract // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index @@ -154,7 +150,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 15], is_simplified: false>(%i) + %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15], is_simplified: false">(%i) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> @@ -164,8 +160,8 @@ module { } } -// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1), +// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_transfer // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir index 8959fbb826bdda..f965b069a772cc 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir @@ -1,16 +1,9 @@ // RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-peel-loops \ // RUN: | FileCheck %s -#map = #xla_gpu.indexing_map< - (d0)[s0, s1] -> (s0, s1), - domain: - d0 in [0, 3], - s0 in [0, 7], - s1 in [0, 10], - d0 + s0 in [0, 9], - d0 + s1 in [0, 12], - is_simplified: false -> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain:" + "d0 in [0, 3], s0 in [0, 7], s1 in [0, 10], d0 + s0 in [0, 9]," + "d0 + s1 in [0, 12], is_simplified: false"> func.func @peel_both_loops(%input: tensor<16x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) @@ -21,9 +14,9 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, } func.return %sum : f32 } -// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9], is_simplified: true> -// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9], is_simplified: true> -// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10], is_simplified: true> +// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9], is_simplified: true"> +// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9], is_simplified: true"> +// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10], is_simplified: true"> // CHECK-LABEL: func.func @peel_both_loops( // CHECK-SAME: %[[INPUT:.*]]: tensor<16x32xf32>, @@ -48,13 +41,8 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, // ----- -#map = #xla_gpu.indexing_map< - (d0)[s0] -> (s0), - domain: - d0 in [0, 3], - s0 in [0, 7], - is_simplified: false -> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (s0)," + "domain: d0 in [0, 3], s0 in [0, 7], is_simplified: false"> func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i] -> (%r0) @@ -72,12 +60,12 @@ func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, // ----- #map = #xla_gpu.indexing_map< - (d0)[s0] -> (s0), - domain: - d0 in [0, 3], - s0 in [0, 7], - s0 mod 5 in [0, 1], - is_simplified: false +" (d0)[s0] -> (s0)," +" domain:" +" d0 in [0, 3]," +" s0 in [0, 7]," +" s0 mod 5 in [0, 1]," +" is_simplified: false" > func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { @@ -91,4 +79,4 @@ func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32, } // CHECK-LABEL: func.func @constraint_exists_after_peeling // CHECK: xla_gpu.loop -// CHECK-NOT: xla_gpu.loop \ No newline at end of file +// CHECK-NOT: xla_gpu.loop diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir index 94c6cddd4a8a40..5f8b9ba5413d84 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir @@ -19,7 +19,7 @@ func.func @row_reduction(%arg0: tensor<128x1027xf32>) return %0 : tensor<128xf32> } -// CHECK: #[[$PAD_AND_RESHAPE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d0, d1 * 128 + d2 * 32 + d3), +// CHECK: #[[$PAD_AND_RESHAPE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d0, d1 * 128 + d2 * 32 + d3), // CHECK-SAME: domain: d0 in [0, 127], d1 in [0, 8], d2 in [0, 3], d3 in [0, 31], d1 * 128 + d2 * 32 + d3 in [0, 1026] // CHECK-LABEL: @row_reduction // CHECK-SAME: %[[IN:.*]]: tensor<128x1027xf32> @@ -77,9 +77,9 @@ func.func @column(%arg0: tensor<2x32x32xf32>) return %0 : tensor<2x32xf32> } -// CHECK: #[[$RESHAPE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3) +// CHECK: #[[$RESHAPE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3) // CHECK-SAME: d1 * 4 + d2 in [0, 31] -// CHECK: #[[$TRANSPOSE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0, d2, d1) +// CHECK: #[[$TRANSPOSE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0, d2, d1) // CHECK-LABEL: @column // CHECK-SAME: %[[IN:.*]]: tensor<2x32x32xf32> // CHECK: %[[C0:.*]] = arith.constant 0.00 diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir index db78b88abd51e0..bfddbd60e2bde7 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir @@ -63,8 +63,9 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %1 = gpu.block_id x scf.for %i = %c0 to %c4 step %c1 { %2 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4)), - domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3], is_simplified: false>[%1, %0, %i] + #xla_gpu.indexing_map< + "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))," + "domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3], is_simplified: false">[%1, %0, %i] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -92,8 +93,9 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt func.func @arg_ranges(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100), - domain: s0 in [0, 42], s1 in [0, 1000], is_simplified: false>[%arg0, %arg1] + #xla_gpu.indexing_map< + "()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)," + "domain: s0 in [0, 42], s1 in [0, 1000], is_simplified: false">[%arg0, %arg1] return %0 : index } @@ -106,8 +108,8 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1), - domain: s0 in [-10, 42], s1 in [0, 1000], is_simplified: false>[%arg0, %arg1] + #xla_gpu.indexing_map<"()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)," + "domain: s0 in [-10, 42], s1 in [0, 1000], is_simplified: false">[%arg0, %arg1] return %0#0, %0#1 : index, index } @@ -124,8 +126,9 @@ func.func @order_summands(%arg1: index) { scf.for %arg2 = %c0 to %c4 step %c1 { scf.for %arg3 = %c0 to %c4 step %c1 { %0 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10), - domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3], is_simplified: false>[%arg2, %arg1, %arg3] + #xla_gpu.indexing_map< + "()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)," + "domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3], is_simplified: false">[%arg2, %arg1, %arg3] "dummy.op"(%0) : (index) -> () } } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir index aaeb665815dcc5..9524c3d32cc6c2 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir @@ -248,7 +248,8 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { %c42_f32 = arith.constant 42.0 : f32 %loop = scf.for %i = %c0 to %c3 step %c1 iter_args(%in_ = %tensor) -> (tensor<100xf32>) { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 4), domain: d0 in [0, 9], is_simplified: false>(%i) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 4)," + "domain: d0 in [0, 9], is_simplified: false">(%i) %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32> scf.yield %updated :tensor<100xf32> } @@ -262,10 +263,11 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000), - domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9), - domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3], is_simplified: false> +#map = #xla_gpu.indexing_map< + "(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)," + "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)," + "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3], is_simplified: false"> func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> { %c0 = arith.constant 0 : index @@ -289,12 +291,23 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, } return %0 : tensor<2400000x9xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d2 * 32768 + (d0 * 4 + d1 * 512 + d3) floordiv 9), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d2 * 32768 + (d0 * 4 + d1 * 512 + d3) floordiv 9), // CHECK-LABEL: func.func @refine_constraints_for_symbol // ----- -#map = #xla_gpu.indexing_map<(d0, d1, d2, d3, d4, d5)[s0] -> ((d0 * 4 + s0) floordiv 6, (d0 * 4 + s0) mod 6), domain: d0 in [0, 29], d1 in [0, 0], d2 in [0, 0], d3 in [0, 0], d4 in [0, 0], d5 in [0, 0], s0 in [0, 3], d0 * 4 + s0 in [0, 29], is_simplified: false> +#map = #xla_gpu.indexing_map< + "(d0, d1, d2, d3, d4, d5)[s0] -> ((d0 * 4 + s0) floordiv 6, (d0 * 4 + s0) mod 6)," + "domain:" + "d0 in [0, 29]," + "d1 in [0, 0]," + "d2 in [0, 0]," + "d3 in [0, 0]," + "d4 in [0, 0]," + "d5 in [0, 0]," + "s0 in [0, 3]," + "d0 * 4 + s0 in [0, 29]," + "is_simplified: false"> func.func @dus(%arg0: tensor<20x30xf32>, %arg1: tensor<5x6xf32>, %arg2: i32, %arg3: i32, %arg4: tensor<20x30xf32>) -> tensor<20x30xf32> { %c24 = arith.constant 24 : index %c15 = arith.constant 15 : index diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir index c77d035e6271b3..0c734ca19882e5 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir @@ -1,8 +1,8 @@ // RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file \ // RUN: -xla-gpu-vectorize-loads-stores -cse -canonicalize | FileCheck %s -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -20,7 +20,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { } return %outer : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 63], is_simplified: true> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 63], is_simplified: true"> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -36,8 +36,8 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 + 1), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0 + 1)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -60,8 +60,8 @@ func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 3 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 3 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -84,8 +84,8 @@ func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (3 * d0 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (3 * d0 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -108,8 +108,8 @@ func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0 * 2)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -134,8 +134,8 @@ func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { // We could vectorize this as a float vector load of double the size, but we // don't currently. -#map = #xla_gpu.indexing_map<(d0)[s0] -> (2 * d0 + s0), - domain: d0 in [0, 127], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 + s0)," + "domain: d0 in [0, 127], s0 in [0, 1], is_simplified: true"> func.func @simple_read_complex(%arg0: tensor<128xcomplex>, %i: index) -> (complex) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -250,10 +250,12 @@ func.func @write_not_yielded(%arg0: tensor<64xf32>) -> tensor<64xf32> { // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), - domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7], is_simplified: true> -#map1 = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512), - domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)," + "domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7], is_simplified: true"> +#map1 = #xla_gpu.indexing_map< + "(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512)," + "domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]," + "is_simplified: true"> func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<32xf32>, %arg3: tensor<131072xf32>, %arg4: index) -> (tensor<131072xf32>, f32) { @@ -280,8 +282,8 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, } return %0#0, %0#1 : tensor<131072xf32>, f32 } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7], is_simplified: true> -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7], is_simplified: true> +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7], is_simplified: true"> +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7], is_simplified: true"> // CHECK-LABEL: @multiple // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index @@ -304,8 +306,8 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 64 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -323,7 +325,7 @@ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { } return %outer : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> ((d0 mod 16) * 4), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> ((d0 mod 16) * 4), // CHECK-LABEL: @remainder_with_modulo // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] @@ -332,8 +334,8 @@ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 65 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -356,10 +358,10 @@ func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map0 = #xla_gpu.indexing_map<(d0) -> (d0 + 5), - domain: d0 in [0, 63], is_simplified: true> -#map1 = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," + "domain: d0 in [0, 63], is_simplified: true"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> module { func.func @apply_indexing_sequence(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -381,8 +383,8 @@ module { } } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2 + 10), -// CHECK-SAME: domain: d0 in [0, 63], is_simplified: true> +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2 + 10), +// CHECK-SAME: domain: d0 in [0, 63], is_simplified: true"> // CHECK-LABEL: @apply_indexing_sequence // CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]] // CHECK: vector.transfer_read {{.*}}[%[[BASE]]] @@ -390,10 +392,10 @@ module { // ----- -#map0 = #xla_gpu.indexing_map<(d0) -> (d0 + 5), - domain: d0 in [0, 63], is_simplified: true> -#map1 = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," + "domain: d0 in [0, 63], is_simplified: true"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> module { func.func @apply_indexing_sequence_same_block(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -418,4 +420,4 @@ module { } // CHECK-LABEL: @apply_indexing_sequence_same_block -// CHECK-NOT: vector.transfer_read \ No newline at end of file +// CHECK-NOT: vector.transfer_read diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index a2de97c39bfe04..f136f7190d1a64 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -213,7 +213,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true> +CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 @@ -278,7 +278,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true> +CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true"> CHECK: tt.func @triton_fn( CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -349,9 +349,9 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249], is_simplified: true> -CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 125), domain: d0 in [0, 1249], is_simplified: true> -CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127), domain: d0 in [0, 1249], is_simplified: true> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249], is_simplified: true"> +CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249], is_simplified: true"> +CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249], is_simplified: true"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[ZERO_64:.*]] = arith.constant 0 : i64 @@ -542,8 +542,8 @@ ENTRY main { TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -// CHECK: #xla_gpu.indexing_map<(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047], is_simplified: true> -// CHECK: #xla_gpu.indexing_map<(d0) -> (d0 mod 32), domain: d0 in [0, 2047], is_simplified: true> +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047], is_simplified: true"> +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047], is_simplified: true"> // CHECK-LABEL: tt.func @triton_fn( // CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr // CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr From 8068e0521336f76f72962de7047e2fae259908b6 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Wed, 25 Sep 2024 04:11:08 -0700 Subject: [PATCH 249/483] [IFRT] Add pass for converting ifrt.Reshard of non-resharding arrays to ifrt.CopyArrays. PiperOrigin-RevId: 678637771 --- .../xla/xla/python/ifrt/ir/tests/BUILD | 1 + .../ir/tests/ifrt_reshard_to_copy_arrays.mlir | 106 ++++++++++ .../xla/xla/python/ifrt/ir/transforms/BUILD | 3 + .../ifrt_reshard_to_copy_arrays_pass.cc | 187 ++++++++++++++++++ .../xla/python/ifrt/ir/transforms/passes.h | 3 + .../xla/python/ifrt/ir/transforms/passes.td | 39 ++++ .../xla/python/ifrt/ir/transforms/utils.cc | 10 + .../xla/xla/python/ifrt/ir/transforms/utils.h | 4 + 8 files changed, 353 insertions(+) create mode 100644 third_party/xla/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir create mode 100644 third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc diff --git a/third_party/xla/xla/python/ifrt/ir/tests/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/BUILD index ab7e1250422b5a..77ff4fa86c73fe 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/BUILD @@ -13,6 +13,7 @@ lit_test_suite( "ifrt_duplicated_callee_elimination.mlir", "ifrt_merge_reshards.mlir", "ifrt_outline_atom_program_to_module.mlir", + "ifrt_reshard_to_copy_arrays.mlir", "ifrt_verify_donation.mlir", "ifrt_verify_sharding_specified.mlir", "spmd_expansion.mlir", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir new file mode 100644 index 00000000000000..cf8231b67050a4 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir @@ -0,0 +1,106 @@ +// RUN: ifrt-opt %s -ifrt-reshard-to-copy-arrays -verify-diagnostics -split-input-file | FileCheck %s + +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +// CHECK-LABEL: @reshard_to_copy_arrays +module @reshard_to_copy_arrays { + func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} { + // CHECK: %[[COPIED:.+]], %{{.+}} = ifrt.CopyArrays(%arg0) + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + // CHECK: return %[[COPIED]] + return %0 : !array1 + } +} + +// ----- + +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> +// CHECK-LABEL: @reshard_not_converted +module @reshard_not_converted { + func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} { + // CHECK: %[[RESHARDED:.+]], %{{.+}} = ifrt.Reshard(%arg0) + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + // CHECK: return %[[RESHARDED]] + return %0 : !array1 + } +} + +// ----- + +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +!array2 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> +// CHECK-LABEL: @extract_copy_from_reshard +module @extract_copy_from_reshard { + func.func @main(%arg0: !array0, %arg1: !array1) -> (!array1, !array2) + attributes {ifrt.function} { + // CHECK: %[[RESHARDED:.+]], %{{.+}} = ifrt.Reshard(%arg1) {donated = true} + // CHECK: %[[COPIED:.+]], %{{.+}} = ifrt.CopyArrays(%arg0) {donated = true} + %0, %1, %ctrl_0 = ifrt.Reshard(%arg0, %arg1) {donated = true} + : (!array0, !array1) -> (!array1, !array2) + // CHECK: return %[[COPIED]], %[[RESHARDED]] + return %0, %1: !array1, !array2 + } +} + +// ----- + +// Verifies that an ifrt.CopyArrays is introduced for each set of devices. +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +!array2 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> +// CHECK-LABEL: @extract_copy_per_device_set +module @extract_copy_per_device_set { + func.func @main(%arg0: !array0, %arg1: !array1, %arg2: !array1) + -> (!array1, !array2, !array0) attributes {ifrt.function} { + // CHECK: %[[RESHARDED:.+]], %{{.+}} = ifrt.Reshard(%arg1) + // CHECK-DAG: %[[COPIED_1:.+]], %{{.+}} = ifrt.CopyArrays(%arg0) + // CHECK-DAG: %[[COPIED_2:.+]], %{{.+}} = ifrt.CopyArrays(%arg2) + %0, %1, %2, %ctrl_0 = ifrt.Reshard(%arg0, %arg1, %arg2) + : (!array0, !array1, !array1) -> (!array1, !array2, !array0) + // CHECK: return %[[COPIED_1]], %[[RESHARDED]], %[[COPIED_2]] + return %0, %1, %2: !array1, !array2, !array0 + } +} + +// ----- + +// Verifies that the control inputs are passed to the CopyArrays. +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +!array2 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> +// CHECK-LABEL: @control_inputs_added_to_copy_arrays +module @control_inputs_added_to_copy_arrays { + func.func @main(%arg0: !array0, %arg1: !array1) -> (!array1, !array2) + attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %[[CTRL:.+]] = ifrt.Call @add_one(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + : (!array0) -> !array0 + // CHECK: %[[RESHARDED:.+]], %{{.+}} = ifrt.Reshard(%arg1) after %[[CTRL:.+]] + // CHECK: %[[COPIED:.+]], %{{.+}} = ifrt.CopyArrays(%[[OUT:.+]]) after %[[CTRL:.+]] + %1, %2, %ctrl_1 = ifrt.Reshard(%0, %arg1) after %ctrl_0 + : (!array0, !array1) -> (!array1, !array2) + // CHECK: return %[[COPIED]], %[[RESHARDED]] + return %1, %2: !array1, !array2 + } + + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index 620362de4c1b50..882944a1602632 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -32,6 +32,7 @@ cc_library( "ifrt_duplicated_callee_elimination_pass.cc", "ifrt_merge_reshards_pass.cc", "ifrt_outline_atom_program_to_module_pass.cc", + "ifrt_reshard_to_copy_arrays_pass.cc", "ifrt_verify_donation_pass.cc", "ifrt_verify_sharding_specified_pass.cc", "spmd_expandable_interface_verification_pass.cc", @@ -52,6 +53,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], ) @@ -73,6 +75,7 @@ cc_library( hdrs = ["utils.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla/python/ifrt/ir", "@com_google_absl//absl/log:check", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc new file mode 100644 index 00000000000000..2ffaf6f8a63f8e --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc @@ -0,0 +1,187 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/utils.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTRESHARDTOCOPYARRAYSPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +class ReshardToCopyArraysOpPattern + : public mlir::OpRewritePattern { + public: + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + xla::ifrt::ReshardOp op, mlir::PatternRewriter& rewriter) const override { + // Map from devices attribute to indices of the input arrays that are just + // copied to those devices. + llvm::DenseMap> + copy_indices; + // Indices of the input arrays that are resharded. + llvm::SmallVector reshard_indices; + for (const auto& [idx, pair] : + llvm::enumerate(llvm::zip(op.getInputs(), op.getOutputs()))) { + auto in_array_type = + mlir::cast(std::get<0>(pair).getType()); + if (in_array_type == nullptr) { + op.emitOpError() << "requires all inputs to be `IfrtArrayType`. Input #" + << idx << ": " << std::get<0>(pair).getType(); + return mlir::failure(); + } + auto out_array_type = + mlir::cast(std::get<1>(pair).getType()); + if (out_array_type == nullptr) { + op.emitOpError() + << "requires all outputs to be `IfrtArrayType`. Output #" << idx + << ": " << std::get<1>(pair).getType(); + return mlir::failure(); + } + if (!IsReshard(in_array_type, out_array_type)) { + copy_indices[out_array_type.getDevicesAttr()].push_back(idx); + } else { + reshard_indices.push_back(idx); + } + } + + if (reshard_indices.size() == op.getInputs().size()) { + // All arrays are resharded. No need to modify the ifrt.Reshard op. + return mlir::failure(); + } + + if (!op.getControlOutput().getUses().empty()) { + // If the control output dependency of the ifrt.Reshard op is used then it + // is unclear what to do with the newly added ifrt.CopyArrays ops. The + // conservative approach would be to add these as control dependencies to + // all the ops that have a control dependency on the ifrt.Reshard op. + // However, we could also add them just to the ops that have a control + // dependency on the ifrt.Reshard op and use the same devices. For now, + // we will just throw an error as the ifrt.Reshard control dependencies + // are not used at the moment. + op.emitOpError() << " cannot extract `ifrt.CopyArrays` from " + "`ifrt.Reshard` with control dependency output"; + return mlir::failure(); + } + + llvm::SmallVector outputs; + outputs.resize(op.getOutputs().size()); + // If an ifrt.Reshard is still left, then we replace the usage of the + // current ifrt.Reshard op's control output with its control output. + // Otherwise, we replace it with the control output of the last + // ifrt.CopyArrays op. + mlir::Value control_output; + + // Replace the ifrt.Reshard with a pruned version that only takes the arrays + // that are resharded. + llvm::SmallVector reshard_input_values; + llvm::SmallVector reshard_output_types; + for (int idx : reshard_indices) { + outputs[idx] = op.getOutputs()[idx]; + reshard_input_values.push_back(op.getInputs()[idx]); + reshard_output_types.push_back(op.getOutputs()[idx].getType()); + } + if (!reshard_input_values.empty()) { + auto reshard_op = rewriter.create( + op.getLoc(), + /*outputs=*/reshard_output_types, + /*control_output=*/op.getControlOutput().getType(), + /*inputs=*/reshard_input_values, + /*donated=*/op.getDonated(), + /*control_inputs=*/op.getControlInputs()); + for (const auto& [idx, output] : + llvm::zip(reshard_indices, reshard_op.getOutputs())) { + outputs[idx] = output; + } + control_output = reshard_op.getControlOutput(); + } + + // Add an ifrt.CopyArrays op for each set of arrays that are copied to a + // set of devices. The new ifrt.CopyArrays ops will inherit *all* the input + // control dependencies of the ifrt.Reshard op. They could receive a subset + // of the control dependencies (e.g., dependencies generated by ops running + // use the same devices as the ones the arrays are coppied to), but that is + // not supported yet. + for (const auto& [devices_attr, indices] : copy_indices) { + llvm::SmallVector copy_input_values; + llvm::SmallVector copy_output_types; + for (int idx : indices) { + copy_input_values.push_back(op.getInputs()[idx]); + copy_output_types.push_back(op.getOutputs()[idx].getType()); + } + auto copy_arrays_op = rewriter.create( + op.getLoc(), + /*outputs=*/copy_output_types, + /*control_output=*/op.getControlOutput().getType(), + /*inputs=*/copy_input_values, + /*donated=*/op.getDonated(), + /*control_inputs=*/op.getControlInputs()); + for (const auto& [idx, output] : + llvm::zip(indices, copy_arrays_op.getOutputs())) { + outputs[idx] = output; + } + if (reshard_input_values.empty()) { + control_output = copy_arrays_op.getControlOutput(); + } + } + outputs.push_back(control_output); + rewriter.replaceOp(op, outputs); + return mlir::success(); + } +}; + +class IfrtReshardToCopyArraysPass + : public impl::IfrtReshardToCopyArraysPassBase< + IfrtReshardToCopyArraysPass> { + public: + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + mlir::ModuleOp module_op = getOperation(); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(module_op, + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +CreateIfrtReshardToCopyArraysPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h index da7ec1ab599795..b1f528a2d1e2a6 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h @@ -49,6 +49,9 @@ CreateIfrtVerifyDonationPass(); std::unique_ptr> CreateIfrtVerifyShardingSpecifiedPass(); +std::unique_ptr> +CreateIfrtReshardToCopyArraysPass(); + // Generated definitions. This should be placed after all Pass creations. #define GEN_PASS_REGISTRATION #include "xla/python/ifrt/ir/transforms/passes.h.inc" // IWYU pragma: export diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td index 10215b72653e0c..a78b5729213c0a 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td @@ -203,5 +203,44 @@ Verify that each `!ifrt.array` has sharding attribute that is not of type let constructor = "CreateIfrtVerifyShardingSpecifiedPass()"; } +def IfrtReshardToCopyArraysPass : + Pass<"ifrt-reshard-to-copy-arrays", "mlir::ModuleOp"> { + let summary = "Replaces `ifrt.Reshard` with `ifrt.CopyArrays`"; + let description = [{ +Replaces each `ifrt.Reshard` op with an `ifrt.Reshard` op with inputs only +the arrays that are being resharded, and several `ifrt.CopyArrays` ops to copy +the arrays that are not being resharded. An `ifrt.CopyArrays` op is added for +unique output `ifrt.DevicesAttr`. + +For example, the following code +```mlir +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +!array2 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> +func.func @main(%arg0: !array0, %arg1: !array1) -> (!array1, !array2) + attributes {ifrt.function} { + %0, %1, %ctrl_0 = ifrt.Reshard(%arg0, %arg1) + : (!array0, !array1) -> (!array1, !array2) + return %0, %1: !array1, !array2 +} +``` + +will be replaced by: + +```mlir +func.func @main(%arg0: !array0, %arg1: !array1) -> (!array1, !array2) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Reshard(%arg1) : (!array1) -> !array2 + %1, %ctrl_1 = ifrt.CopyArrays(%arg0) : (!array0) -> !array1 + return %0, %1: !array1, !array2 +} +``` + }]; + + let constructor = "CreateIfrtReshardToCopyArraysPass()"; +} #endif // XLA_PYTHON_IFRT_IR_TRANSFORMS_PASSES_TD_ diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc b/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc index b1cb219e5e49fe..68c73b3dadb564 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Support/LLVM.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" namespace xla { namespace ifrt { @@ -30,5 +31,14 @@ mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module) { return func; } +bool IsReshard(xla::ifrt::IfrtArrayType from, xla::ifrt::IfrtArrayType to) { + if (from.getShape() == to.getShape() && + from.getShardingAttr() == to.getShardingAttr() && + from.getDevices().size() == to.getDevices().size()) { + return false; + } + return true; +} + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/utils.h b/third_party/xla/xla/python/ifrt/ir/transforms/utils.h index 81528e97f418ae..1ffda0f00ef2a6 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/utils.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/utils.h @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" namespace xla { namespace ifrt { @@ -26,6 +27,9 @@ namespace ifrt { // fails otherwise. mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module); +// Returns true if transferring between from and to array requires a reshard. +bool IsReshard(xla::ifrt::IfrtArrayType from, xla::ifrt::IfrtArrayType to); + } // namespace ifrt } // namespace xla From 710774f0a2336d328090f9380a2fbe70eab675ac Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 25 Sep 2024 04:21:36 -0700 Subject: [PATCH 250/483] PR #17579: Algebraic simplifier: mark iota non-negative. Imported from GitHub PR https://github.com/openxla/xla/pull/17579 Copybara import of the project: -- 02c09a8dd5bb62ffd3729a23813a0e66f672a5a3 by Ilia Sergachev : Algebraic simplifier: mark iota non-negative. -- 4735edc2bac278ea1e87035f128a2f5d0f2a7a59 by Ilia Sergachev : Fix unrelated clang-format issues to make CI happy Merging this change closes #17579 PiperOrigin-RevId: 678640567 --- third_party/xla/xla/service/algebraic_simplifier.cc | 3 ++- .../xla/xla/service/algebraic_simplifier_test.cc | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index f54864220e2e69..ff9a2f688cc874 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -530,7 +530,8 @@ bool AlgebraicSimplifierVisitor::IsNonNegative( return hlo->operand(0) == hlo->operand(1); } case HloOpcode::kAbs: - case HloOpcode::kExp: { + case HloOpcode::kExp: + case HloOpcode::kIota: { return true; } case HloOpcode::kBroadcast: { diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 1af8721066b8bf..ea67d07a141967 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -10561,6 +10561,18 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationSelMaxBcast) { m::Broadcast(m::ConstantScalar()))))); } +TEST_F(AlgebraicSimplifierTest, AbsEliminationIota) { + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(R"( + e { + i = s32[3,2] iota(), iota_dimension=0 + ROOT a = s32[3,2] abs(i) + } + )")); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Iota())); +} + TEST_F(AlgebraicSimplifierTest, SimplifyRedundantBitcastConvert) { const char* kModuleStr = R"( HloModule m From b47eb03464a06c81ce101a1300c791f1a3eaf3e8 Mon Sep 17 00:00:00 2001 From: mayuyuace Date: Wed, 25 Sep 2024 04:29:28 -0700 Subject: [PATCH 251/483] PR #15904: [XLA:GPU]implement sycl platform id Imported from GitHub PR https://github.com/openxla/xla/pull/15904 Copybara import of the project: -- df9b82ad0c35cb3f8ad8253b20a38a74f9318d73 by mayuyuace : implement sycl platform id -- 72cf11f61eed4f729d0e5800401fb26da8693a06 by mayuyuace : remove override' of GetUncachedExecutor Merging this change closes #15904 PiperOrigin-RevId: 678642780 --- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/computation_placer.cc | 3 +++ third_party/xla/xla/service/gpu/BUILD | 2 ++ third_party/xla/xla/service/gpu/gpu_executable.cc | 3 +++ .../xla/xla/service/gpu/gpu_transfer_manager.cc | 10 ++++++++++ .../xla/xla/stream_executor/sycl/sycl_platform.h | 2 +- 6 files changed, 20 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 53d6b29fc5bcd7..6fa72a430d8fde 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4356,6 +4356,7 @@ cc_library( "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/computation_placer.cc b/third_party/xla/xla/service/computation_placer.cc index ee0cf2932a1e86..43f351a5489592 100644 --- a/third_party/xla/xla/service/computation_placer.cc +++ b/third_party/xla/xla/service/computation_placer.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -222,6 +223,8 @@ static bool InitModule() { stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer); xla::ComputationPlacer::RegisterComputationPlacer( stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer( + stream_executor::sycl::kSyclPlatformId, &CreateComputationPlacer); return true; } static bool module_initialized = InitModule(); diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2f2173c0295ae7..7045d03c85cb73 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -597,6 +597,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:scoped_activate_context", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -1211,6 +1212,7 @@ cc_library( "//xla/stream_executor:memory_allocation", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index e3d939e873a228..e43835397d0956 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -81,6 +81,7 @@ limitations under the License. #include "xla/stream_executor/scoped_module_handle.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -177,6 +178,8 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( << std::get(gpu_version_).ToString() << "}, but was {" << std::get(cc).ToString() << "}"; + } else if (platform_id == stream_executor::sycl::kSyclPlatformId) { + // TODO: Add check. } else { return Internal("Unknown platform"); } diff --git a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc index dc770514bdda2b..ffff4acdd1dfbe 100644 --- a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc +++ b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -369,11 +370,20 @@ static std::unique_ptr CreateAMDGPUTransferManager() { .getPointerSize(0 /* default address space */)); } +static std::unique_ptr CreateSYCLTransferManager() { + return std::make_unique( + /*id=*/stream_executor::sycl::kSyclPlatformId, + /*pointer_size=*/llvm::DataLayout(xla::gpu::spir::DataLayout()) + .getPointerSize(0 /* default address space */)); +} + static bool InitModule() { xla::TransferManager::RegisterTransferManager( stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager); xla::TransferManager::RegisterTransferManager( stream_executor::rocm::kROCmPlatformId, &CreateAMDGPUTransferManager); + xla::TransferManager::RegisterTransferManager( + stream_executor::sycl::kSyclPlatformId, &CreateSYCLTransferManager); return true; } diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h index 61f0eb3d5372b9..7c70e5d17e0f6e 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h +++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h @@ -60,7 +60,7 @@ class SyclPlatform : public Platform { // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - int ordinal) override; + int ordinal); // This platform's name. std::string name_; From 9609230b24a000340ad619ba06423abf4dff4b94 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 04:34:25 -0700 Subject: [PATCH 252/483] Automated Code Change PiperOrigin-RevId: 678644037 --- tensorflow/core/graph/graph.h | 2 +- tensorflow/core/graph/graph_test.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index c7a4f696bf126d..9ca28088533672 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -1027,7 +1027,7 @@ inline bool NodeIter::operator!=(const NodeIter& rhs) const { } inline void NodeIter::operator++() { - while (1) { + while (true) { DCHECK_LE(id_, graph_->num_node_ids()); ++id_; if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) { diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index c0582f7092f40d..f957d45803f9a3 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -624,12 +624,12 @@ TEST_F(GraphTest, EdgeDebugString) { EXPECT_EQ(s1, "[id=0 :0 -> :0]"); // Print edge with null src node - auto e2 = BuildEdge(2, 0, b, 1, 1); + auto e2 = BuildEdge(2, nullptr, b, 1, 1); auto s2 = e2->DebugString(); EXPECT_EQ(s2, "[id=2 :1 -> B:1]"); // Print edge with null dst node - auto e3 = BuildEdge(3, a, 0, 2, 1); + auto e3 = BuildEdge(3, a, nullptr, 2, 1); auto s3 = e3->DebugString(); EXPECT_EQ(s3, "[id=3 A:2 -> :1]"); } From ee1c9cb39b05a27edeef3750097d1750ecf9e6a5 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 25 Sep 2024 04:38:55 -0700 Subject: [PATCH 253/483] [XLA:GPU] Do not fuse custom fusions in the multi-output-fusion pass. Without this change, some custom fusions would be seen as fusible, e.g. if they were certain kinds of reduction-based fusions or elementwise fusions. However there's no support for fusing custom fusions. PiperOrigin-RevId: 678645100 --- third_party/xla/xla/service/gpu/gpu_fusible.cc | 4 +++- .../xla/xla/service/gpu/gpu_fusible_test.cc | 17 ++++------------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index 94e67e43c1adb6..4fe265fac93f5f 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -906,8 +906,10 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) { // with any other instruction. // Note that scatter cannot be the root of a multi-output fusion because // its emitter doesn't support it. + // + // Custom fusions cannot be fused with anything. - return instr.IsFusible() && + return instr.IsFusible() && !instr.IsCustomFusion() && (IsInputFusibleReduction(instr) || IsInputFusibleTranspose(instr) || instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here. instr.IsElementwise()); diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc index 735709cbd346f8..b86071a3484fbe 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc @@ -488,23 +488,14 @@ TEST_F(GpuFusibleTest, TEST_F(GpuFusibleTest, CustomFusionIsNotFusibleAsConsumer) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( -HloModule m - triton_fusion { - p0 = f16[20,3]{1,0} parameter(0) - p1 = f16[3,40]{1,0} parameter(1) - dot = f16[20,40]{1,0} dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT c = f16[20,40]{0,1} copy(dot) + p = s32[20,3] parameter(0) + ROOT neg = s32[20,3] negate(p) } ENTRY e { - p0 = f16[20,3]{1,0} parameter(0) - n = f16[20,3]{1,0} negate(p0) - p1 = f16[3,40]{1,0} parameter(1) - ROOT r = f16[20,40]{0,1} fusion(n, p1), - kind=kCustom, - calls=triton_fusion + p = s32[20,3] parameter(0) + ROOT r = s32[20,3] fusion(p), kind=kCustom, calls=triton_fusion })")); const HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_FALSE(IsFusibleAsMultiOutputFusionRoot(*root)); From 1062a7bb25561fab7a8d337c1bc76fd64e59f66b Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Wed, 25 Sep 2024 05:16:40 -0700 Subject: [PATCH 254/483] [XLA:GPU] Don't fall back to the default layout in all cases, not just entry computation layout. Not resetting of the shapes of the entry computation's parameters has the same reasoning as entry_computation_layout. Other ops are reset by the layout normalization pass anyway. PiperOrigin-RevId: 678656058 --- third_party/xla/xla/service/hlo_parser.cc | 12 +-- third_party/xla/xla/service/hlo_parser.h | 14 ++-- .../xla/xla/service/hlo_parser_test.cc | 76 +++++++++++++++++-- .../xla/xla/tools/hlo_module_loader.cc | 12 +-- third_party/xla/xla/tools/hlo_module_loader.h | 4 +- 5 files changed, 88 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index abf31f1767d7f8..fcc3f93863ba07 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -547,7 +547,7 @@ class HloParserImpl : public HloParser { bool ParseJsonDict(std::string* result); bool ParseDimensionSizes(std::vector* dimension_sizes, std::vector* dynamic_dimensions); - bool ParseShape(Shape* result, bool set_to_default_layout = true); + bool ParseShape(Shape* result); bool ParseLayout(Layout* layout); bool ParseLayoutIntAttribute(int64_t* attr_value, absl::string_view attr_description); @@ -915,7 +915,7 @@ bool HloParserImpl::ParseComputationLayout( } while (lexer_.GetKind() != TokKind::kRparen) { Shape param; - if (!ParseShape(¶m, options_.fill_missing_module_parameter_layouts())) { + if (!ParseShape(¶m)) { return false; } computation_layout->add_parameter_layout(ShapeLayout(param)); @@ -935,7 +935,7 @@ bool HloParserImpl::ParseComputationLayout( return false; } Shape result; - if (!ParseShape(&result, options_.fill_missing_module_parameter_layouts())) { + if (!ParseShape(&result)) { return false; } *computation_layout->mutable_result_layout() = ShapeLayout(result); @@ -6097,7 +6097,7 @@ bool HloParserImpl::ParseLayout(Layout* layout) { // tuple_elements // ::= /*empty*/ // ::= shape (',' shape)* -bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) { +bool HloParserImpl::ParseShape(Shape* result) { if (EatIfPresent(TokKind::kLparen)) { // Tuple std::vector shapes; if (lexer_.GetKind() == TokKind::kRparen) { @@ -6106,7 +6106,7 @@ bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) { // shape (',' shape)* do { shapes.emplace_back(); - if (!ParseShape(&shapes.back(), set_to_default_layout)) { + if (!ParseShape(&shapes.back())) { return false; } } while (EatIfPresent(TokKind::kComma)); @@ -6132,7 +6132,7 @@ bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) { result->add_dimensions(dimension_sizes[i]); result->set_dynamic_dimension(i, dynamic_dimensions[i]); } - if (set_to_default_layout || ShapeUtil::IsScalar(*result)) { + if (options_.fill_missing_layouts() || ShapeUtil::IsScalar(*result)) { LayoutUtil::SetToDefaultLayout(result); } // We need to lookahead to see if a following open brace is the start of a diff --git a/third_party/xla/xla/service/hlo_parser.h b/third_party/xla/xla/service/hlo_parser.h index c6b5f545c54cd4..17012f779c01ac 100644 --- a/third_party/xla/xla/service/hlo_parser.h +++ b/third_party/xla/xla/service/hlo_parser.h @@ -30,19 +30,17 @@ namespace xla { class HloParserOptions { public: - // If the entry computation parameter layout is not set, set the layout to be - // the default (e.g. {3,2,1,0}). - HloParserOptions& set_fill_missing_module_parameter_layouts(bool value) { - fill_missing_module_parameter_layouts_ = value; + // When a shape layout is not set (e.g. in the entry computation layout or + // instruction layout), set the layout to be the default (e.g. {3,2,1,0}). + HloParserOptions& set_fill_missing_layouts(bool value) { + fill_missing_layouts_ = value; return *this; } - bool fill_missing_module_parameter_layouts() const { - return fill_missing_module_parameter_layouts_; - } + bool fill_missing_layouts() const { return fill_missing_layouts_; } private: - bool fill_missing_module_parameter_layouts_ = true; + bool fill_missing_layouts_ = true; }; // Given a string in the HloModule::ToString() format, parses the string and diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index 0035b317fb418c..40c12a92972caa 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -3435,8 +3435,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { absl::StatusOr> module = ParseAndReturnUnverifiedModule( - original, {}, - HloParserOptions().set_fill_missing_module_parameter_layouts(false)); + original, {}, HloParserOptions().set_fill_missing_layouts(false)); TF_ASSERT_OK(module.status()); // Do not set the default layout. EXPECT_FALSE(module.value()->entry_computation_layout().AnyLayoutSet()); @@ -3460,8 +3459,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { absl::StatusOr> module = ParseAndReturnUnverifiedModule( - original, {}, - HloParserOptions().set_fill_missing_module_parameter_layouts(true)); + original, {}, HloParserOptions().set_fill_missing_layouts(true)); TF_ASSERT_OK(module.status()); EXPECT_THAT(module.value() ->entry_computation_layout() @@ -3489,8 +3487,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { absl::StatusOr> module = ParseAndReturnUnverifiedModule( - original, {}, - HloParserOptions().set_fill_missing_module_parameter_layouts(true)); + original, {}, HloParserOptions().set_fill_missing_layouts(true)); TF_ASSERT_OK(module.status()); EXPECT_THAT(module.value() ->entry_computation_layout() @@ -3500,6 +3497,73 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { ElementsAre(2, 1, 0)); } +TEST_F(HloParserTest, DoNotFallBackToDefaultLayoutIfDisabled) { + const std::string original = R"( +HloModule t + +ENTRY main { + p0 = f16[16,32,48,64]{3,2,1,0} parameter(0) + p1 = f16[80,64,48,32]{3,2,1,0} parameter(1) + ROOT dot = f16[64,32,16,80] dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3} +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, HloParserOptions().set_fill_missing_layouts(false)); + TF_ASSERT_OK(module.status()); + EXPECT_FALSE(module.value() + ->entry_computation() + ->root_instruction() + ->shape() + .has_layout()); +} + +TEST_F(HloParserTest, FallBackToDefaultLayoutIfEnabled) { + const std::string original = R"( +HloModule t + +ENTRY main { + p0 = f16[16,32,48,64]{3,2,1,0} parameter(0) + p1 = f16[80,64,48,32]{3,2,1,0} parameter(1) + ROOT dot = f16[64,32,16,80] dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3} +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, HloParserOptions().set_fill_missing_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation() + ->root_instruction() + ->shape() + .layout() + .minor_to_major(), + ElementsAre(3, 2, 1, 0)); +} + +TEST_F(HloParserTest, FallBackToDefaultLayoutIfAlreadySet) { + const std::string original = R"( +HloModule t + +ENTRY main { + p0 = f16[16,32,48,64]{3,2,1,0} parameter(0) + p1 = f16[80,64,48,32]{3,2,1,0} parameter(1) + ROOT dot = f16[64,32,16,80]{1,2,0,3} dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3} +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, HloParserOptions().set_fill_missing_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation() + ->root_instruction() + ->shape() + .layout() + .minor_to_major(), + ElementsAre(1, 2, 0, 3)); +} + TEST_F(HloParserTest, NoEntry) { const std::string original = R"(HloModule no_entry: c1 { diff --git a/third_party/xla/xla/tools/hlo_module_loader.cc b/third_party/xla/xla/tools/hlo_module_loader.cc index f765acaeeef8ca..85c1effb10732a 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.cc +++ b/third_party/xla/xla/tools/hlo_module_loader.cc @@ -72,8 +72,7 @@ absl::StatusOr> LoadModuleFromData( const std::string& data, std::string_view format, const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, - BufferAssignmentProto* buffer_assignment_proto, - bool fill_missing_module_parameter_layouts) { + BufferAssignmentProto* buffer_assignment_proto, bool fill_missing_layouts) { DebugOptions debug_options = GetDebugOptionsFromFlags(); std::unique_ptr module; if (format == "hlo" || format == "txt") { @@ -85,8 +84,7 @@ absl::StatusOr> LoadModuleFromData( config_modifier_hook(&config); } HloParserOptions options; - options.set_fill_missing_module_parameter_layouts( - fill_missing_module_parameter_layouts); + options.set_fill_missing_layouts(fill_missing_layouts); TF_ASSIGN_OR_RETURN( module, ParseAndReturnUnverifiedModule(hlo_string, config, options)); } else { @@ -136,16 +134,14 @@ absl::StatusOr> LoadModuleFromFile( const std::string& path, std::string format, const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, - BufferAssignmentProto* buffer_assignment_proto, - bool fill_missing_module_parameter_layouts) { + BufferAssignmentProto* buffer_assignment_proto, bool fill_missing_layouts) { std::string data; if (format.empty()) { format = std::string(tsl::io::Extension(path)); } TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data)); return LoadModuleFromData(data, format, ovr_config, config_modifier_hook, - buffer_assignment_proto, - fill_missing_module_parameter_layouts); + buffer_assignment_proto, fill_missing_layouts); } absl::StatusOr> diff --git a/third_party/xla/xla/tools/hlo_module_loader.h b/third_party/xla/xla/tools/hlo_module_loader.h index 9000422fd8d4f0..da44e71428332a 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.h +++ b/third_party/xla/xla/tools/hlo_module_loader.h @@ -61,7 +61,7 @@ absl::StatusOr> LoadModuleFromData( hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr, - bool fill_missing_module_parameter_layouts = true); + bool fill_missing_layouts = true); // Loads an HLO module from file. // The file can be one of the followings: @@ -84,7 +84,7 @@ absl::StatusOr> LoadModuleFromFile( hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr, - bool fill_missing_module_parameter_layouts = true); + bool fill_missing_layouts = true); // Loads an HLO snapshot from a string, only for its inputs // The data format must be one of the following: From 0ecb2dca761ff571b9a413515d4aeeb01a44750e Mon Sep 17 00:00:00 2001 From: terryysun Date: Wed, 25 Sep 2024 06:13:37 -0700 Subject: [PATCH 255/483] PR #15144: [NVIDIA GPU] Use memcpy for intra-node all-to-all Imported from GitHub PR https://github.com/openxla/xla/pull/15144 The communications of all-to-all rely on NCCL even when it is intra-node. By leveraging memcpy for intra-node communications, all-to-all can have better performance while reducing SM consumption (right now consumed by NCCL). Copybara import of the project: -- 38720c73f5817dbbf5b6d98751140bb53f572690 by Terry Sun : memcpyp2p for local a2a -- 90018f4a3fe0ed3018767db810518faf9435bc93 by Terry Sun : use nccl to pass recv ptrs -- f9b75b0e088286ded770b27fff9d020f8e85a648 by Terry Sun : refactor and cleanup Merging this change closes #15144 PiperOrigin-RevId: 678671759 --- .../xla/service/gpu/ir_emitter_unnested.cc | 4 +- third_party/xla/xla/service/gpu/runtime/BUILD | 3 + .../gpu/runtime/nccl_all_gather_thunk.cc | 3 +- .../gpu/runtime/nccl_all_gather_thunk.h | 3 +- .../gpu/runtime/nccl_all_reduce_thunk.cc | 6 +- .../gpu/runtime/nccl_all_reduce_thunk.h | 6 +- .../gpu/runtime/nccl_all_to_all_thunk.cc | 184 +++++++++++++++++- .../gpu/runtime/nccl_all_to_all_thunk.h | 26 ++- .../xla/xla/service/gpu/runtime/nccl_api.cc | 27 +++ .../xla/xla/service/gpu/runtime/nccl_api.h | 8 + .../xla/service/gpu/runtime/nccl_api_stub.cc | 10 + .../xla/service/gpu/runtime/nccl_clique_key.h | 3 +- .../nccl_collective_broadcast_thunk.cc | 3 +- .../runtime/nccl_collective_broadcast_thunk.h | 2 +- .../xla/xla/service/gpu/runtime/thunk.h | 25 +++ .../xla/stream_executor/cuda/cuda_driver.cc | 24 +++ .../xla/stream_executor/cuda/cuda_executor.cc | 13 ++ .../xla/stream_executor/cuda/cuda_executor.h | 3 + .../xla/xla/stream_executor/gpu/gpu_driver.h | 16 ++ .../xla/xla/stream_executor/stream_executor.h | 5 + .../xla/xla/tests/collective_ops_e2e_test.cc | 40 ++++ 21 files changed, 400 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index bfc49082958ba4..b4406783adcb0e 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -2140,7 +2140,9 @@ absl::Status IrEmitterUnnested::EmitNcclThunk( thunk_info.profile_annotation = async_start->name(); } auto thunk = std::make_unique( - thunk_info, NcclApi::Default(), inst, /*buffers=*/std::move(buffers)); + thunk_info, NcclApi::Default(), inst, + /*buffers=*/std::move(buffers), + ir_emitter_context_->debug_options().xla_gpu_use_memcpy_local_p2p()); GetCollectivesAsyncEvents().insert({async_start, thunk->async_events()}); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 0cf0912aac171d..9427f585206df4 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -824,7 +824,10 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/runtime:nccl_clique_key", "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc index 49aae84589b97a..9089339e0419b1 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc @@ -66,7 +66,8 @@ absl::Status CheckImplementableInst(const HloAllGatherInstruction* inst) { NcclAllGatherStartThunk::NcclAllGatherStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, - const HloAllGatherInstruction* inst, std::vector buffers) + const HloAllGatherInstruction* inst, std::vector buffers, + bool p2p_memcpy_enabled) : NcclCollectiveThunk(Thunk::kNcclAllGatherStart, thunk_info, nccl_api, IsSyncCollective(inst)), config_(impl::GetNcclAllGatherConfig(inst)), diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h index 4b5d63e6639e3e..aba61ffb6c6acb 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h @@ -39,7 +39,8 @@ class NcclAllGatherStartThunk : public NcclCollectiveThunk { public: NcclAllGatherStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllGatherInstruction* inst, - std::vector buffers); + std::vector buffers, + bool p2p_memcpy_enabled = false); static const char* GetHloOpName() { return "all-gather-start"; } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc index 532d8d03b6bb6a..884211a839d498 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc @@ -156,7 +156,8 @@ NcclAllReduceReduceScatterThunkBase::NcclAllReduceReduceScatterThunkBase( NcclAllReduceStartThunk::NcclAllReduceStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, - const HloAllReduceInstruction* inst, std::vector buffers) + const HloAllReduceInstruction* inst, std::vector buffers, + bool p2p_memcpy_enabled) : NcclAllReduceReduceScatterThunkBase( Thunk::kNcclAllReduceStart, thunk_info, nccl_api, impl::GetNcclAllReduceConfigInst(inst), std::move(buffers), @@ -189,7 +190,8 @@ absl::Status NcclAllReduceStartThunk::RunNcclCollective( NcclReduceScatterStartThunk::NcclReduceScatterStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, - const HloReduceScatterInstruction* inst, std::vector buffers) + const HloReduceScatterInstruction* inst, std::vector buffers, + bool p2p_memcpy_enabled) : NcclAllReduceReduceScatterThunkBase( Thunk::kNcclReduceScatterStart, thunk_info, nccl_api, impl::GetNcclAllReduceConfigInst(inst), std::move(buffers), diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h index 7d70edaf2dab56..f36727c5081a31 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h @@ -63,7 +63,8 @@ class NcclAllReduceStartThunk : public NcclAllReduceReduceScatterThunkBase { public: NcclAllReduceStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllReduceInstruction* inst, - std::vector buffers); + std::vector buffers, + bool p2p_memcpy_enabled = false); static const char* GetHloOpName() { return "all-reduce-start"; } @@ -87,7 +88,8 @@ class NcclReduceScatterStartThunk : public NcclAllReduceReduceScatterThunkBase { public: NcclReduceScatterStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, const HloReduceScatterInstruction* inst, - std::vector buffers); + std::vector buffers, + bool p2p_memcpy_enabled = false); static const char* GetHloOpName() { return "reduce-scatter-start"; } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc index 90bbd448c03fc2..cb280979027af4 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -41,7 +42,6 @@ limitations under the License. namespace xla { namespace gpu { - namespace { NcclAllToAllConfig GetNcclAllToAllConfig(const HloAllToAllInstruction* instr) { @@ -58,11 +58,12 @@ NcclAllToAllConfig GetNcclAllToAllConfig(const HloAllToAllInstruction* instr) { NcclAllToAllStartThunk::NcclAllToAllStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllToAllInstruction* instr, - std::vector buffers) + std::vector buffers, bool p2p_memcpy_enabled) : NcclCollectiveThunk(Thunk::kNcclAllToAllStart, thunk_info, nccl_api, IsSyncCollective(instr)), config_(GetNcclAllToAllConfig(instr)), - buffers_(std::move(buffers)) { + buffers_(std::move(buffers)), + p2p_memcpy_enabled_(p2p_memcpy_enabled) { CHECK_EQ(config_.config.operand_count, buffers_.size()); } @@ -92,6 +93,76 @@ NcclAllToAllStartThunk::NcclAllToAllStartThunk( return GetNcclAllToAllConfig(instr).config.group_mode; } +absl::Status NcclAllToAllStartThunk::Initialize( + const InitializeParams& params) { + TF_RETURN_IF_ERROR(NcclCollectiveThunk::Initialize(params)); + device_count_ = params.local_device_count; + CHECK_GT(device_count_, 0); + VLOG(5) << "Local device count: " << device_count_; + + if (is_local() && p2p_memcpy_enabled_) { + const NcclStreamId stream_id = nccl_stream_id(); + AsyncStreamKind stream_kind = GetAsyncStreamKind(); + TF_ASSIGN_OR_RETURN( + NcclCommHandleWrapper comm_wrapper, + GetNcclComm(*params.collective_params, *params.collective_cliques, + config().replica_groups, config().group_mode, stream_id, + stream_kind)); + TF_ASSIGN_OR_RETURN(int32_t num_participants, + nccl_api()->CommCount(comm_wrapper.comm_handle)); + + for (int i = 0; i < num_participants; ++i) { + for (int j = 0; j < num_participants; ++j) { + if (send_pointer_maps_.count(i) && send_pointer_maps_.at(i).count(j)) { + continue; + } + if (!params.stream->parent()->HostMemoryRegister( + &send_pointer_maps_[i][j], sizeof(void*))) { + VLOG(5) << "Registering host send pointer for memcpy failed."; + } + + if (!params.stream->parent()->HostMemoryRegister( + &receive_pointer_maps_[i][j], sizeof(void*))) { + VLOG(5) << "Registering host recv pointer for memcpy failed."; + } + } + } + } + + return absl::OkStatus(); +} + +absl::Status NcclAllToAllStartThunk::Cleanup(const CleanupParams& params) { + if (p2p_memcpy_enabled_) { + const NcclStreamId stream_id = nccl_stream_id(); + AsyncStreamKind stream_kind = GetAsyncStreamKind(); + TF_ASSIGN_OR_RETURN( + NcclCommHandleWrapper comm_wrapper, + GetNcclComm(*params.collective_params, *params.collective_cliques, + config().replica_groups, config().group_mode, stream_id, + stream_kind)); + TF_ASSIGN_OR_RETURN(int32_t num_participants, + nccl_api()->CommCount(comm_wrapper.comm_handle)); + + int local_id = params.executor->device_ordinal() % num_participants; + if (send_pointer_maps_.count(local_id)) { + for (auto& [id, value] : send_pointer_maps_[local_id]) { + if (!params.executor->HostMemoryUnregister((void*)value)) { + VLOG(5) << "Unregistering host send pointer for memcpy failed."; + } + } + } + if (receive_pointer_maps_.count(local_id)) { + for (auto& [id, value] : receive_pointer_maps_[local_id]) { + if (!params.executor->HostMemoryUnregister((void*)value)) { + VLOG(5) << "Unregistering host recv pointer for memcpy failed."; + } + } + } + } + return absl::OkStatus(); +} + absl::Status NcclAllToAllStartThunk::RunNcclCollective( const ExecuteParams& params, se::Stream& stream, NcclCommHandleWrapper comm_wrapper) { @@ -99,11 +170,39 @@ absl::Status NcclAllToAllStartThunk::RunNcclCollective( std::vector device_buffers, ConvertToDeviceBuffers(params, buffers_, config_.config.operand_element_type)); + TF_ASSIGN_OR_RETURN(int32_t num_participants, + nccl_api()->CommCount(comm_wrapper.comm_handle)); + + if (is_local() && p2p_memcpy_enabled_) { + int local_id = stream.parent()->device_ordinal() % num_participants; + return xla::gpu::RunMemCpyAllToAll( + nccl_api(), config_.has_split_dimension, device_buffers, stream, + comm_wrapper.comm_handle, send_pointer_maps_[local_id], + receive_pointer_maps_[local_id]); + } return xla::gpu::RunAllToAll(nccl_api(), config_.has_split_dimension, device_buffers, stream, comm_wrapper.comm_handle); } +AsyncStreamKind NcclAllToAllStartThunk::GetAsyncStreamKind() const { + return (is_local() && p2p_memcpy_enabled_) ? AsyncStreamKind::kMemCpyP2P + : AsyncStreamKind::kCollective; +} + +bool NcclAllToAllStartThunk::is_local() const { + for (const auto& replica_group : config_.config.replica_groups) { + const int64_t node_id = replica_group.replica_ids().at(0) / device_count_; + if (!absl::c_all_of(replica_group.replica_ids(), + [this, node_id](const int64_t rank) { + return rank / device_count_ == node_id; + })) { + return false; + } + } + return true; +} + absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, std::vector& buffers, se::Stream& stream, NcclApi::NcclCommHandle comm) { @@ -163,5 +262,84 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, return nccl_api->GroupEnd(); } +absl::Status RunMemCpyAllToAll( + NcclApi* nccl_api, bool has_split_dimension, + std::vector& buffers, se::Stream& stream, + NcclApi::NcclCommHandle comm, + absl::node_hash_map& send_pointer_map, + absl::node_hash_map& receive_pointer_map) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing mem-copy-all-to-all from device ordinal: " + << device_ordinal; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + + TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); + + // AllToAll can operate in two modes. Either it specifies a split dimension, + // in which case inputs are split and outputs concatenated in that dimension + // (here, we only support dimension 0), or it takes a list of inputs + // and produces a tuple of outputs. + if (has_split_dimension) { + for (DeviceBufferPair& buffer : buffers) { + TF_RET_CHECK(buffer.element_count % num_participants == 0) + << "Buffer was not an exact multiple of the number of participants."; + + size_t chunk_elements = buffer.element_count / num_participants; + + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + for (int peer = 0; peer < num_participants; ++peer) { + se::DeviceMemoryBase recv_slice = + NcclApi::Slice(buffer.destination_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); + send_pointer_map[peer] = (uint64_t)recv_slice.opaque(); + + TF_RETURN_IF_ERROR(nccl_api->SendPtrToPeer(&send_pointer_map[peer], + peer, comm, &stream)); + TF_RETURN_IF_ERROR(nccl_api->RecvPtrFromPeer(&receive_pointer_map[peer], + peer, comm, &stream)); + } + TF_RETURN_IF_ERROR(nccl_api->GroupEnd()); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + + for (int peer = 0; peer < num_participants; ++peer) { + se::DeviceMemoryBase send_slice = + NcclApi::Slice(buffer.source_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); + se::DeviceMemoryBase dst_addr = + se::DeviceMemoryBase((void*)receive_pointer_map[peer]); + TF_RETURN_IF_ERROR( + stream.MemcpyD2D(&dst_addr, send_slice, send_slice.size())); + } + } + } else { + TF_RET_CHECK(buffers.size() == num_participants) + << "Number of inputs didn't match the number of participants."; + + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + for (int peer = 0; peer < num_participants; ++peer) { + send_pointer_map[peer] = + (uint64_t)buffers[peer].destination_buffer.opaque(); + + TF_RETURN_IF_ERROR(nccl_api->SendPtrToPeer(&send_pointer_map[peer], peer, + comm, &stream)); + TF_RETURN_IF_ERROR(nccl_api->RecvPtrFromPeer(&receive_pointer_map[peer], + peer, comm, &stream)); + } + TF_RETURN_IF_ERROR(nccl_api->GroupEnd()); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + + for (int peer = 0; peer < num_participants; ++peer) { + // double buffer, exchange data with peer + se::DeviceMemoryBase dst_addr = + se::DeviceMemoryBase((void*)receive_pointer_map[peer]); + TF_RETURN_IF_ERROR(stream.MemcpyD2D(&dst_addr, + buffers[peer].source_buffer, + buffers[peer].source_buffer.size())); + } + } + return absl::OkStatus(); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h index 3bc8e2e78cb192..ed3056ec646789 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h @@ -19,12 +19,15 @@ limitations under the License. #include #include +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value_ref.h" namespace xla { namespace gpu { @@ -39,7 +42,7 @@ class NcclAllToAllStartThunk : public NcclCollectiveThunk { public: NcclAllToAllStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllToAllInstruction* instr, - std::vector buffers); + std::vector buffers, bool p2p_memcpy_enabled); // Returns whether the given instruction can be lowered to a nccl all-to-all // call. @@ -47,6 +50,10 @@ class NcclAllToAllStartThunk : public NcclCollectiveThunk { int64_t replica_count, int64_t partition_count); + absl::Status Initialize(const InitializeParams& params) override; + + absl::Status Cleanup(const CleanupParams& params) override; + static const char* GetHloOpName() { return "all-to-all-start"; } static CollectiveOpGroupMode GetGroupMode( @@ -58,15 +65,32 @@ class NcclAllToAllStartThunk : public NcclCollectiveThunk { se::Stream& stream, NcclCommHandleWrapper comm_wrapper) override; + AsyncStreamKind GetAsyncStreamKind() const override; + + bool is_local() const; + private: const NcclAllToAllConfig config_; const std::vector buffers_; + int64_t device_count_ = 1; + bool p2p_memcpy_enabled_ = false; + absl::node_hash_map> + send_pointer_maps_; + absl::node_hash_map> + receive_pointer_maps_; }; absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, std::vector& buffers, se::Stream& stream, NcclApi::NcclCommHandle comm); +absl::Status RunMemCpyAllToAll( + NcclApi* nccl_api, bool has_split_dimension, + std::vector& buffers, se::Stream& stream, + NcclApi::NcclCommHandle comm, + absl::node_hash_map& send_pointer_map, + absl::node_hash_map& receive_pointer_map); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc index 8b6bd101752e66..25275524578a75 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc @@ -332,10 +332,14 @@ class DefaultNcclApi final : public NcclApi { absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, se::Stream* stream) final; + absl::Status SendPtrToPeer(void* ptr, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final; absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, se::Stream* stream) final; + absl::Status RecvPtrFromPeer(void* ptr, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final; absl::StatusOr RegisterBuffer( NcclCommHandle comm, se::DeviceMemoryBase buffer) final; @@ -613,6 +617,17 @@ absl::Status DefaultNcclApi::Send(se::DeviceMemoryBase send_buffer, peer, Cast(comm), se::gpu::AsGpuStreamValue(stream))); } +absl::Status DefaultNcclApi::SendPtrToPeer(void* ptr, int32_t peer, + NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL RecvPtrFromPeer operation on device #%d; " + "peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), peer, comm, stream); + return XLA_NCCL_STATUS(ncclSend(ptr, 1, ncclUint64, peer, Cast(comm), + se::gpu::AsGpuStreamValue(stream))); +} + absl::Status DefaultNcclApi::Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, @@ -631,6 +646,18 @@ absl::Status DefaultNcclApi::Recv(se::DeviceMemoryBase recv_buffer, peer, Cast(comm), se::gpu::AsGpuStreamValue(stream))); } +absl::Status DefaultNcclApi::RecvPtrFromPeer(void* ptr, int32_t peer, + NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL RecvPtrFromPeer operation on device #%d; " + "peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), peer, comm, stream); + + return XLA_NCCL_STATUS(ncclRecv(ptr, 1, ncclUint64, peer, Cast(comm), + se::gpu::AsGpuStreamValue(stream))); +} + absl::StatusOr DefaultNcclApi::RegisterBuffer(NcclCommHandle comm, se::DeviceMemoryBase buffer) { diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.h b/third_party/xla/xla/service/gpu/runtime/nccl_api.h index 813a940052a36b..d44603f3c95838 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.h @@ -251,6 +251,10 @@ class NcclApi { virtual absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, se::Stream* stream) = 0; + // Send a pointer `ptr` to rank `peer`. + virtual absl::Status SendPtrToPeer(void* ptr, int32_t peer, + NcclCommHandle comm, + se::Stream* stream) = 0; // Receive data from rank `peer` into `recv_buff`. // @@ -258,6 +262,10 @@ class NcclApi { virtual absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, se::Stream* stream) = 0; + // Receive a pointer from rank `peer` into `ptr`. + virtual absl::Status RecvPtrFromPeer(void* ptr, int32_t peer, + NcclCommHandle comm, + se::Stream* stream) = 0; // Register `buffer` with communicator `comm` for zero-copy communication. // Returned handle can be used for future unregistration. diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc index c3934e02814d76..9cf030ad9fe5cb 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc @@ -149,11 +149,21 @@ class NcclApiStub final : public NcclApi { return UnimplementedError(); } + absl::Status SendPtrToPeer(void* ptr, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final { + return UnimplementedError(); + } + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t, NcclCommHandle, se::Stream*) final { return UnimplementedError(); } + absl::Status RecvPtrFromPeer(void* ptr, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final { + return UnimplementedError(); + } + absl::StatusOr RegisterBuffer( NcclCommHandle, se::DeviceMemoryBase) final { return UnimplementedError(); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h index 8bb3f2740320e0..22cd6af46359bb 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h @@ -56,10 +56,11 @@ enum class AsyncStreamKind : int64_t { kCollective = 0, // Stream for asynchronous collective ops. kP2P0 = 1, // One Stream for P2P Send and Recv ops. kP2P1 = 2, // Another Stream for P2P Send and Recv ops. + kMemCpyP2P = 3, // Stream for MemCpyP2P }; constexpr static int64_t kAsyncStreamTotal = - static_cast(AsyncStreamKind::kP2P1) + 1; + static_cast(AsyncStreamKind::kMemCpyP2P) + 1; // Assigns a unique ID to a stream for asynchronous or synchronous execution. // These IDs can be used, for example, to look up the NCCL communicator. diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc index 370dd1189acc5c..dd9da283791be2 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc @@ -37,7 +37,8 @@ namespace xla::gpu { NcclCollectiveBroadcastStartThunk::NcclCollectiveBroadcastStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, - const HloCollectiveBroadcastInstruction* instr, std::vector buffers) + const HloCollectiveBroadcastInstruction* instr, std::vector buffers, + bool p2p_memcpy_enabled) : NcclCollectiveThunk(Thunk::kNcclCollectiveBroadcastStart, thunk_info, nccl_api, IsSyncCollective(instr)), config_(GetNcclCollectiveConfig(instr, std::nullopt)), diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h index 4b19b785c025d7..14e32e1b4172cc 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h @@ -47,7 +47,7 @@ class NcclCollectiveBroadcastStartThunk : public NcclCollectiveThunk { NcclCollectiveBroadcastStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, const HloCollectiveBroadcastInstruction* instr, - std::vector buffers); + std::vector buffers, bool p2p_memcpy_enabled = false); protected: absl::Status RunNcclCollective(const ExecuteParams& params, diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 4b6d345bbc2228..1c3e09f8fb2889 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -400,6 +400,23 @@ class Thunk { bool mock_collectives = false); }; + //===--------------------------------------------------------------------===// + // CleanupParams + //===--------------------------------------------------------------------===// + + // Parameters passed to Cleanup. Before returning from executable execution, + // thunks may need to clean up any resource allocated or registered through + // runtime APIs. + struct CleanupParams { + se::StreamExecutor* executor = nullptr; + + // Parameters for executing collective operations. + CollectiveExecuteParams* collective_params = nullptr; + + // Collective cliques acquired based on resource requests. + CollectiveCliques* collective_cliques = nullptr; + }; + //===--------------------------------------------------------------------===// // The hlo_instruction argument is meant to be the instruction this thunk was @@ -444,6 +461,14 @@ class Thunk { // Precondition: Initialize(initialize_params) has been called. virtual absl::Status ExecuteOnStream(const ExecuteParams& params) = 0; + // Cleans up any resources after thunk execution. + // + // This may be called multiple times. Its main purpose is to free up + // any resources occupied after initialization and execution. + virtual absl::Status Cleanup(const CleanupParams& params) { + return absl::OkStatus(); + } + static absl::string_view KindToString(Thunk::Kind kind); ExecutionStreamId execution_stream_id() const { return execution_stream_id_; } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 1308272ff09f89..39f0d211e1611a 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -1264,6 +1264,30 @@ void GpuDriver::HostDeallocate(Context* context, void* location) { } } +bool GpuDriver::HostRegister(Context* context, void* location, uint64_t bytes) { + ScopedActivateContext activation(context); + // "Portable" memory is visible to all CUDA contexts. Safe for our use model. + auto status = cuda::ToStatus( + cuMemHostRegister(location, bytes, CU_MEMHOSTREGISTER_PORTABLE)); + if (!status.ok()) { + LOG(ERROR) << "error registering host memory at " << location << ": " + << status; + return false; + } + return true; +} + +bool GpuDriver::HostUnregister(Context* context, void* location) { + ScopedActivateContext activation(context); + auto status = cuda::ToStatus(cuMemHostUnregister(location)); + if (!status.ok()) { + LOG(ERROR) << "error unregistering host memory at " << location << ": " + << status; + return false; + } + return true; +} + int GpuDriver::GetGpuStreamPriority( Context* context, stream_executor::StreamPriority stream_priority) { ScopedActivateContext activation(context); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 475899fa6b1bb2..98441a9bdc0549 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -473,6 +473,19 @@ bool CudaExecutor::SynchronizeAllActivity() { return GpuDriver::SynchronizeContext(gpu_context()).ok(); } +bool CudaExecutor::HostMemoryRegister(void* location, uint64_t size) { + VLOG(1) << "Called StreamExecutor::HostMemoryRegister(data=" << location + << ")"; + + return GpuDriver::HostRegister(gpu_context(), location, size); +} + +bool CudaExecutor::HostMemoryUnregister(void* location) { + VLOG(1) << "Called StreamExecutor::HostUnregister(data=" << location << ")"; + + return GpuDriver::HostUnregister(gpu_context(), location); +} + absl::Status CudaExecutor::SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) { if (reinterpret_cast(location->opaque()) % 4 == 0 && diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h index e467c2a3d432be..18d3fff5c5d976 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h @@ -137,6 +137,9 @@ class CudaExecutor : public GpuExecutor { return GpuDriver::HostDeallocate(gpu_context(), location); } + bool HostMemoryRegister(void* location, uint64_t size) override; + bool HostMemoryUnregister(void* location) override; + absl::StatusOr GetPointerMemorySpace(const void* ptr) override { return GpuDriver::GetPointerMemorySpace( reinterpret_cast(const_cast(ptr))); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 447201e63796e1..2b299e544b307e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -131,6 +131,22 @@ class GpuDriver { // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management static void HostDeallocate(Context* context, void* location); + // Registers a memory region at location of size bytes via + // cuMemHostRegister/hipHostRegister. + // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223 + // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management + static bool HostRegister(Context* context, void* location, uint64_t bytes); + + // Unregisters a memory region that was previously registered at location via + // cuMemHostUnregister/hipHostUnregister. + // + // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g63f450c8125359be87b7623b1c0b2a14 + // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management + // + // TODO(leary) verify an error will be returned if the location wasn't + // previously registered. + static bool HostUnregister(Context* context, void* location); + // Queries the priority range and returns the corresponding integer value via // cuCtxGetStreamPriorityRange/hipDeviceGetStreamPriorityRange // diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index a0c3c48e521e30..fdc5c9dc7f303f 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -193,6 +193,11 @@ class StreamExecutor { virtual absl::Status SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) = 0; + virtual bool HostMemoryUnregister(void* location) { return false; }; + virtual bool HostMemoryRegister(void* location, uint64_t size) { + return false; + }; + // Blocks the caller while "size" bytes are copied to the given location in // device memory. virtual absl::Status SynchronousMemcpy(DeviceMemoryBase* device_dst, diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 4fcfdfd3391745..cdcb82919b67f8 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -449,6 +449,46 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithSplitDim) { LiteralTestUtil::ExpectR1Equal({15, 16}, results[1]); } +TEST_F(CollectiveOpsTestE2E, AsyncAllToAllMemCpy) { + const absl::string_view kModuleStr = R"( + HloModule test + ENTRY test_computation { + id = u32[] replica-id() + id2 = u32[2, 2] broadcast(id), dimensions={} + a0 = u32[2, 2] constant({{10, 15}, {20, 25}}) + a1 = u32[2, 2] add(id2, a0) + all2all = u32[2, 2] all-to-all(a1), dimensions={0} + ROOT out = u32[4] reshape(all2all) + } + )"; + const int64_t kNumReplicas = 2; + + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_use_memcpy_local_p2p(true); + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + CreateExecutable(std::move(module), /*run_hlo_passes=*/true)); + ASSERT_TRUE(executable->has_module()); + HloModule* executable_module = &executable->module(); + + // Verify that the all-to-all is not decomposed into a tuple all-to-all. + const HloInstruction* all_to_all = + FindInstruction(executable_module, HloOpcode::kAllToAll); + EXPECT_THAT(all_to_all, op::Shape("u32[2, 2]")); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(executable.get(), kNumReplicas)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16}, results[0]); + LiteralTestUtil::ExpectR1Equal({20, 25, 21, 26}, results[1]); +} + XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithoutSplitDim) { const absl::string_view kModuleStr = R"( HloModule test From 5747695c1bfa941d86dccfe356e399f9c9719b43 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Wed, 25 Sep 2024 06:31:14 -0700 Subject: [PATCH 256/483] #sdy rename custom calls during sdy round tripping of ManualComputationOp. PiperOrigin-RevId: 678676520 --- .../xla/xla/service/spmd/shardy/constants.h | 10 ++++++ .../shardy/sdy_round_trip/shard_map_export.cc | 4 +-- .../shardy/sdy_round_trip/shard_map_import.cc | 4 +-- .../test/sdy_round_trip_shard_map_export.mlir | 32 +++++++++---------- .../test/sdy_round_trip_shard_map_import.mlir | 32 +++++++++---------- ...y_round_trip_shard_map_import_failure.mlir | 8 ++--- 6 files changed, 50 insertions(+), 40 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/constants.h b/third_party/xla/xla/service/spmd/shardy/constants.h index f9bbb8a0da789e..0923e1e22e5429 100644 --- a/third_party/xla/xla/service/spmd/shardy/constants.h +++ b/third_party/xla/xla/service/spmd/shardy/constants.h @@ -95,6 +95,16 @@ inline constexpr llvm::StringRef kManualAxes = "xla.sdy.manual_axes"; inline constexpr llvm::StringRef kManualComputationBodyFuncName = "xla.sdy.manual_computation_body"; +// The target name of the custom call that changes operands from global to local +// shape during Shardy round tripping. +inline constexpr llvm::StringRef kGlobalToLocalShapeCallTargetName = + "xla.sdy.GlobalToLocalShape"; + +// The target name of the custom call that changes results from local to global +// shape during Shardy round tripping. +inline constexpr llvm::StringRef kLocalToGlobalShapeCallTargetName = + "xla.sdy.LocalToGlobalShape"; + // The name of the global mesh. inline constexpr llvm::StringRef kGlobalMeshName = "mesh"; diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc index 35c88549e867ea..16d9397ed16ee7 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc @@ -92,7 +92,7 @@ class SdyRoundTripShardMapExportPass if (!operands.empty()) { fullToShard = rewriter.create( loc, manualCompBodyArgTypes, operands); - fullToShard.setCallTargetName(kSPMDFullToShardShapeCallTargetName); + fullToShard.setCallTargetName(kGlobalToLocalShapeCallTargetName); operands = fullToShard->getResults(); } @@ -109,7 +109,7 @@ class SdyRoundTripShardMapExportPass if (!results.empty()) { auto shardToFull = rewriter.create( loc, manualComputation.getResultTypes(), callOp->getResults()); - shardToFull.setCallTargetName(kSPMDShardToFullShapeCallTargetName); + shardToFull.setCallTargetName(kLocalToGlobalShapeCallTargetName); results = shardToFull->getResults(); } sdy::inlineRegionAndConvertTerminatorOp( diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc index dec5bd2cbc84e1..d577a0820d0a94 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc @@ -94,7 +94,7 @@ class ManualComputationPattern : public OpConversionPattern { operands = fullToShard->getOperands(); CHECK(fullToShard); CHECK(fullToShard.getCallTargetName() == - kSPMDFullToShardShapeCallTargetName); + kGlobalToLocalShapeCallTargetName); } mlir::TypeRange resultTypes = callOp->getResultTypes(); stablehlo::CustomCallOp shardToFull; @@ -104,7 +104,7 @@ class ManualComputationPattern : public OpConversionPattern { shardToFull = mlir::cast( *callOp->getResult(0).getUsers().begin()); CHECK(shardToFull.getCallTargetName() == - kSPMDShardToFullShapeCallTargetName); + kLocalToGlobalShapeCallTargetName); resultTypes = shardToFull->getResultTypes(); } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir index aefca62fc88b6d..ffb749919eca2d 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir @@ -5,14 +5,14 @@ sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]> // CHECK-LABEL: func @single_manual_comp func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) { - // CHECK-NEXT: %[[FULL_TO_SHARD:.*]]:2 = stablehlo.custom_call @SPMDFullToShardShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]]:2 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body(%[[FULL_TO_SHARD]]#0, %[[FULL_TO_SHARD]]#1) // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} // CHECK-SAME: : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_0, [{"a"}, {"b"}]>, <@mesh_0, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_0, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { %1 = stablehlo.add %arg2, %arg2 : tensor<2x8xf32> @@ -32,22 +32,22 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.shard // CHECK-LABEL: func @manual_comp_using_another func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"b"}]>}) { - // CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[SHMAP_0:.*]] = call @local_xla.sdy.manual_computation_body_0(%[[FULL_TO_SHARD_0]]) // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} // CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_0:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP_0]]) : (tensor<2x8xf32>) -> tensor<8x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[SHARD_TO_FULL_0]]) : (tensor<8x8xf32>) -> tensor<8x4xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_0:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP_0]]) : (tensor<2x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%[[SHARD_TO_FULL_0]]) : (tensor<8x8xf32>) -> tensor<8x4xf32> // CHECK-NEXT: %[[SHMAP_1:.*]] = call @local_xla.sdy.manual_computation_body_1(%[[FULL_TO_SHARD_1]]) // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} // CHECK-SAME: : (tensor<8x4xf32>) -> tensor<8x4xf32 - // CHECK-NEXT: %[[SHARD_TO_FULL_1:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP_1]]) : (tensor<8x4xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_1:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP_1]]) : (tensor<8x4xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL_1]] : tensor<8x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { sdy.return %arg1 : tensor<2x8xf32> @@ -61,14 +61,14 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy // CHECK-LABEL: func @nested_shmaps func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { - // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_3(%[[FULL_TO_SHARD]]) // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} // CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { @@ -82,14 +82,14 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@m // CHECK-LABEL: func @nested_shmaps_extra_op func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { - // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_5(%[[FULL_TO_SHARD]]) // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} // CHECK-SAME: (tensor<2x8xf32>) -> tensor<2x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { @@ -110,7 +110,7 @@ func.func @manual_computation_no_inputs() -> tensor<4xi64> { // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} // CHECK-SAME: () -> tensor<2xi64> - // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2xi64>) -> tensor<4xi64> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2xi64>) -> tensor<4xi64> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4xi64> %0 = sdy.manual_computation() in_shardings=[] out_shardings=[<@mesh_0, [{"b"}]>] manual_axes={"b"} () { %1 = stablehlo.constant dense<[2, 3]> : tensor<2xi64> @@ -121,7 +121,7 @@ func.func @manual_computation_no_inputs() -> tensor<4xi64> { // CHECK-LABEL: func @manual_computation_no_outputs func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { - // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> // CHECK-NEXT: call @local_xla.sdy.manual_computation_body_7(%[[FULL_TO_SHARD]]) // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", @@ -151,27 +151,27 @@ func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { // CHECK-NEXT: stablehlo.multiply %arg0, %arg0 : tensor<2x4xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32 -// CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_2(%[[FULL_TO_SHARD]]) // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} // CHECK-SAME: : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK-NEXT: stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x4xf32>) -> tensor<2x8xf32> +// CHECK-NEXT: stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x4xf32>) -> tensor<2x8xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK-NEXT: stablehlo.multiply %arg0, %arg0 : tensor<2x4xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> -// CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32 +// CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32 // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_4(%[[FULL_TO_SHARD]]) // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} // CHECK-SAME: : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[SHMAP]]) : (tensor<2x4xf32>) -> tensor<2x8xf32> +// CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x4xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_TO_FULL]], %[[SHARD_TO_FULL]] : tensor<2x8xf32> // CHECK-NEXT: return %[[ADD]] : tensor<2x8xf32> diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir index 4211756a8dcf02..c0ecce7e8b67de 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir @@ -20,9 +20,9 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) // CHECK-NEXT: sdy.return %[[REDUCE]] : tensor<2x32xf32> // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x32xf32> - %0:2 = stablehlo.custom_call @SPMDFullToShardShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + %0:2 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> return %2 : tensor<8x32xf32> } @@ -44,20 +44,20 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: sdy.return %arg1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[MAN_COMP_1]] : tensor<8x8xf32> - %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_0(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> + %3 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> %4 = call @local_xla.sdy.manual_computation_body_1(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x4xf32>) -> tensor<8x4xf32> - %5 = stablehlo.custom_call @SPMDShardToFullShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> + %5 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> return %5 : tensor<8x8xf32> } // CHECK-NOT: func @local_xla.sdy.manual_computation_body_3( func.func @local_xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> %1 = call @local_xla.sdy.manual_computation_body_2(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> return %2 : tensor<2x8xf32> } @@ -85,9 +85,9 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[MAN_COMP_1]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_3(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -110,9 +110,9 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[ADD]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_5(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -128,7 +128,7 @@ func.func @manual_computation_no_inputs() -> tensor<4xi64> { // CHECK-NEXT: } : () -> tensor<4xi64> // CHECK-NEXT: return %[[SHMAP]] : tensor<4xi64> %0 = call @local_xla.sdy.manual_computation_body_6() {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} : () -> tensor<2xi64> - %1 = stablehlo.custom_call @SPMDShardToFullShape(%0) : (tensor<2xi64>) -> tensor<4xi64> + %1 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%0) : (tensor<2xi64>) -> tensor<4xi64> return %1 : tensor<4xi64> } @@ -143,7 +143,7 @@ func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { // CHECK-NEXT: sdy.return // CHECK-NEXT: } : (tensor<4xi64>) -> () // CHECK-NEXT: return - %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> call @local_xla.sdy.manual_computation_body_7(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} : (tensor<2xi64>) -> () return } @@ -178,9 +178,9 @@ func.func @local_xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> te // CHECK-NOT: func @local_xla.sdy.manual_computation_body_5( func.func @local_xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> %1 = call @local_xla.sdy.manual_computation_body_4(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> %3 = stablehlo.add %2, %2 : tensor<2x8xf32> return %3 : tensor<2x8xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir index 7effeae63ae5be..ba5f28da7a7484 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir @@ -3,14 +3,14 @@ sdy.mesh @mesh = <["a"=2]> func.func @using_same_body_func(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) %1 = call @local_xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) - %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) - %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %3 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) // expected-error @+2 {{'func.call' op expected a unique FuncOp per @local_xla.sdy.manual_computation_body call}} // expected-error @+1 {{failed to legalize operation 'func.call'}} %4 = call @local_xla.sdy.manual_computation_body(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) - %5 = stablehlo.custom_call @SPMDShardToFullShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %5 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) return %5 : tensor<8x8xf32> } From d2f2a9c45c3350a6d0a83078af3ac991c9d8ba02 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Wed, 25 Sep 2024 06:42:30 -0700 Subject: [PATCH 257/483] [IFRT] Add pass for populating atom program metadata. PiperOrigin-RevId: 678679470 --- .../xla/xla/python/ifrt/ir/tests/BUILD | 1 + .../ifrt_populate_atom_program_metadata.mlir | 238 ++++++++++++++++ .../xla/xla/python/ifrt/ir/transforms/BUILD | 1 + ...frt_populate_atom_program_metadata_pass.cc | 266 ++++++++++++++++++ .../xla/python/ifrt/ir/transforms/passes.h | 3 + .../xla/python/ifrt/ir/transforms/passes.td | 62 ++++ 6 files changed, 571 insertions(+) create mode 100644 third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir create mode 100644 third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc diff --git a/third_party/xla/xla/python/ifrt/ir/tests/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/BUILD index 77ff4fa86c73fe..2aae556e8ffca9 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/BUILD @@ -13,6 +13,7 @@ lit_test_suite( "ifrt_duplicated_callee_elimination.mlir", "ifrt_merge_reshards.mlir", "ifrt_outline_atom_program_to_module.mlir", + "ifrt_populate_atom_program_metadata.mlir", "ifrt_reshard_to_copy_arrays.mlir", "ifrt_verify_donation.mlir", "ifrt_verify_sharding_specified.mlir", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir new file mode 100644 index 00000000000000..783ebb26bfdca1 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir @@ -0,0 +1,238 @@ +// RUN: ifrt-opt %s -ifrt-populate-atom-program-metadata -ifrt-duplicated-callee-elimination -split-input-file | FileCheck %s + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @populate_arg_sharding +module @populate_arg_sharding { + func.func @main(%arg0: !array) attributes {ifrt.function} { + // CHECK: ifrt.Call @[[CALLEE:.+]]::@main(%arg0) + %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0,1] : (!array) -> () + return + } + + // CHECK: module @[[CALLEE]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-NOT: ifrt + module @callee attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) { + return + } + } +} + +// ----- + +// CHECK-LABEL: @populate_result_sharding +module @populate_result_sharding { + func.func @main() attributes {ifrt.function} { + // CHECK: ifrt.Call @[[CALLEE:.+]]::@main() + %0, %ctrl_0 = ifrt.Call @callee::@main() on devices [0,1] + : () -> (!ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [1,0]>) + return + } + + // CHECK: module @[[CALLEE]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-NOT: ifrt + module @callee attributes {sym_visibility = "private"} { + func.func private @main() -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } + } +} + +// ----- + +// Verifies that a single module is populated with metadata if the input and +// output types are the same. +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @calls_outlined_to_single_module +module @calls_outlined_to_single_module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[CALLEE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: %[[OUT_1:.+]], %[[CTRL_1:.+]] = ifrt.Call @[[CALLEE]]::@main(%[[OUT_0]]) + %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[CALLEE]]::@main(%[[OUT_1]]) after %[[CTRL_1]] + %2, %ctrl_2 = ifrt.Call @add_one::@main(%1) after %ctrl_1 on devices [0,1] + : (!array) -> !array + return %1 : !array + } + + // CHECK: module @[[CALLEE]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-SAME: %arg0: tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-SAME: -> (tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-NOT: ifrt + module @add_one attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +// CHECK-LABEL: @call_twice_with_different_sharding +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0,1]> +module @call_twice_with_different_sharding { + func.func @main(%arg0: !array) -> !array_unspecified + attributes {ifrt.function} { + // CHECK: %[[OUTPUT:.+]], %{{.+}} = ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1]: (!array) -> !array + // CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%[[OUTPUT]]) + %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) on devices [0,1] + : (!array) -> !array_unspecified + return %1 : !array_unspecified + } + + // CHECK: module @[[CALLEE_1]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main(%arg0: tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-SAME: -> (tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_unspecified + // CHECK-DAG: ifrt.devices = #ifrt + + // CHECK: module @[[CALLEE_0]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main(%arg0: tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-SAME: -> (tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-NOT: ifrt + module @add_one attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @populate_io_alias +module @populate_io_alias { + func.func @main(%arg0: !array) attributes {ifrt.function} { + // CHECK: ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array) -> !array + // Verify that the module is cloned if io_aliases differ. + // CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%arg0) + %1, %ctrl_1 = ifrt.Call @callee::@main(%arg0) on devices [0,1] + : (!array) -> !array + return + } + + // CHECK: module @[[CALLEE_1]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main(%arg0: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-SAME: } + + // CHECK: module @[[CALLEE_0]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main(%arg0: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-DAG: tf.aliasing_output = 0 : i32 + // CHECK-SAME: } + module @callee attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + return %arg0: tensor<2x2xi32> + } + } +} + +// ----- + +!shared_array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @output_of_call_donated +module @output_of_call_donated { + func.func @main(%arg0: !shared_array) -> !shared_array + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0) on devices [0, 1] : + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + : (!shared_array) -> !shared_array + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.Call @[[CALLEE_1:.+]]::@main(%[[OUT_0]]) on devices [0, 1] {io_aliases = [array]} : + %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) on devices [0,1] + {io_aliases=[array]} : (!shared_array) -> !shared_array + return %1 : !shared_array + } + + // CHECK: module @[[CALLEE_1]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-SAME: %arg0: tensor<2x2xi32> + // CHECK-SAME: tf.aliasing_output = 0 : i32 + + // CHECK: module @[[CALLEE_0]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-SAME: %arg0: tensor<2x2xi32> + module @add_one attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index 882944a1602632..5db9b3bba088ab 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -32,6 +32,7 @@ cc_library( "ifrt_duplicated_callee_elimination_pass.cc", "ifrt_merge_reshards_pass.cc", "ifrt_outline_atom_program_to_module_pass.cc", + "ifrt_populate_atom_program_metadata_pass.cc", "ifrt_reshard_to_copy_arrays_pass.cc", "ifrt_verify_donation_pass.cc", "ifrt_verify_sharding_specified_pass.cc", diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc new file mode 100644 index 00000000000000..5b6e8268ac0902 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc @@ -0,0 +1,266 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/utils.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTPOPULATEATOMPROGRAMMETADATAPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +// Used for comparing CallOps without including control dependencies. +struct IfrtCallOpInfo : llvm::DenseMapInfo { + static unsigned getHashValue(xla::ifrt::CallOp call_op) { + llvm::hash_code hash = {}; + // Use `getInputs()/getOutputs()` instead of `getOperands()/getResults()` to + // ensure that the control dependencies are not included in the hash. + for (auto input_type : call_op.getInputs().getTypes()) { + hash = llvm::hash_combine(hash, input_type); + } + for (auto output_type : call_op.getOutputs().getTypes()) { + hash = llvm::hash_combine(hash, output_type); + } + for (mlir::NamedAttribute attr : call_op->getAttrs()) { + // Exclude `operandSegmentSizes` because its value changes depending on + // how many control dependencies a CallOp has. + if (attr.getName() == "operandSegmentSizes") { + continue; + } + hash = llvm::hash_combine(hash, attr); + } + return hash; + } + + static bool isEqual(xla::ifrt::CallOp lhs, xla::ifrt::CallOp rhs) { + if (lhs == rhs) { + return true; + } + if (lhs == getEmptyKey() || lhs == getTombstoneKey() || + rhs == getEmptyKey() || rhs == getTombstoneKey()) { + return false; + } + // Verify that the input and output types are the same. + if (lhs.getInputs().getTypes() != rhs.getInputs().getTypes()) { + return false; + } + if (lhs.getOutputs().getTypes() != rhs.getOutputs().getTypes()) { + return false; + } + mlir::NamedAttrList lattrs = lhs->getAttrDictionary(); + mlir::NamedAttrList rattrs = rhs->getAttrDictionary(); + lattrs.erase("operandSegmentSizes"); + rattrs.erase("operandSegmentSizes"); + // Verify that the attributes are the same. + return lattrs == rattrs; + } +}; + +// Populates the metadata on the atom program ModuleOp and `main` FuncOp. +mlir::LogicalResult PopulateMetadata(xla::ifrt::CallOp call_op, + mlir::ModuleOp module_op, + mlir::func::FuncOp callee_op, + mlir::OpBuilder& builder) { + module_op->setAttr(xla::ifrt::kIfrtNumDevicesAttrName, + builder.getI32IntegerAttr(call_op.getDevices().size())); + // Copy `ifrt.local_view` attribute if it exists. + if (call_op->hasAttrOfType( + xla::ifrt::kIfrtLocalViewAttrName)) { + module_op->setAttr(xla::ifrt::kIfrtLocalViewAttrName, + call_op->getAttr(xla::ifrt::kIfrtLocalViewAttrName)); + } + + // Attach sharding to inputs. + for (const auto& [i, input] : llvm::enumerate(call_op.getInputs())) { + const auto array = + mlir::dyn_cast_or_null(input.getType()); + if (array == nullptr) { + return call_op->emitOpError() + << "requires all inputs to be IfrtArrayType. Input #" << i << ": " + << input.getType(); + } + // It is faster to get all the attributes and add the new ones than + // setting the new attributes one-by-one. This is because the logic that + // sets an attribute converts the attr dict to a NamedAttrList, and then + // linearly searches for the attr. + llvm::SmallVector arg_attrs; + auto arg_attr_dict = callee_op.getArgAttrDict(i); + if (arg_attr_dict != nullptr) { + arg_attrs.append(arg_attr_dict.begin(), arg_attr_dict.end()); + } + arg_attrs.push_back( + builder.getNamedAttr(kIfrtShardingAttrName, array.getShardingAttr())); + arg_attrs.push_back( + builder.getNamedAttr(kIfrtDevicesAttrName, array.getDevicesAttr())); + callee_op.setArgAttrs(i, arg_attrs); + } + + // Attach sharding to outputs. + for (const auto& [i, output] : llvm::enumerate(call_op.getOutputs())) { + const auto array = + mlir::dyn_cast_or_null(output.getType()); + if (array == nullptr) { + return call_op->emitOpError() + << "requires all outputs to be IfrtArrayType. Input #" << i << ": " + << output.getType(); + } + llvm::SmallVector res_attrs; + auto res_attr_dict = callee_op.getResultAttrDict(i); + if (res_attr_dict != nullptr) { + res_attrs.append(res_attr_dict.begin(), res_attr_dict.end()); + } + res_attrs.push_back( + builder.getNamedAttr(kIfrtShardingAttrName, array.getShardingAttr())); + res_attrs.push_back( + builder.getNamedAttr(kIfrtDevicesAttrName, array.getDevicesAttr())); + callee_op.setResultAttrs(i, res_attrs); + } + + // Alias inputs. + for (const auto& raw_io_alias : + call_op.getIoAliases().getAsRange()) { + llvm::ArrayRef io_alias_as_array = raw_io_alias.asArrayRef(); + callee_op.setArgAttr(io_alias_as_array[0], "tf.aliasing_output", + builder.getI32IntegerAttr(io_alias_as_array[1])); + } + return mlir::success(); +} + +class IfrtPopulateAtomProgramMetadataPass + : public impl::IfrtPopulateAtomProgramMetadataPassBase< + IfrtPopulateAtomProgramMetadataPass> { + public: + void runOnOperation() override; +}; + +void IfrtPopulateAtomProgramMetadataPass::runOnOperation() { + mlir::MLIRContext& context = getContext(); + mlir::SymbolTableCollection symbol_table; + mlir::OpBuilder builder(&context); + mlir::func::FuncOp main_func = xla::ifrt::GetMainFunction(getOperation()); + + // Construct a map from callee `SymbolRefAttr` to the unique `CallOps` + // using it. This map is used to decide if a atom program module must be + // cloned before populating its metadata (i.e., used more than once). + llvm::DenseMap> + callee_call_count; + for (xla::ifrt::CallOp call_op : main_func.getOps()) { + callee_call_count[call_op.getCallee()].insert(call_op); + } + + llvm::DenseMap + visited_call_ops; + auto result = main_func.walk([&](xla::ifrt::CallOp call_op) + -> mlir::WalkResult { + mlir::func::FuncOp callee = call_op.getCalleeOp(symbol_table); + if (callee == nullptr) { + return call_op->emitOpError() + << "can't find callee `" << call_op.getCalleeAttr() << "`"; + } + auto callee_module = llvm::dyn_cast(callee->getParentOp()); + if (callee.getSymName() != xla::ifrt::kCalleeMainFuncName || + callee_module == nullptr) { + return call_op.emitOpError() + << "requires callee outlined as `" + << xla::ifrt::kCalleeMainFuncName + << "` function in a ModuleOp. Actual callee name: " + << callee.getSymName() + << ". Actual callee parent: " << callee->getParentOp()->getName(); + } + + if (auto call_op_it = visited_call_ops.find(call_op); + call_op_it != visited_call_ops.end()) { + call_op.setCalleeAttr(call_op_it->second); + } else { + callee_call_count[call_op.getCallee()].erase(call_op); + if (!callee_call_count[call_op.getCallee()].empty()) { + // Only clone the callee if it is used more than once. + mlir::ModuleOp cloned_module = callee_module.clone(); + mlir::func::FuncOp cloned_callee = + xla::ifrt::GetMainFunction(cloned_module); + cloned_callee.setPrivate(); + // Insert new cloned atom program module in the SymbolTable. + symbol_table + .getSymbolTable( + callee_module->getParentWithTrait()) + .insert(cloned_module); + mlir::SymbolRefAttr callee_attr = mlir::SymbolRefAttr::get( + cloned_module.getSymNameAttr(), + mlir::SymbolRefAttr::get(cloned_callee.getSymNameAttr())); + auto populate_result = + PopulateMetadata(call_op, cloned_module, cloned_callee, builder); + if (mlir::failed(populate_result)) { + return populate_result; + } + // Clone the CallOp because it will be modified next. + visited_call_ops[call_op.clone()] = callee_attr; + call_op.setCalleeAttr(callee_attr); + } else { + auto populate_result = PopulateMetadata( + call_op, callee_module, xla::ifrt::GetMainFunction(callee_module), + builder); + if (mlir::failed(populate_result)) { + return populate_result; + } + visited_call_ops[call_op.clone()] = call_op.getCalleeAttr(); + } + } + return mlir::WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + } + + // Erase the cloned CallOp because they were used only as keys of the map. + for (auto& [call_op, unused] : visited_call_ops) { + call_op.erase(); + } +} + +} // namespace + +std::unique_ptr> +CreateIfrtPopulateAtomProgramMetadataPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h index b1f528a2d1e2a6..285e7e2156d42a 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h @@ -49,6 +49,9 @@ CreateIfrtVerifyDonationPass(); std::unique_ptr> CreateIfrtVerifyShardingSpecifiedPass(); +std::unique_ptr> +CreateIfrtPopulateAtomProgramMetadataPass(); + std::unique_ptr> CreateIfrtReshardToCopyArraysPass(); diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td index a78b5729213c0a..a75fc059cb8e64 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td @@ -203,6 +203,68 @@ Verify that each `!ifrt.array` has sharding attribute that is not of type let constructor = "CreateIfrtVerifyShardingSpecifiedPass()"; } +def IfrtPopulateAtomProgramMetadataPass : + Pass<"ifrt-populate-atom-program-metadata", "mlir::ModuleOp"> { + let summary = "Populate metadata from call site to atom functions"; + let description = [{ +For every CallOp, this pass + 1. clones the callee's parent ModuleOp + 2. adds `ifrt.num_devices` attribute to the callee's parent ModuleOp + 2. attaches shardings and devices to the inputs and outputs of the callee's + main FuncOp + 3. attaches `tf.aliasing_output` attr to the callee main FuncOp's inputs + according to `io_aliases` +For CallOps with the same callee, a different clone will be created for each +CallOp, even if the populated metadata are the same. User may want to run +`ifrt-duplicated-callee-elimination` pass to dedup the clones. + +For example, the following code + +```mlir +%0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0, 1] + {io_aliases=[array]} + : (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]>, + !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> + +module @callee { + func.func private @main(%arg0: tensor<2x2xi32>, %arg1: tensor<4x4xi32>) + -> tensor<4x4xi32> {} +} +``` + +will be replaced by + +```mlir +%0, %ctrl_0 = ifrt.Call @new_callee::@main(%arg0) on devices [0, 1] + : (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> + +module @new_callee attributes {ifrt.num_devices = 2} { + func.func private @new_callee( + %arg0: tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>, + ifrt.devices = #ifrt}, + %arg1: tensor<4x4xi32> { + ifrt.sharding = #ifrt.sharding_param<1x2 to [0] on 2>, + ifrt.devices = #ifrt + tf.aliasing_output = 0 : i32}) + -> (tensor<4x4xi32> { + ifrt.sharding = #ifrt.sharding_param<1x2 to [0] on 2>, + ifrt.devices = #ifrt}) + {} +} +``` + }]; + + let constructor = "CreateIfrtPopulateAtomProgramMetadataPass()"; +} + def IfrtReshardToCopyArraysPass : Pass<"ifrt-reshard-to-copy-arrays", "mlir::ModuleOp"> { let summary = "Replaces `ifrt.Reshard` with `ifrt.CopyArrays`"; From 5f8699f8829003ce927d0d43ac72143a4d0df512 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Wed, 25 Sep 2024 07:15:15 -0700 Subject: [PATCH 258/483] [IFRT] Add IFRT IR pipeline for outlining atom programs to ModuleOps. PiperOrigin-RevId: 678688576 --- .../xla/xla/python/ifrt/ir/tests/ifrt-opt.cc | 1 + .../xla/xla/python/ifrt/ir/transforms/BUILD | 4 ++ .../xla/python/ifrt/ir/transforms/passes.cc | 61 +++++++++++++++++++ .../xla/python/ifrt/ir/transforms/passes.h | 19 ++++++ 4 files changed, 85 insertions(+) create mode 100644 third_party/xla/xla/python/ifrt/ir/transforms/passes.cc diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc b/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc index 78c88c3c779195..59e25a972e1051 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc @@ -27,6 +27,7 @@ int main(int argc, char** argv) { mlir::mhlo::registerAllMhloDialects(registry); registry.insert(); xla::ifrt::registerIfrtIrPasses(); + xla::ifrt::RegisterIfrtPassesAndPipelines(); xla::ifrt::AttachBuiltInSpmdExpansions(registry); diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index 5db9b3bba088ab..f9001e16f90d4d 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -36,6 +36,7 @@ cc_library( "ifrt_reshard_to_copy_arrays_pass.cc", "ifrt_verify_donation_pass.cc", "ifrt_verify_sharding_specified_pass.cc", + "passes.cc", "spmd_expandable_interface_verification_pass.cc", "spmd_expansion_pass.cc", ], @@ -44,6 +45,7 @@ cc_library( deps = [ ":passes_inc_gen", ":utils", + "//xla/mlir_hlo", "//xla/python/ifrt/ir", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -55,6 +57,8 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc b/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc new file mode 100644 index 00000000000000..8957182ff0afcc --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/ir/transforms/passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Passes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace xla { +namespace ifrt { + +void CreateIfrtToOutlinedAtomProgramsPipeline( + mlir::OpPassManager& pm, + const IfrtToOutlinedAtomProgramsPipelineOptions& options) { + // Passes that verify the correctness of the module. + pm.addPass(xla::ifrt::CreateSpmdExpandableInterfaceVerificationPass( + {{mlir::mhlo::MhloDialect::getDialectNamespace().str(), + mlir::stablehlo::StablehloDialect::getDialectNamespace().str()}})); + pm.addNestedPass( + xla::ifrt::CreateIfrtVerifyDonationPass()); + + // Passes that outline atom programs to modules and set their metadata. + pm.addPass(xla::ifrt::CreateIfrtOutlineAtomProgramToModulePass()); + pm.addPass(xla::ifrt::CreateIfrtPopulateAtomProgramMetadataPass()); + pm.addPass(xla::ifrt::CreateIfrtDuplicatedCalleeEliminationPass()); + pm.addPass(mlir::createSymbolDCEPass()); + + if (!options.propagate_shardings) { + pm.addPass(xla::ifrt::CreateIfrtVerifyShardingSpecifiedPass()); + // We can split ifrt.Reshard to ifrt.CopyArrays because all the shardings + // are specified. + pm.addPass(xla::ifrt::CreateIfrtReshardToCopyArraysPass()); + } +} + +void RegisterIfrtPassesAndPipelines() { + registerIfrtIrPasses(); + mlir::PassPipelineRegistration( + "ifrt-to-outlined-atom-programs-pipeline", + "Runs passes that do not require compilation-time information", + CreateIfrtToOutlinedAtomProgramsPipeline); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h index 285e7e2156d42a..ed58677e97cf63 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h @@ -18,9 +18,11 @@ limitations under the License. #include +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassOptions.h" namespace xla { namespace ifrt { @@ -59,6 +61,23 @@ CreateIfrtReshardToCopyArraysPass(); #define GEN_PASS_REGISTRATION #include "xla/python/ifrt/ir/transforms/passes.h.inc" // IWYU pragma: export +struct IfrtToOutlinedAtomProgramsPipelineOptions + : mlir::PassPipelineOptions { + Option propagate_shardings{ + *this, "propagate_shardings", + llvm::cl::desc("Whether to propagate shardings from executables for " + "unspecified shardings.")}; +}; + +// Creates pipeline of all the IFRT IR passes that do not require +// compilation-time information (e.g., device assignments). +void CreateIfrtToOutlinedAtomProgramsPipeline( + mlir::OpPassManager& pm, + const IfrtToOutlinedAtomProgramsPipelineOptions& options); + +// Registers passes and pipelines to ifrt-opt. +void RegisterIfrtPassesAndPipelines(); + } // namespace ifrt } // namespace xla From 913db661082f50273eeb82152907d3d0410f0c8b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 07:30:15 -0700 Subject: [PATCH 259/483] Reverts 0627b30129ae5b5023cc24e98f2563f927e761a5 PiperOrigin-RevId: 678692966 --- third_party/xla/xla/service/gpu/model/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index c793c7508c133f..7a4790b1c3adaf 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -540,9 +540,6 @@ cc_library( xla_cc_test( name = "indexing_map_serialization_test", srcs = ["indexing_map_serialization_test.cc"], - tags = [ - "nomsan", - ], deps = [ ":indexing_map_serialization", ":indexing_test_utils", From eb3b975ad6de62744c4bb1bdfe8bffe474dc2c94 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 25 Sep 2024 09:02:31 -0700 Subject: [PATCH 260/483] XLA:GPU Fix data type for normalization of FFTs for F64 and C128 types. The current normalization "scale_factor" is stored as a float, but it should be a double for kZ2ZInverse and kZ2D. This change caches the denominator as a uint64 in the plan, and then takes the inverse in the appropriate type only when applying the normalization. Reported in https://github.com/jax-ml/jax/issues/23827, when compared to numpy FFTs. I've tested that this fixes the issue reported there, but I'm not sure where would be best to add a test in XLA. PiperOrigin-RevId: 678723450 --- .../xla/xla/service/gpu/runtime/fft_thunk.cc | 14 +++++++------- .../xla/xla/service/gpu/runtime/fft_thunk.h | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc b/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc index 7d620522146acf..a493a20031005e 100644 --- a/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc @@ -175,13 +175,13 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, batch_size, &scratch_allocator); TF_RET_CHECK(fft_plan != nullptr) << "Failed to create cuFFT batched plan with scratch allocator"; - fft_plan_ptr->scale_factor = 1.0f / output_distance; + fft_plan_ptr->scale_factor = output_distance; } else { fft->UpdatePlanWithScratchAllocator(stream, fft_plan.get(), &scratch_allocator); } - float scale_factor = fft_plan_ptr->scale_factor; + uint64_t scale_factor = fft_plan_ptr->scale_factor; bool launch_ok; switch (fft_type) { @@ -205,7 +205,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), - complex64(scale_factor), &output_data, 1); + complex64(1.0f / scale_factor), &output_data, 1); } break; } @@ -217,7 +217,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), - complex128(scale_factor), &output_data, 1); + complex128(1.0 / scale_factor), &output_data, 1); } break; } @@ -241,7 +241,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), - scale_factor, &output_data, 1); + 1.0f / scale_factor, &output_data, 1); } break; } @@ -253,7 +253,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), - scale_factor, &output_data, 1); + 1.0 / scale_factor, &output_data, 1); } break; } @@ -264,7 +264,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, return absl::OkStatus(); } return Internal("Unable to launch fft with type %s", - FftTypeToString(fft_type)); + FftTypeToString(fft_type)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime/fft_thunk.h b/third_party/xla/xla/service/gpu/runtime/fft_thunk.h index ffd45ed804fda9..eedb75fb80fe6d 100644 --- a/third_party/xla/xla/service/gpu/runtime/fft_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/fft_thunk.h @@ -42,7 +42,7 @@ struct FftPlan { // protect each plan with a mutex. absl::Mutex mu; std::unique_ptr plan ABSL_GUARDED_BY(mu); - float scale_factor ABSL_GUARDED_BY(mu); + uint64_t scale_factor ABSL_GUARDED_BY(mu); }; class FftPlanCache { From d84b856d969201974a8bd1c5a93f75fbd08ea465 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Wed, 25 Sep 2024 09:19:50 -0700 Subject: [PATCH 261/483] Reverts 4aba9ebb1d0c2bb14415dd7a0ef2b02e30a19b15 PiperOrigin-RevId: 678729408 --- tensorflow/compiler/mlir/lite/core/api/BUILD | 32 - .../lite/core/api/flatbuffer_conversions.cc | 3186 ----------------- .../lite/core/api/flatbuffer_conversions.h | 440 --- .../core/api/flatbuffer_conversions_test.cc | 873 ----- tensorflow/compiler/mlir/lite/core/c/BUILD | 13 +- .../mlir/lite/core/c/builtin_op_data.h | 624 +--- .../mlir/lite/core/c/dimension_type.h | 38 + .../compiler/mlir/lite/core/c/tflite_types.h | 70 - .../utils/sparsity_format_converter.cc | 2 +- .../utils/sparsity_format_converter.h | 2 +- tensorflow/lite/c/BUILD | 1 - tensorflow/lite/core/c/BUILD | 4 - tensorflow/lite/core/c/builtin_op_data.h | 639 +++- tensorflow/lite/core/c/c_api_types.h | 40 +- tensorflow/lite/core/c/common.h | 6 + tensorflow/lite/java/BUILD | 2 - 16 files changed, 733 insertions(+), 5239 deletions(-) delete mode 100644 tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc delete mode 100644 tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h delete mode 100644 tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions_test.cc create mode 100644 tensorflow/compiler/mlir/lite/core/c/dimension_type.h delete mode 100644 tensorflow/compiler/mlir/lite/core/c/tflite_types.h diff --git a/tensorflow/compiler/mlir/lite/core/api/BUILD b/tensorflow/compiler/mlir/lite/core/api/BUILD index cd16efca57aad8..0aaca3928420d6 100644 --- a/tensorflow/compiler/mlir/lite/core/api/BUILD +++ b/tensorflow/compiler/mlir/lite/core/api/BUILD @@ -52,35 +52,3 @@ tf_cc_test( "@com_google_googletest//:gtest_main", ], ) - -cc_library( - name = "flatbuffer_conversions", - srcs = ["flatbuffer_conversions.cc"], - hdrs = [ - "flatbuffer_conversions.h", - ], - compatible_with = get_compatible_with_portable(), - copts = tflite_copts(), - deps = [ - "//tensorflow/compiler/mlir/lite/core/c:tflite_common", - "//tensorflow/compiler/mlir/lite/kernels/internal:compatibility_macros", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@flatbuffers//:runtime_cc", - ], -) - -tf_cc_test( - name = "flatbuffer_conversions_test", - size = "small", - srcs = ["flatbuffer_conversions_test.cc"], - deps = [ - ":flatbuffer_conversions", - "//tensorflow/compiler/mlir/lite/core/c:tflite_common", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "@com_google_googletest//:gtest_main", - "@flatbuffers//:runtime_cc", - ], -) diff --git a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc deleted file mode 100644 index 60db7412bd199f..00000000000000 --- a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc +++ /dev/null @@ -1,3186 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h" - -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "flatbuffers/vector.h" // from @flatbuffers -#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" -#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" - -/// Check whether the value `a` is true, and if not return -/// absl::InvalidArgumentError from the current function, while also -/// reporting the location of the error. -#define TFL_MIGRATION_ENSURE(a) \ - do { \ - if (!(a)) { \ - auto error_message = \ - absl::StrFormat("%s:%d %s was not true.", __FILE__, __LINE__, #a); \ - LOG(ERROR) << error_message; \ - return absl::InvalidArgumentError(error_message); \ - } \ - } while (0) - -#define TFL_MIGRATION_ENSURE_STATUS(a) \ - do { \ - const absl::Status s = (a); \ - if (!s.ok()) { \ - return s; \ - } \ - } while (0) - -namespace tflite_migration { -using absl::OkStatus; -using tflite::ActivationFunctionType; -using tflite::ActivationFunctionType_NONE; -using tflite::ActivationFunctionType_RELU; -using tflite::ActivationFunctionType_RELU6; -using tflite::ActivationFunctionType_RELU_N1_TO_1; -using tflite::ActivationFunctionType_SIGN_BIT; -using tflite::ActivationFunctionType_TANH; -using tflite::BuiltinOperator; -using tflite::BuiltinOperator_ABS; -using tflite::BuiltinOperator_ADD; -using tflite::BuiltinOperator_ADD_N; -using tflite::BuiltinOperator_ARG_MAX; -using tflite::BuiltinOperator_ARG_MIN; -using tflite::BuiltinOperator_ASSIGN_VARIABLE; -using tflite::BuiltinOperator_AVERAGE_POOL_2D; -using tflite::BuiltinOperator_BATCH_MATMUL; -using tflite::BuiltinOperator_BATCH_TO_SPACE_ND; -using tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM; -using tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN; -using tflite::BuiltinOperator_BITWISE_XOR; -using tflite::BuiltinOperator_BROADCAST_ARGS; -using tflite::BuiltinOperator_BROADCAST_TO; -using tflite::BuiltinOperator_BUCKETIZE; -using tflite::BuiltinOperator_CALL; -using tflite::BuiltinOperator_CALL_ONCE; -using tflite::BuiltinOperator_CAST; -using tflite::BuiltinOperator_CEIL; -using tflite::BuiltinOperator_COMPLEX_ABS; -using tflite::BuiltinOperator_CONCAT_EMBEDDINGS; -using tflite::BuiltinOperator_CONCATENATION; -using tflite::BuiltinOperator_CONV_2D; -using tflite::BuiltinOperator_CONV_3D; -using tflite::BuiltinOperator_CONV_3D_TRANSPOSE; -using tflite::BuiltinOperator_COS; -using tflite::BuiltinOperator_CUMSUM; -using tflite::BuiltinOperator_CUSTOM; -using tflite::BuiltinOperator_DELEGATE; -using tflite::BuiltinOperator_DENSIFY; -using tflite::BuiltinOperator_DEPTH_TO_SPACE; -using tflite::BuiltinOperator_DEPTHWISE_CONV_2D; -using tflite::BuiltinOperator_DEQUANTIZE; -using tflite::BuiltinOperator_DIV; -using tflite::BuiltinOperator_DYNAMIC_UPDATE_SLICE; -using tflite::BuiltinOperator_ELU; -using tflite::BuiltinOperator_EMBEDDING_LOOKUP; -using tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE; -using tflite::BuiltinOperator_EQUAL; -using tflite::BuiltinOperator_EXP; -using tflite::BuiltinOperator_EXPAND_DIMS; -using tflite::BuiltinOperator_FAKE_QUANT; -using tflite::BuiltinOperator_FILL; -using tflite::BuiltinOperator_FLOOR; -using tflite::BuiltinOperator_FLOOR_DIV; -using tflite::BuiltinOperator_FLOOR_MOD; -using tflite::BuiltinOperator_FULLY_CONNECTED; -using tflite::BuiltinOperator_GATHER; -using tflite::BuiltinOperator_GATHER_ND; -using tflite::BuiltinOperator_GELU; -using tflite::BuiltinOperator_GREATER; -using tflite::BuiltinOperator_GREATER_EQUAL; -using tflite::BuiltinOperator_HARD_SWISH; -using tflite::BuiltinOperator_HASHTABLE; -using tflite::BuiltinOperator_HASHTABLE_FIND; -using tflite::BuiltinOperator_HASHTABLE_IMPORT; -using tflite::BuiltinOperator_HASHTABLE_LOOKUP; -using tflite::BuiltinOperator_HASHTABLE_SIZE; -using tflite::BuiltinOperator_IF; -using tflite::BuiltinOperator_IMAG; -using tflite::BuiltinOperator_L2_NORMALIZATION; -using tflite::BuiltinOperator_L2_POOL_2D; -using tflite::BuiltinOperator_LEAKY_RELU; -using tflite::BuiltinOperator_LESS; -using tflite::BuiltinOperator_LESS_EQUAL; -using tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION; -using tflite::BuiltinOperator_LOG; -using tflite::BuiltinOperator_LOG_SOFTMAX; -using tflite::BuiltinOperator_LOGICAL_AND; -using tflite::BuiltinOperator_LOGICAL_NOT; -using tflite::BuiltinOperator_LOGICAL_OR; -using tflite::BuiltinOperator_LOGISTIC; -using tflite::BuiltinOperator_LSH_PROJECTION; -using tflite::BuiltinOperator_LSTM; -using tflite::BuiltinOperator_MATRIX_DIAG; -using tflite::BuiltinOperator_MATRIX_SET_DIAG; -using tflite::BuiltinOperator_MAX_POOL_2D; -using tflite::BuiltinOperator_MAXIMUM; -using tflite::BuiltinOperator_MEAN; -using tflite::BuiltinOperator_MINIMUM; -using tflite::BuiltinOperator_MIRROR_PAD; -using tflite::BuiltinOperator_MUL; -using tflite::BuiltinOperator_MULTINOMIAL; -using tflite::BuiltinOperator_NEG; -using tflite::BuiltinOperator_NON_MAX_SUPPRESSION_V4; -using tflite::BuiltinOperator_NON_MAX_SUPPRESSION_V5; -using tflite::BuiltinOperator_NOT_EQUAL; -using tflite::BuiltinOperator_ONE_HOT; -using tflite::BuiltinOperator_PACK; -using tflite::BuiltinOperator_PAD; -using tflite::BuiltinOperator_PADV2; -using tflite::BuiltinOperator_POW; -using tflite::BuiltinOperator_PRELU; -using tflite::BuiltinOperator_QUANTIZE; -using tflite::BuiltinOperator_RANDOM_STANDARD_NORMAL; -using tflite::BuiltinOperator_RANDOM_UNIFORM; -using tflite::BuiltinOperator_RANGE; -using tflite::BuiltinOperator_RANK; -using tflite::BuiltinOperator_READ_VARIABLE; -using tflite::BuiltinOperator_REAL; -using tflite::BuiltinOperator_REDUCE_ALL; -using tflite::BuiltinOperator_REDUCE_ANY; -using tflite::BuiltinOperator_REDUCE_MAX; -using tflite::BuiltinOperator_REDUCE_MIN; -using tflite::BuiltinOperator_REDUCE_PROD; -using tflite::BuiltinOperator_REDUCE_WINDOW; -using tflite::BuiltinOperator_RELU; -using tflite::BuiltinOperator_RELU6; -using tflite::BuiltinOperator_RELU_0_TO_1; -using tflite::BuiltinOperator_RELU_N1_TO_1; -using tflite::BuiltinOperator_RESHAPE; -using tflite::BuiltinOperator_RESIZE_BILINEAR; -using tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR; -using tflite::BuiltinOperator_REVERSE_SEQUENCE; -using tflite::BuiltinOperator_REVERSE_V2; -using tflite::BuiltinOperator_RFFT2D; -using tflite::BuiltinOperator_RIGHT_SHIFT; -using tflite::BuiltinOperator_RNN; -using tflite::BuiltinOperator_ROUND; -using tflite::BuiltinOperator_RSQRT; -using tflite::BuiltinOperator_SCATTER_ND; -using tflite::BuiltinOperator_SEGMENT_SUM; -using tflite::BuiltinOperator_SELECT; -using tflite::BuiltinOperator_SELECT_V2; -using tflite::BuiltinOperator_SHAPE; -using tflite::BuiltinOperator_SIN; -using tflite::BuiltinOperator_SKIP_GRAM; -using tflite::BuiltinOperator_SLICE; -using tflite::BuiltinOperator_SOFTMAX; -using tflite::BuiltinOperator_SPACE_TO_BATCH_ND; -using tflite::BuiltinOperator_SPACE_TO_DEPTH; -using tflite::BuiltinOperator_SPARSE_TO_DENSE; -using tflite::BuiltinOperator_SPLIT; -using tflite::BuiltinOperator_SPLIT_V; -using tflite::BuiltinOperator_SQRT; -using tflite::BuiltinOperator_SQUARE; -using tflite::BuiltinOperator_SQUARED_DIFFERENCE; -using tflite::BuiltinOperator_SQUEEZE; -using tflite::BuiltinOperator_STABLEHLO_ABS; -using tflite::BuiltinOperator_STABLEHLO_ADD; -using tflite::BuiltinOperator_STABLEHLO_AND; -using tflite::BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM; -using tflite::BuiltinOperator_STABLEHLO_CLAMP; -using tflite::BuiltinOperator_STABLEHLO_COMPARE; -using tflite::BuiltinOperator_STABLEHLO_COMPOSITE; -using tflite::BuiltinOperator_STABLEHLO_CONCATENATE; -using tflite::BuiltinOperator_STABLEHLO_CONVERT; -using tflite::BuiltinOperator_STABLEHLO_CONVOLUTION; -using tflite::BuiltinOperator_STABLEHLO_COSINE; -using tflite::BuiltinOperator_STABLEHLO_CUSTOM_CALL; -using tflite::BuiltinOperator_STABLEHLO_DIVIDE; -using tflite::BuiltinOperator_STABLEHLO_DOT_GENERAL; -using tflite::BuiltinOperator_STABLEHLO_DYNAMIC_SLICE; -using tflite::BuiltinOperator_STABLEHLO_DYNAMIC_UPDATE_SLICE; -using tflite::BuiltinOperator_STABLEHLO_EXPONENTIAL; -using tflite::BuiltinOperator_STABLEHLO_FLOOR; -using tflite::BuiltinOperator_STABLEHLO_GATHER; -using tflite::BuiltinOperator_STABLEHLO_IOTA; -using tflite::BuiltinOperator_STABLEHLO_LOG; -using tflite::BuiltinOperator_STABLEHLO_LOGISTIC; -using tflite::BuiltinOperator_STABLEHLO_MAXIMUM; -using tflite::BuiltinOperator_STABLEHLO_MINIMUM; -using tflite::BuiltinOperator_STABLEHLO_MULTIPLY; -using tflite::BuiltinOperator_STABLEHLO_NEGATE; -using tflite::BuiltinOperator_STABLEHLO_OR; -using tflite::BuiltinOperator_STABLEHLO_PAD; -using tflite::BuiltinOperator_STABLEHLO_POWER; -using tflite::BuiltinOperator_STABLEHLO_REDUCE; -using tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW; -using tflite::BuiltinOperator_STABLEHLO_REMAINDER; -using tflite::BuiltinOperator_STABLEHLO_RESHAPE; -using tflite::BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR; -using tflite::BuiltinOperator_STABLEHLO_RSQRT; -using tflite::BuiltinOperator_STABLEHLO_SCATTER; -using tflite::BuiltinOperator_STABLEHLO_SELECT; -using tflite::BuiltinOperator_STABLEHLO_SLICE; -using tflite::BuiltinOperator_STABLEHLO_SORT; -using tflite::BuiltinOperator_STABLEHLO_SUBTRACT; -using tflite::BuiltinOperator_STABLEHLO_TANH; -using tflite::BuiltinOperator_STABLEHLO_TRANSPOSE; -using tflite::BuiltinOperator_STABLEHLO_WHILE; -using tflite::BuiltinOperator_STRIDED_SLICE; -using tflite::BuiltinOperator_SUB; -using tflite::BuiltinOperator_SUM; -using tflite::BuiltinOperator_SVDF; -using tflite::BuiltinOperator_TANH; -using tflite::BuiltinOperator_TILE; -using tflite::BuiltinOperator_TOPK_V2; -using tflite::BuiltinOperator_TRANSPOSE; -using tflite::BuiltinOperator_TRANSPOSE_CONV; -using tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM; -using tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN; -using tflite::BuiltinOperator_UNIQUE; -using tflite::BuiltinOperator_UNPACK; -using tflite::BuiltinOperator_UNSORTED_SEGMENT_MAX; -using tflite::BuiltinOperator_UNSORTED_SEGMENT_MIN; -using tflite::BuiltinOperator_UNSORTED_SEGMENT_PROD; -using tflite::BuiltinOperator_VAR_HANDLE; -using tflite::BuiltinOperator_WHERE; -using tflite::BuiltinOperator_WHILE; -using tflite::BuiltinOperator_ZEROS_LIKE; -using tflite::CombinerType; -using tflite::CombinerType_MEAN; -using tflite::CombinerType_SQRTN; -using tflite::CombinerType_SUM; -using tflite::LSHProjectionType; -using tflite::LSHProjectionType_DENSE; -using tflite::LSHProjectionType_SPARSE; -using tflite::MirrorPadMode; -using tflite::MirrorPadMode_REFLECT; -using tflite::MirrorPadMode_SYMMETRIC; -using tflite::Operator; -using tflite::Padding; -using tflite::Padding_SAME; -using tflite::Padding_VALID; -using tflite::ReduceWindowFunction_ADD; -using tflite::ReduceWindowFunction_ALL; -using tflite::ReduceWindowFunction_ANY; -using tflite::ReduceWindowFunction_MAXIMUM; -using tflite::ReduceWindowFunction_MINIMUM; -using tflite::ReduceWindowFunction_MUL; -using tflite::ReduceWindowFunction_UNSUPPORTED; -using tflite::RngAlgorithm; -using tflite::RngAlgorithm_DEFAULT; -using tflite::RngAlgorithm_PHILOX; -using tflite::RngAlgorithm_THREEFRY; -using tflite::TensorType; -using tflite::TensorType_BFLOAT16; -using tflite::TensorType_BOOL; -using tflite::TensorType_COMPLEX128; -using tflite::TensorType_COMPLEX64; -using tflite::TensorType_FLOAT16; -using tflite::TensorType_FLOAT32; -using tflite::TensorType_FLOAT64; -using tflite::TensorType_INT16; -using tflite::TensorType_INT32; -using tflite::TensorType_INT4; -using tflite::TensorType_INT64; -using tflite::TensorType_INT8; -using tflite::TensorType_RESOURCE; -using tflite::TensorType_STRING; -using tflite::TensorType_UINT16; -using tflite::TensorType_UINT32; -using tflite::TensorType_UINT64; -using tflite::TensorType_UINT8; -using tflite::TensorType_VARIANT; -; -using tflite::AddOptions; -using tflite::ArgMaxOptions; -using tflite::ArgMinOptions; -using tflite::BuiltinOperator_ATAN2; -using tflite::BuiltinOperator_BITCAST; -using tflite::BuiltinOperator_CONCAT_EMBEDDINGS; -using tflite::BuiltinOperator_DILATE; -using tflite::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES; -using tflite::BuiltinOperator_SIGN; -using tflite::BuiltinOperator_STABLEHLO_CBRT; -using tflite::BuiltinOperator_STABLEHLO_SHIFT_LEFT; -using tflite::BuiltinOperator_UNSORTED_SEGMENT_SUM; -using tflite::BuiltinOperator_WHERE; -using tflite::CallOnceOptions; -using tflite::ConcatenationOptions; -using tflite::Conv2DOptions; -using tflite::DepthwiseConv2DOptions; -using tflite::FullyConnectedOptions; -using tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; -using tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; -using tflite::IfOptions; -using tflite::L2NormOptions; -using tflite::LSTMKernelType_BASIC; -using tflite::LSTMKernelType_FULL; -using tflite::MirrorPadOptions; -using tflite::MulOptions; -using tflite::PackOptions; -using tflite::Pool2DOptions; -using tflite::ReducerOptions; -using tflite::ReshapeOptions; -using tflite::ResizeBilinearOptions; -using tflite::ResizeNearestNeighborOptions; -using tflite::ShapeOptions; -using tflite::SoftmaxOptions; -using tflite::SplitOptions; -using tflite::SplitVOptions; -using tflite::SqueezeOptions; -using tflite::StableHLOCompositeOptions; -using tflite::StablehloGatherOptions; -using tflite::StablehloPadOptions; -using tflite::StablehloReduceWindowOptions; -using tflite::StablehloRngBitGeneratorOptions; -using tflite::StablehloScatterOptions; -using tflite::StridedSliceOptions; -using tflite::SubOptions; -using tflite::SVDFOptions; -using tflite::TransposeConvOptions; -using tflite::UnpackOptions; -using tflite::VarHandleOptions; -using tflite::WhileOptions; - -namespace { - -// Utility class for safely allocating POD data. This is useful for avoiding -// leaks in cases where op params are allocated but fail to propagate to the -// parsed op data (e.g., when model parameters are invalid). -class SafeBuiltinDataAllocator { - public: - class BuiltinDataDeleter { - public: - explicit BuiltinDataDeleter(BuiltinDataAllocator* allocator) - : allocator_(allocator) {} - - void operator()(void* data) { allocator_->Deallocate(data); } - - private: - BuiltinDataAllocator* allocator_; - }; - - template - using BuiltinDataPtr = std::unique_ptr; - - explicit SafeBuiltinDataAllocator(BuiltinDataAllocator* allocator) - : allocator_(allocator) {} - - template - BuiltinDataPtr Allocate() { - return BuiltinDataPtr(allocator_->AllocatePOD(), - BuiltinDataDeleter(allocator_)); - } - - private: - BuiltinDataAllocator* allocator_; -}; - -// All the Parse functions take some pointers as params and this function has -// the common DCHECKs to catch if any of those are nullptr. -void CheckParsePointerParams(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - TFLITE_DCHECK(op != nullptr); - TFLITE_DCHECK(allocator != nullptr); - TFLITE_DCHECK(builtin_data != nullptr); -} - -// Copies the contents from the flatbuffer int vector `flatbuffer` into the -// int array `buffer`. `flat_vector` and `buffer` represent the same -// configuration operation for a given operation. -template -static absl::Status FlatBufferIntVectorToArray( - int max_size_of_buffer, const flatbuffers::Vector* flat_vector, - DataType* buffer, const char* op_name) { - if (!flat_vector) { - auto error_message = absl::StrFormat( - "Input array not provided for operation '%s'.\n", op_name); - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } else { - size_t num_dimensions = flat_vector->size(); - if (num_dimensions > max_size_of_buffer / sizeof(DataType)) { - auto error_message = absl::StrFormat( - "Found too many dimensions in the input array of operation '%s'.\n", - op_name); - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } else { - for (size_t i = 0; i < num_dimensions; ++i) { - buffer[i] = flat_vector->Get(i); - } - } - } - return OkStatus(); -} - -// Converts the flatbuffer activation to what is used at runtime. -TfLiteFusedActivation ConvertActivation(ActivationFunctionType activation) { - switch (activation) { - case ActivationFunctionType_NONE: - return kTfLiteActNone; - case ActivationFunctionType_RELU: - return kTfLiteActRelu; - case ActivationFunctionType_RELU_N1_TO_1: - return kTfLiteActReluN1To1; - case ActivationFunctionType_RELU6: - return kTfLiteActRelu6; - case ActivationFunctionType_TANH: - return kTfLiteActTanh; - case ActivationFunctionType_SIGN_BIT: - return kTfLiteActSignBit; - } - return kTfLiteActNone; -} - -TfLitePadding ConvertPadding(Padding padding) { - switch (padding) { - case Padding_SAME: - return kTfLitePaddingSame; - case Padding_VALID: - return kTfLitePaddingValid; - } - return kTfLitePaddingUnknown; -} - -// Converts the flatbuffer mirror padding enum to what is used at runtime. -TfLiteMirrorPaddingMode ConvertMirrorPadding(MirrorPadMode padding) { - switch (padding) { - case MirrorPadMode_REFLECT: - return kTfLiteMirrorPaddingReflect; - case MirrorPadMode_SYMMETRIC: - return kTfLiteMirrorPaddingSymmetric; - } - return kTfLiteMirrorPaddingUnknown; -} - -TfLiteRngAlgorithm ConvertRngAlgorithm(RngAlgorithm algorithm) { - switch (algorithm) { - case RngAlgorithm_THREEFRY: - return kTfLiteRngAlgorithmThreefry; - case RngAlgorithm_PHILOX: - return kTfLiteRngAlgorithmPhilox; - case RngAlgorithm_DEFAULT: - return kTfLiteRngAlgorithmDefault; - } - return kTfLiteRngAlgorithmUnknown; -} - -#ifndef TF_LITE_STATIC_MEMORY -absl::Status ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, - BuiltinDataAllocator* allocator, - void** builtin_data) { - auto parseLSHProjectionType = [](LSHProjectionType type) { - switch (type) { - case LSHProjectionType_SPARSE: - return kTfLiteLshProjectionSparse; - case LSHProjectionType_DENSE: - return kTfLiteLshProjectionDense; - default: - return kTfLiteLshProjectionUnknown; - } - }; - auto parseCombinerType = [](CombinerType type) { - switch (type) { - case CombinerType_MEAN: - return kTfLiteCombinerTypeMean; - case CombinerType_SQRTN: - return kTfLiteCombinerTypeSqrtn; - case CombinerType_SUM: - default: - return kTfLiteCombinerTypeSum; - } - }; - - SafeBuiltinDataAllocator safe_allocator(allocator); - *builtin_data = nullptr; - switch (op_type) { - case BuiltinOperator_ABS: { - return ParseAbs(op, allocator, builtin_data); - } - - case BuiltinOperator_ADD: { - return ParseAdd(op, allocator, builtin_data); - } - - case BuiltinOperator_ADD_N: { - return ParseAddN(op, allocator, builtin_data); - } - - case BuiltinOperator_ARG_MAX: { - return ParseArgMax(op, allocator, builtin_data); - } - - case BuiltinOperator_ARG_MIN: { - return ParseArgMin(op, allocator, builtin_data); - } - - case BuiltinOperator_ASSIGN_VARIABLE: { - return ParseAssignVariable(op, allocator, builtin_data); - } - - case BuiltinOperator_AVERAGE_POOL_2D: { - return ParsePool(op, allocator, builtin_data); - } - - case BuiltinOperator_BATCH_MATMUL: { - return ParseBatchMatMul(op, allocator, builtin_data); - } - - case BuiltinOperator_BATCH_TO_SPACE_ND: { - return ParseBatchToSpaceNd(op, allocator, builtin_data); - } - - case BuiltinOperator_BROADCAST_ARGS: { - return ParseBroadcastArgs(op, allocator, builtin_data); - } - - case BuiltinOperator_BROADCAST_TO: { - return ParseBroadcastTo(op, allocator, builtin_data); - } - - case BuiltinOperator_CALL_ONCE: { - return ParseCallOnce(op, allocator, builtin_data); - } - - case BuiltinOperator_CEIL: { - return ParseCeil(op, allocator, builtin_data); - } - - case BuiltinOperator_CONCATENATION: { - return ParseConcatenation(op, allocator, builtin_data); - } - - case BuiltinOperator_CONV_2D: { - return ParseConv2D(op, allocator, builtin_data); - } - - case BuiltinOperator_CUMSUM: { - return ParseCumsum(op, allocator, builtin_data); - } - - case BuiltinOperator_DEPTH_TO_SPACE: { - return ParseDepthToSpace(op, allocator, builtin_data); - } - - case BuiltinOperator_DEPTHWISE_CONV_2D: { - return ParseDepthwiseConv2D(op, allocator, builtin_data); - } - - case BuiltinOperator_DEQUANTIZE: { - return ParseDequantize(op, allocator, builtin_data); - } - - case BuiltinOperator_DIV: { - return ParseDiv(op, allocator, builtin_data); - } - - case BuiltinOperator_ELU: { - return ParseElu(op, allocator, builtin_data); - } - - case BuiltinOperator_EMBEDDING_LOOKUP: { - return ParseEmbeddingLookup(op, allocator, builtin_data); - } - - case BuiltinOperator_EXP: { - return ParseExp(op, allocator, builtin_data); - } - - case BuiltinOperator_EXPAND_DIMS: { - return ParseExpandDims(op, allocator, builtin_data); - } - - case BuiltinOperator_FILL: { - return ParseFill(op, allocator, builtin_data); - } - - case BuiltinOperator_FLOOR: { - return ParseFloor(op, allocator, builtin_data); - } - - case BuiltinOperator_FLOOR_DIV: { - return ParseFloorDiv(op, allocator, builtin_data); - } - - case BuiltinOperator_FLOOR_MOD: { - return ParseFloorMod(op, allocator, builtin_data); - } - - case BuiltinOperator_FULLY_CONNECTED: { - return ParseFullyConnected(op, allocator, builtin_data); - } - - case BuiltinOperator_GATHER_ND: { - return ParseGatherNd(op, allocator, builtin_data); - } - - case BuiltinOperator_GREATER: { - return ParseGreater(op, allocator, builtin_data); - } - - case BuiltinOperator_GREATER_EQUAL: { - return ParseGreaterEqual(op, allocator, builtin_data); - } - - case BuiltinOperator_HARD_SWISH: { - return ParseHardSwish(op, allocator, builtin_data); - } - - case BuiltinOperator_L2_NORMALIZATION: { - return ParseL2Normalization(op, allocator, builtin_data); - } - - case BuiltinOperator_L2_POOL_2D: { - return ParsePool(op, allocator, builtin_data); - } - - case BuiltinOperator_LEAKY_RELU: { - return ParseLeakyRelu(op, allocator, builtin_data); - } - - case BuiltinOperator_LESS: { - return ParseLess(op, allocator, builtin_data); - } - - case BuiltinOperator_LESS_EQUAL: { - return ParseLessEqual(op, allocator, builtin_data); - } - - case BuiltinOperator_LOG: { - return ParseLog(op, allocator, builtin_data); - } - - case BuiltinOperator_LOGICAL_AND: { - return ParseLogicalAnd(op, allocator, builtin_data); - } - - case BuiltinOperator_LOGICAL_NOT: { - return ParseLogicalNot(op, allocator, builtin_data); - } - - case BuiltinOperator_LOGICAL_OR: { - return ParseLogicalOr(op, allocator, builtin_data); - } - - case BuiltinOperator_LOGISTIC: { - return ParseLogistic(op, allocator, builtin_data); - } - - case BuiltinOperator_LOG_SOFTMAX: { - return ParseLogSoftmax(op, allocator, builtin_data); - } - - case BuiltinOperator_LSTM: { - return ParseLSTM(op, allocator, builtin_data); - } - - case BuiltinOperator_MAXIMUM: { - return ParseMaximum(op, allocator, builtin_data); - } - - case BuiltinOperator_MAX_POOL_2D: { - return ParsePool(op, allocator, builtin_data); - } - - case BuiltinOperator_MIRROR_PAD: { - return ParseMirrorPad(op, allocator, builtin_data); - } - - case BuiltinOperator_MEAN: { - return ParseReducer(op, allocator, builtin_data); - } - - case BuiltinOperator_MINIMUM: { - return ParseMinimum(op, allocator, builtin_data); - } - - case BuiltinOperator_MUL: { - return ParseMul(op, allocator, builtin_data); - } - - case BuiltinOperator_NEG: { - return ParseNeg(op, allocator, builtin_data); - } - - case BuiltinOperator_NOT_EQUAL: { - return ParseNotEqual(op, allocator, builtin_data); - } - - case BuiltinOperator_PACK: { - return ParsePack(op, allocator, builtin_data); - } - - case BuiltinOperator_PAD: { - return ParsePad(op, allocator, builtin_data); - } - - case BuiltinOperator_PADV2: { - return ParsePadV2(op, allocator, builtin_data); - } - - case BuiltinOperator_POW: { - return ParsePow(op, allocator, builtin_data); - } - - case BuiltinOperator_PRELU: { - return ParsePrelu(op, allocator, builtin_data); - } - - case BuiltinOperator_QUANTIZE: { - return ParseQuantize(op, allocator, builtin_data); - } - - case BuiltinOperator_READ_VARIABLE: { - return ParseReadVariable(op, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_ANY: { - return ParseReducer(op, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_ALL: { - return ParseReducer(op, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_MAX: { - return ParseReducer(op, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_MIN: { - return ParseReducer(op, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_PROD: { - return ParseReducer(op, allocator, builtin_data); - } - - case BuiltinOperator_RELU: { - return ParseRelu(op, allocator, builtin_data); - } - - case BuiltinOperator_RELU6: { - return ParseRelu6(op, allocator, builtin_data); - } - - case BuiltinOperator_RESHAPE: { - return ParseReshape(op, allocator, builtin_data); - } - - case BuiltinOperator_RESIZE_BILINEAR: { - return ParseResizeBilinear(op, allocator, builtin_data); - } - - case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: { - return ParseResizeNearestNeighbor(op, allocator, builtin_data); - } - - case BuiltinOperator_ROUND: { - return ParseRound(op, allocator, builtin_data); - } - - case BuiltinOperator_RSQRT: { - return ParseRsqrt(op, allocator, builtin_data); - } - - case BuiltinOperator_SELECT_V2: { - return ParseSelectV2(op, allocator, builtin_data); - } - - case BuiltinOperator_SHAPE: { - return ParseShape(op, allocator, builtin_data); - } - - case BuiltinOperator_SIN: { - return ParseSin(op, allocator, builtin_data); - } - - case BuiltinOperator_SOFTMAX: { - return ParseSoftmax(op, allocator, builtin_data); - } - - case BuiltinOperator_SPACE_TO_BATCH_ND: { - return ParseSpaceToBatchNd(op, allocator, builtin_data); - } - - case BuiltinOperator_SPACE_TO_DEPTH: { - return ParseSpaceToDepth(op, allocator, builtin_data); - } - - case BuiltinOperator_SPLIT: { - return ParseSplit(op, allocator, builtin_data); - } - - case BuiltinOperator_SPLIT_V: { - return ParseSplitV(op, allocator, builtin_data); - } - - case BuiltinOperator_SQRT: { - return ParseSqrt(op, allocator, builtin_data); - } - - case BuiltinOperator_SQUARE: { - return ParseSquare(op, allocator, builtin_data); - } - - case BuiltinOperator_SQUARED_DIFFERENCE: { - return ParseSquaredDifference(op, allocator, builtin_data); - } - - case BuiltinOperator_SQUEEZE: { - return ParseSqueeze(op, allocator, builtin_data); - } - - case BuiltinOperator_STRIDED_SLICE: { - return ParseStridedSlice(op, allocator, builtin_data); - } - - case BuiltinOperator_SUB: { - return ParseSub(op, allocator, builtin_data); - } - - case BuiltinOperator_SUM: { - return ParseReducer(op, allocator, builtin_data); - } - - case BuiltinOperator_SVDF: { - return ParseSvdf(op, allocator, builtin_data); - } - - case BuiltinOperator_TANH: { - return ParseTanh(op, allocator, builtin_data); - } - - case BuiltinOperator_TRANSPOSE_CONV: { - return ParseTransposeConv(op, allocator, builtin_data); - } - - case BuiltinOperator_UNPACK: { - return ParseUnpack(op, allocator, builtin_data); - } - - case BuiltinOperator_VAR_HANDLE: { - return ParseVarHandle(op, allocator, builtin_data); - } - - case BuiltinOperator_ZEROS_LIKE: { - return ParseZerosLike(op, allocator, builtin_data); - } - - case BuiltinOperator_BITWISE_XOR: { - return ParseBitwiseXor(op, allocator, builtin_data); - } - - case BuiltinOperator_RIGHT_SHIFT: { - return ParseRightShift(op, allocator, builtin_data); - } - - case BuiltinOperator_CAST: { - return ParseCast(op, allocator, builtin_data); - } - case BuiltinOperator_LSH_PROJECTION: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* lshParams = - op->builtin_options_as_LSHProjectionOptions()) { - params->type = parseLSHProjectionType(lshParams->type()); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* sequence_rnn_params = - op->builtin_options_as_SequenceRNNOptions()) { - params->activation = - ConvertActivation(sequence_rnn_params->fused_activation_function()); - params->time_major = sequence_rnn_params->time_major(); - params->asymmetric_quantize_inputs = - sequence_rnn_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { - auto params = - safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* bidi_sequence_rnn_params = - op->builtin_options_as_BidirectionalSequenceRNNOptions()) { - params->activation = ConvertActivation( - bidi_sequence_rnn_params->fused_activation_function()); - params->time_major = bidi_sequence_rnn_params->time_major(); - params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); - params->asymmetric_quantize_inputs = - bidi_sequence_rnn_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_RNN: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { - params->activation = - ConvertActivation(rnn_params->fused_activation_function()); - params->asymmetric_quantize_inputs = - rnn_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { - auto params = - safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* embedding_params = - op->builtin_options_as_EmbeddingLookupSparseOptions()) { - params->combiner = parseCombinerType(embedding_params->combiner()); - } - *builtin_data = params.release(); - return OkStatus(); - } - - case BuiltinOperator_HASHTABLE_LOOKUP: - // no-op. - return OkStatus(); - - case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* schema_params = - op->builtin_options_as_LocalResponseNormalizationOptions()) { - params->radius = schema_params->radius(); - params->bias = schema_params->bias(); - params->alpha = schema_params->alpha(); - params->beta = schema_params->beta(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { - return ParseUnidirectionalSequenceLSTM(op, allocator, builtin_data); - } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { - auto params = - safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* bidi_lstm_params = - op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { - params->activation = - ConvertActivation(bidi_lstm_params->fused_activation_function()); - params->cell_clip = bidi_lstm_params->cell_clip(); - params->proj_clip = bidi_lstm_params->proj_clip(); - params->merge_outputs = bidi_lstm_params->merge_outputs(); - params->time_major = bidi_lstm_params->time_major(); - params->asymmetric_quantize_inputs = - bidi_lstm_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_SKIP_GRAM: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* skip_gram_params = - op->builtin_options_as_SkipGramOptions()) { - params->ngram_size = skip_gram_params->ngram_size(); - params->max_skip_size = skip_gram_params->max_skip_size(); - params->include_all_ngrams = skip_gram_params->include_all_ngrams(); - } - *builtin_data = params.release(); - return OkStatus(); - } - - case BuiltinOperator_GATHER: { - return ParseGather(op, allocator, builtin_data); - } - case BuiltinOperator_SPARSE_TO_DENSE: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* sparse_to_dense_params = - op->builtin_options_as_SparseToDenseOptions()) { - params->validate_indices = sparse_to_dense_params->validate_indices(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_DELEGATE: { - auto error_msg = "DELEGATE op shouldn't exist in model."; - LOG(ERROR) << error_msg; - return absl::InvalidArgumentError(error_msg); - } - case BuiltinOperator_FAKE_QUANT: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* schema_params = - op->builtin_options_as_FakeQuantOptions()) { - params->min = schema_params->min(); - params->max = schema_params->max(); - params->num_bits = schema_params->num_bits(); - params->narrow_range = schema_params->narrow_range(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_ONE_HOT: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* schema_params = op->builtin_options_as_OneHotOptions()) { - params->axis = schema_params->axis(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_UNIQUE: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - const auto* unique_params = op->builtin_options_as_UniqueOptions(); - if (unique_params != nullptr) { - params->index_out_type = - unique_params->idx_out_type() == tflite::TensorType_INT64 - ? TfLiteType::kTfLiteInt64 - : TfLiteType::kTfLiteInt32; - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_REVERSE_SEQUENCE: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* reverse_seq_params = - op->builtin_options_as_ReverseSequenceOptions()) { - params->seq_dim = reverse_seq_params->seq_dim(); - params->batch_dim = reverse_seq_params->batch_dim(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_IF: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* if_params = op->builtin_options_as_IfOptions()) { - params->then_subgraph_index = if_params->then_subgraph_index(); - params->else_subgraph_index = if_params->else_subgraph_index(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_WHILE: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* while_params = op->builtin_options_as_WhileOptions()) { - params->cond_subgraph_index = while_params->cond_subgraph_index(); - params->body_subgraph_index = while_params->body_subgraph_index(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_CONV_3D: - case BuiltinOperator_CONV_3D_TRANSPOSE: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* conv3d_params = op->builtin_options_as_Conv3DOptions()) { - params->padding = ConvertPadding(conv3d_params->padding()); - params->activation = - ConvertActivation(conv3d_params->fused_activation_function()); - params->stride_depth = conv3d_params->stride_d(); - params->stride_height = conv3d_params->stride_h(); - params->stride_width = conv3d_params->stride_w(); - params->dilation_depth_factor = conv3d_params->dilation_d_factor(); - params->dilation_height_factor = conv3d_params->dilation_h_factor(); - params->dilation_width_factor = conv3d_params->dilation_w_factor(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_HASHTABLE: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* hashtable_params = - op->builtin_options_as_HashtableOptions()) { - params->table_id = hashtable_params->table_id(); - TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( - hashtable_params->key_dtype(), ¶ms->key_dtype)); - TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( - hashtable_params->value_dtype(), ¶ms->value_dtype)); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_MULTINOMIAL: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* multinomial_params = - op->builtin_options_as_RandomOptions()) { - params->seed = multinomial_params->seed(); - params->seed2 = multinomial_params->seed2(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_RANDOM_STANDARD_NORMAL: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* random_std_normal_params = - op->builtin_options_as_RandomOptions()) { - params->seed = random_std_normal_params->seed(); - params->seed2 = random_std_normal_params->seed2(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_BUCKETIZE: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* bucketize_params = - op->builtin_options_as_BucketizeOptions()) { - const flatbuffers::Vector* boundaries = - bucketize_params->boundaries(); - if (boundaries == nullptr) { - auto error_message = - "boundaries array not provided for operation 'bucketize'.\n"; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - params->num_boundaries = boundaries->size(); - if (boundaries->data() == nullptr) { - auto error_message = - "boundaries.data() returned nullptr for " - "operation 'bucketize'.\n"; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - params->boundaries = boundaries->data(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_RANDOM_UNIFORM: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* random_uniform_params = - op->builtin_options_as_RandomOptions()) { - params->seed = random_uniform_params->seed(); - params->seed2 = random_uniform_params->seed2(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_GELU: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* gelu_params = op->builtin_options_as_GeluOptions()) { - params->approximate = gelu_params->approximate(); - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_STABLEHLO_SCATTER: { - return ParseStablehloScatter(op, allocator, builtin_data); - } - case BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR: { - return ParseStablehloRngBitGenerator(op, allocator, builtin_data); - } - case BuiltinOperator_STABLEHLO_GATHER: { - return ParseStablehloGather(op, allocator, builtin_data); - } - case BuiltinOperator_STABLEHLO_REDUCE_WINDOW: { - return ParseStablehloReduceWindow(op, allocator, builtin_data); - } - case BuiltinOperator_REDUCE_WINDOW: { - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* reduce_params = - op->builtin_options_2_as_ReduceWindowOptions()) { - switch (reduce_params->reduce_function()) { - case ReduceWindowFunction_ADD: - params->reduce_function = TfLiteReduceWindowFunctionAdd; - break; - case ReduceWindowFunction_MUL: - params->reduce_function = TfLiteReduceWindowFunctionMul; - break; - case ReduceWindowFunction_MINIMUM: - params->reduce_function = TfLiteReduceWindowFunctionMin; - break; - case ReduceWindowFunction_MAXIMUM: - params->reduce_function = TfLiteReduceWindowFunctionMax; - break; - case ReduceWindowFunction_ALL: - params->reduce_function = TfLiteReduceWindowFunctionAll; - break; - case ReduceWindowFunction_ANY: - params->reduce_function = TfLiteReduceWindowFunctionAny; - break; - case ReduceWindowFunction_UNSUPPORTED: - default: - return absl::InvalidArgumentError("Unsupported reduce function"); - } - } - *builtin_data = params.release(); - return OkStatus(); - } - case BuiltinOperator_STABLEHLO_PAD: { - return ParseStablehloPad(op, allocator, builtin_data); - } - case BuiltinOperator_STABLEHLO_COMPOSITE: { - return ParseStablehloComposite(op, allocator, builtin_data); - } - case BuiltinOperator_STABLEHLO_SHIFT_LEFT: { - return ParseStablehloShiftLeft(op, allocator, builtin_data); - } - // TODO: skip param parsing for now since ops below don't have kernels - case BuiltinOperator_STABLEHLO_SLICE: - case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM: - case BuiltinOperator_STABLEHLO_CONVOLUTION: - case BuiltinOperator_STABLEHLO_LOGISTIC: - case BuiltinOperator_STABLEHLO_ADD: - case BuiltinOperator_STABLEHLO_DIVIDE: - case BuiltinOperator_STABLEHLO_MULTIPLY: - case BuiltinOperator_STABLEHLO_MAXIMUM: - case BuiltinOperator_STABLEHLO_RESHAPE: - case BuiltinOperator_STABLEHLO_CLAMP: - case BuiltinOperator_STABLEHLO_CONCATENATE: - case BuiltinOperator_STABLEHLO_CUSTOM_CALL: - case BuiltinOperator_STABLEHLO_REDUCE: - case BuiltinOperator_STABLEHLO_ABS: - case BuiltinOperator_STABLEHLO_AND: - case BuiltinOperator_STABLEHLO_COSINE: - case BuiltinOperator_STABLEHLO_EXPONENTIAL: - case BuiltinOperator_STABLEHLO_FLOOR: - case BuiltinOperator_STABLEHLO_LOG: - case BuiltinOperator_STABLEHLO_MINIMUM: - case BuiltinOperator_STABLEHLO_NEGATE: - case BuiltinOperator_STABLEHLO_OR: - case BuiltinOperator_STABLEHLO_POWER: - case BuiltinOperator_STABLEHLO_REMAINDER: - case BuiltinOperator_STABLEHLO_RSQRT: - case BuiltinOperator_STABLEHLO_SELECT: - case BuiltinOperator_STABLEHLO_SUBTRACT: - case BuiltinOperator_STABLEHLO_TANH: - case BuiltinOperator_STABLEHLO_DYNAMIC_SLICE: - case BuiltinOperator_STABLEHLO_DYNAMIC_UPDATE_SLICE: - case BuiltinOperator_STABLEHLO_IOTA: - case BuiltinOperator_STABLEHLO_COMPARE: - case BuiltinOperator_STABLEHLO_CONVERT: - case BuiltinOperator_STABLEHLO_DOT_GENERAL: - case BuiltinOperator_STABLEHLO_SORT: - case BuiltinOperator_STABLEHLO_WHILE: - case BuiltinOperator_STABLEHLO_TRANSPOSE: - case BuiltinOperator_STABLEHLO_CBRT: - - // Below are the ops with no builtin_data structure. - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - case BuiltinOperator_CALL: - case BuiltinOperator_COMPLEX_ABS: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_COS: - case BuiltinOperator_CUSTOM: - case BuiltinOperator_DENSIFY: - case BuiltinOperator_DYNAMIC_UPDATE_SLICE: - case BuiltinOperator_EQUAL: - case BuiltinOperator_HASHTABLE_FIND: - case BuiltinOperator_HASHTABLE_IMPORT: - case BuiltinOperator_HASHTABLE_SIZE: - case BuiltinOperator_IMAG: - case BuiltinOperator_MATRIX_DIAG: - case BuiltinOperator_MATRIX_SET_DIAG: - case BuiltinOperator_NON_MAX_SUPPRESSION_V4: - case BuiltinOperator_NON_MAX_SUPPRESSION_V5: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_RELU_0_TO_1: - case BuiltinOperator_SCATTER_ND: - case BuiltinOperator_SELECT: - case BuiltinOperator_SLICE: - case BuiltinOperator_TILE: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_TRANSPOSE: - case BuiltinOperator_RANGE: - case BuiltinOperator_RANK: - case BuiltinOperator_REAL: - case BuiltinOperator_RFFT2D: - case BuiltinOperator_SEGMENT_SUM: - case BuiltinOperator_REVERSE_V2: - case BuiltinOperator_UNSORTED_SEGMENT_MAX: - case BuiltinOperator_UNSORTED_SEGMENT_MIN: - case BuiltinOperator_UNSORTED_SEGMENT_PROD: - case BuiltinOperator_UNSORTED_SEGMENT_SUM: - case BuiltinOperator_ATAN2: - case BuiltinOperator_SIGN: - case BuiltinOperator_BITCAST: - case BuiltinOperator_WHERE: - case BuiltinOperator_DILATE: - return OkStatus(); - case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES: - return absl::UnimplementedError("Unsupported op"); - } - return absl::UnimplementedError("Unsupported op"); -} // NOLINT[readability/fn_size] -#endif // !defined(TF_LITE_STATIC_MEMORY) -} // namespace - -absl::Status ConvertTensorType(TensorType tensor_type, TfLiteType* type) { - switch (tensor_type) { - case TensorType_FLOAT16: - *type = kTfLiteFloat16; - return OkStatus(); - case TensorType_BFLOAT16: - *type = kTfLiteBFloat16; - return OkStatus(); - case TensorType_FLOAT32: - *type = kTfLiteFloat32; - return OkStatus(); - case TensorType_FLOAT64: - *type = kTfLiteFloat64; - return OkStatus(); - case TensorType_INT16: - *type = kTfLiteInt16; - return OkStatus(); - case TensorType_UINT16: - *type = kTfLiteUInt16; - return OkStatus(); - case TensorType_INT32: - *type = kTfLiteInt32; - return OkStatus(); - case TensorType_UINT32: - *type = kTfLiteUInt32; - return OkStatus(); - case TensorType_UINT8: - *type = kTfLiteUInt8; - return OkStatus(); - case TensorType_INT8: - *type = kTfLiteInt8; - return OkStatus(); - case TensorType_INT64: - *type = kTfLiteInt64; - return OkStatus(); - case TensorType_UINT64: - *type = kTfLiteUInt64; - return OkStatus(); - case TensorType_STRING: - *type = kTfLiteString; - return OkStatus(); - case TensorType_BOOL: - *type = kTfLiteBool; - return OkStatus(); - case TensorType_COMPLEX64: - *type = kTfLiteComplex64; - return OkStatus(); - case TensorType_COMPLEX128: - *type = kTfLiteComplex128; - return OkStatus(); - case TensorType_RESOURCE: - *type = kTfLiteResource; - return OkStatus(); - case TensorType_VARIANT: - *type = kTfLiteVariant; - return OkStatus(); - case TensorType_INT4: - *type = kTfLiteInt4; - return OkStatus(); - default: - *type = kTfLiteNoType; - auto error_message = - absl::StrFormat("Unsupported data type %d in tensor", tensor_type); - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseAbs(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseAdd(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const AddOptions* schema_params = op->builtin_options_as_AddOptions(); - - if (schema_params != nullptr) { - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - params->pot_scale_int16 = schema_params->pot_scale_int16(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseAddN(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - return OkStatus(); -} - -absl::Status ParseArgMax(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const ArgMaxOptions* schema_params = op->builtin_options_as_ArgMaxOptions(); - - if (schema_params != nullptr) { - TFL_MIGRATION_ENSURE_STATUS( - ConvertTensorType(schema_params->output_type(), ¶ms->output_type)); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseArgMin(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const ArgMinOptions* schema_params = op->builtin_options_as_ArgMinOptions(); - - if (schema_params != nullptr) { - TFL_MIGRATION_ENSURE_STATUS( - ConvertTensorType(schema_params->output_type(), ¶ms->output_type)); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseAssignVariable(const Operator*, BuiltinDataAllocator*, - void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseBatchMatMul(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* bmm_params = op->builtin_options_as_BatchMatMulOptions()) { - params->adj_x = bmm_params->adj_x(); - params->adj_y = bmm_params->adj_y(); - params->asymmetric_quantize_inputs = - bmm_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseBatchToSpaceNd(const Operator*, BuiltinDataAllocator*, - void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseBroadcastArgs(const Operator*, BuiltinDataAllocator*, - void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseBroadcastTo(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseCallOnce(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const CallOnceOptions* schema_params = - op->builtin_options_as_CallOnceOptions(); - - if (schema_params != nullptr) { - params->init_subgraph_index = schema_params->init_subgraph_index(); - - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseCast(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* schema_params = op->builtin_options_as_CastOptions()) { - TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType(schema_params->in_data_type(), - ¶ms->in_data_type)); - TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( - schema_params->out_data_type(), ¶ms->out_data_type)); - } - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseCeil(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseConcatenation(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const ConcatenationOptions* schema_params = - op->builtin_options_as_ConcatenationOptions(); - - if (schema_params != nullptr) { - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - params->axis = schema_params->axis(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseConv2D(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const Conv2DOptions* schema_params = op->builtin_options_as_Conv2DOptions(); - - if (schema_params != nullptr) { - params->padding = ConvertPadding(schema_params->padding()); - params->stride_width = schema_params->stride_w(); - params->stride_height = schema_params->stride_h(); - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - - params->dilation_width_factor = schema_params->dilation_w_factor(); - params->dilation_height_factor = schema_params->dilation_h_factor(); - TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( - schema_params->quantized_bias_type(), ¶ms->quantized_bias_type)); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseCumsum(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* cumsum_params = op->builtin_options_as_CumsumOptions()) { - params->exclusive = cumsum_params->exclusive(); - params->reverse = cumsum_params->reverse(); - } - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseCos(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseDepthToSpace(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const auto* schema_params = op->builtin_options_as_DepthToSpaceOptions(); - if (schema_params != nullptr) { - params->block_size = schema_params->block_size(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseDepthwiseConv2D(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const DepthwiseConv2DOptions* schema_params = - op->builtin_options_as_DepthwiseConv2DOptions(); - - if (schema_params != nullptr) { - params->padding = ConvertPadding(schema_params->padding()); - params->stride_width = schema_params->stride_w(); - params->stride_height = schema_params->stride_h(); - params->depth_multiplier = schema_params->depth_multiplier(); - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - - params->dilation_width_factor = schema_params->dilation_w_factor(); - params->dilation_height_factor = schema_params->dilation_h_factor(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseDequantize(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseDiv(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* schema_params = op->builtin_options_as_DivOptions()) { - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - } - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseElu(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseEmbeddingLookup(const Operator*, BuiltinDataAllocator*, - void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseEqual(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseExp(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseExpandDims(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseFill(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseFloor(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseFloorDiv(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseFloorMod(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseFullyConnected(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const FullyConnectedOptions* schema_params = - op->builtin_options_as_FullyConnectedOptions(); - - if (schema_params != nullptr) { - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - params->keep_num_dims = schema_params->keep_num_dims(); - params->asymmetric_quantize_inputs = - schema_params->asymmetric_quantize_inputs(); - TFL_MIGRATION_ENSURE_STATUS(ConvertTensorType( - schema_params->quantized_bias_type(), ¶ms->quantized_bias_type)); - switch (schema_params->weights_format()) { - case FullyConnectedOptionsWeightsFormat_DEFAULT: - params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; - break; - case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: - params->weights_format = - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; - break; - default: - auto error_message = "Unhandled fully-connected weights format."; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseGather(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - params->axis = 0; - params->batch_dims = 0; - if (const auto* gather_params = op->builtin_options_as_GatherOptions()) { - params->axis = gather_params->axis(); - params->batch_dims = gather_params->batch_dims(); - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseGatherNd(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseGreater(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseGreaterEqual(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseHardSwish(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseIf(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const IfOptions* schema_params = op->builtin_options_as_IfOptions(); - - if (schema_params != nullptr) { - params->then_subgraph_index = schema_params->then_subgraph_index(); - params->else_subgraph_index = schema_params->else_subgraph_index(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseL2Normalization(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const L2NormOptions* schema_params = op->builtin_options_as_L2NormOptions(); - - if (schema_params != nullptr) { - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseLeakyRelu(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* leaky_relu_params = - op->builtin_options_as_LeakyReluOptions()) { - params->alpha = leaky_relu_params->alpha(); - } - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseLess(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseLessEqual(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseLog(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseLogicalAnd(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseLogicalNot(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseLogicalOr(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseLogistic(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseLogSoftmax(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseLSTM(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) { - params->activation = - ConvertActivation(lstm_params->fused_activation_function()); - params->cell_clip = lstm_params->cell_clip(); - params->proj_clip = lstm_params->proj_clip(); - switch (lstm_params->kernel_type()) { - case LSTMKernelType_FULL: - params->kernel_type = kTfLiteLSTMFullKernel; - break; - case LSTMKernelType_BASIC: - params->kernel_type = kTfLiteLSTMBasicKernel; - break; - default: - auto error_message = absl::StrFormat("Unhandled LSTM kernel type: %d", - lstm_params->kernel_type()); - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - params->asymmetric_quantize_inputs = - lstm_params->asymmetric_quantize_inputs(); - } else { - auto error_message = "No valid LSTM builtin options exist"; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseMaximum(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseMinimum(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseMirrorPad(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const MirrorPadOptions* schema_params = - op->builtin_options_as_MirrorPadOptions(); - - if (schema_params != nullptr) { - params->mode = ConvertMirrorPadding(schema_params->mode()); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseMul(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const MulOptions* schema_params = op->builtin_options_as_MulOptions(); - - if (schema_params != nullptr) { - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseNeg(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseNotEqual(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParsePack(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const PackOptions* schema_params = op->builtin_options_as_PackOptions(); - - if (schema_params != nullptr) { - params->values_count = schema_params->values_count(); - params->axis = schema_params->axis(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParsePad(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParsePadV2(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParsePool(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const Pool2DOptions* schema_params = op->builtin_options_as_Pool2DOptions(); - - if (schema_params != nullptr) { - params->padding = ConvertPadding(schema_params->padding()); - params->stride_width = schema_params->stride_w(); - params->stride_height = schema_params->stride_h(); - params->filter_width = schema_params->filter_width(); - params->filter_height = schema_params->filter_height(); - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParsePow(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParsePrelu(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseQuantize(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseReadVariable(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseReducer(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const ReducerOptions* schema_params = op->builtin_options_as_ReducerOptions(); - - if (schema_params != nullptr) { - params->keep_dims = schema_params->keep_dims(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseRelu(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseRelu6(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseReshape(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const ReshapeOptions* schema_params = op->builtin_options_as_ReshapeOptions(); - - if (schema_params != nullptr) { - const flatbuffers::Vector* new_shape = schema_params->new_shape(); - if (new_shape != nullptr) { - TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( - sizeof(params->shape), new_shape, params->shape, "reshape")); - params->num_dimensions = new_shape->size(); - } else { - // TODO(b/157480169) TODO(b/147203660): We should either return - // kTfLiteError or fill in some reasonable defaults in the params struct. - // We are not doing so until we better undertand the ramifications of - // changing the legacy behavior. - } - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseResizeBilinear(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const ResizeBilinearOptions* schema_params = - op->builtin_options_as_ResizeBilinearOptions(); - - if (schema_params != nullptr) { - params->align_corners = schema_params->align_corners(); - params->half_pixel_centers = schema_params->half_pixel_centers(); - } else { - params->align_corners = false; - params->half_pixel_centers = false; - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseResizeNearestNeighbor(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const ResizeNearestNeighborOptions* schema_params = - op->builtin_options_as_ResizeNearestNeighborOptions(); - - if (schema_params != nullptr) { - params->align_corners = schema_params->align_corners(); - params->half_pixel_centers = schema_params->half_pixel_centers(); - } else { - params->align_corners = false; - params->half_pixel_centers = false; - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseStablehloReduceWindow(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - - const StablehloReduceWindowOptions* schema_params = - op->builtin_options_2_as_StablehloReduceWindowOptions(); - if (schema_params) { - if (!schema_params->window_dimensions() || - schema_params->window_dimensions()->size() == 0) { - auto error_message = - "'window_dimensions' attribute is not optional for " - "'stablehlo.reduce_window' and cannot be empty."; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - - const size_t rank = schema_params->window_dimensions()->size(); - - auto LoadAttr = [](int64_t* params_array, size_t params_array_size_bytes, - const flatbuffers::Vector* flatbuffer_vector, - const char* attr_name, const size_t expected_size, - const int64_t fill_value) -> absl::Status { - if (flatbuffer_vector && flatbuffer_vector->size()) { - if (expected_size != 0 && flatbuffer_vector->size() != expected_size) { - auto error_message = absl::StrFormat( - "'%s' attribute of 'stablehlo.reduce_window' does not have the " - "expected size (%llu != %llu).", - attr_name, flatbuffer_vector->size(), expected_size); - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - absl::Status status = FlatBufferIntVectorToArray( - params_array_size_bytes, flatbuffer_vector, params_array, - "stablehlo.reduce_window"); - if (!status.ok()) { - auto error_message = absl::StrFormat("%s Check the '%s' attribute.", - status.message(), attr_name); - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - } else { - std::fill_n(params_array, params_array_size_bytes / sizeof(int64_t), - fill_value); - } - return OkStatus(); - }; - - TFL_MIGRATION_ENSURE_STATUS( - LoadAttr(params->window_dimensions, sizeof(params->window_dimensions), - schema_params->window_dimensions(), "window_dimensions", - /*expected_size=*/rank, /*fill_value=*/1)); - TFL_MIGRATION_ENSURE_STATUS( - LoadAttr(params->window_strides, sizeof(params->window_strides), - schema_params->window_strides(), "window_strides", - /*expected_size=*/rank, /*fill_value=*/1)); - TFL_MIGRATION_ENSURE_STATUS( - LoadAttr(params->base_dilations, sizeof(params->base_dilations), - schema_params->base_dilations(), "base_dilations", - /*expected_size=*/rank, /*fill_value=*/1)); - TFL_MIGRATION_ENSURE_STATUS( - LoadAttr(params->window_dilations, sizeof(params->window_dilations), - schema_params->window_dilations(), "window_dilations", - /*expected_size=*/rank, /*fill_value=*/1)); - TFL_MIGRATION_ENSURE_STATUS(LoadAttr(params->padding, - sizeof(params->padding), - schema_params->padding(), "padding", - /*expected_size=*/2 * rank, - /*fill_value=*/0)); - - params->body_subgraph_index = schema_params->body_subgraph_index(); - *builtin_data = params.release(); - return OkStatus(); - } - auto error_message = - "Could not get 'stablehlo.reduce_window' operation parameters."; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); -} - -absl::Status ParseStablehloScatter(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const StablehloScatterOptions* schema_params = - op->builtin_options_2_as_StablehloScatterOptions(); - if (schema_params) { - params->indices_are_sorted = schema_params->indices_are_sorted(); - - if (schema_params->update_window_dims()) { - TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->update_window_dims()->size() * sizeof(int64_t), - schema_params->update_window_dims(), params->update_window_dims, - "stablehlo_scatter")); - params->num_update_window_dims = - schema_params->update_window_dims()->size(); - } - - if (schema_params->inserted_window_dims()) { - TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->inserted_window_dims()->size() * sizeof(int64_t), - schema_params->inserted_window_dims(), params->inserted_window_dims, - "stablehlo_scatter")); - params->num_inserted_window_dims = - schema_params->inserted_window_dims()->size(); - } - - if (schema_params->scatter_dims_to_operand_dims()) { - TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->scatter_dims_to_operand_dims()->size() * - sizeof(int64_t), - schema_params->scatter_dims_to_operand_dims(), - params->scatter_dims_to_operand_dims, "stablehlo_scatter")); - params->num_scatter_dims_to_operand_dims = - schema_params->scatter_dims_to_operand_dims()->size(); - } - - params->index_vector_dim = schema_params->index_vector_dim(); - params->unique_indices = schema_params->unique_indices(); - params->update_computation_subgraph_index = - schema_params->update_computation_subgraph_index(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseStablehloRngBitGenerator(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const StablehloRngBitGeneratorOptions* schema_params = - op->builtin_options_2_as_StablehloRngBitGeneratorOptions(); - if (schema_params != nullptr) { - params->algorithm = ConvertRngAlgorithm(schema_params->algorithm()); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseStablehloGather(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const StablehloGatherOptions* schema_params = - op->builtin_options_2_as_StablehloGatherOptions(); - - if (schema_params != nullptr) { - TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( - /*max_size_of_buffer=*/schema_params->offset_dims()->size() * - sizeof(int64_t), - /*flat_vector=*/schema_params->offset_dims(), - /*buffer=*/params->offset_dims, - /*op_name=*/"stablehlo_gather")); - params->num_offset_dims = schema_params->offset_dims()->size(); - - TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->collapsed_slice_dims()->size() * sizeof(int64_t), - schema_params->collapsed_slice_dims(), params->collapsed_slice_dims, - "stablehlo_gather")); - params->num_collapsed_slice_dims = - schema_params->collapsed_slice_dims()->size(); - - TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->start_index_map()->size() * sizeof(int64_t), - schema_params->start_index_map(), params->start_index_map, - "stablehlo_gather")); - params->num_start_index_map = schema_params->start_index_map()->size(); - - params->index_vector_dim = schema_params->index_vector_dim(); - - TFL_MIGRATION_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->slice_sizes()->size() * sizeof(int64_t), - schema_params->slice_sizes(), params->slice_sizes, "stablehlo_gather")); - params->num_slice_sizes = schema_params->slice_sizes()->size(); - - params->indices_are_sorted = schema_params->indices_are_sorted(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseStablehloPad(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - const StablehloPadOptions* schema_params = - op->builtin_options_2_as_StablehloPadOptions(); - - if (schema_params) { - auto LoadAttr = - [](int64_t* params_array, const size_t params_array_size_bytes, - const flatbuffers::Vector* const flatbuffer_vector, - const char* const attr_name) -> absl::Status { - absl::Status status = - FlatBufferIntVectorToArray(params_array_size_bytes, flatbuffer_vector, - params_array, "stablehlo.pad"); - if (!status.ok()) { - auto error_message = absl::StrFormat("%s Check the '%s' attribute.", - status.message(), attr_name); - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - return status; - }; - - TFL_MIGRATION_ENSURE_STATUS( - LoadAttr(params->edge_padding_low, sizeof(params->edge_padding_low), - schema_params->edge_padding_low(), "edge_padding_low")); - TFL_MIGRATION_ENSURE_STATUS( - LoadAttr(params->edge_padding_high, sizeof(params->edge_padding_high), - schema_params->edge_padding_high(), "edge_padding_high")); - TFL_MIGRATION_ENSURE_STATUS( - LoadAttr(params->interior_padding, sizeof(params->interior_padding), - schema_params->interior_padding(), "interior_padding")); - if (schema_params->edge_padding_low()->size() != - schema_params->edge_padding_high()->size() || - schema_params->edge_padding_low()->size() != - schema_params->interior_padding()->size()) { - auto error_message = - "'stablehlo.pad' operation parameter array sizes are not consistent."; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); - } - *builtin_data = params.release(); - return OkStatus(); - } - auto error_message = "Could not get 'stablehlo.pad' operation parameters."; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); -} - -absl::Status ParseStablehloComposite(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = safe_allocator.Allocate(); - const StableHLOCompositeOptions* schema_params = - op->builtin_options_2_as_StableHLOCompositeOptions(); - if (schema_params) { - params->name = schema_params->name()->c_str(); - params->version = schema_params->version(); - params->subgraph_index = schema_params->decomposition_subgraph_index(); - params->attributes = schema_params->composite_attributes()->data(); - params->attributes_size = schema_params->composite_attributes()->size(); - *builtin_data = params.release(); - return OkStatus(); - } - auto error_message = - "Could not get 'stablehlo.composite' operation parameters."; - LOG(ERROR) << error_message; - return absl::InvalidArgumentError(error_message); -} - -absl::Status ParseStablehloShiftLeft(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseRound(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseRsqrt(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseSelectV2(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseShape(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const ShapeOptions* schema_params = op->builtin_options_as_ShapeOptions(); - - if (schema_params != nullptr) { - TFL_MIGRATION_ENSURE_STATUS( - ConvertTensorType(schema_params->out_type(), ¶ms->out_type)); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseSin(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseSlice(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseSoftmax(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const SoftmaxOptions* schema_params = op->builtin_options_as_SoftmaxOptions(); - - if (schema_params != nullptr) { - params->beta = schema_params->beta(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseSpaceToBatchNd(const Operator*, BuiltinDataAllocator*, - void**) { - return OkStatus(); -} - -absl::Status ParseSpaceToDepth(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const auto* schema_params = op->builtin_options_as_SpaceToDepthOptions(); - if (schema_params != nullptr) { - params->block_size = schema_params->block_size(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseSplit(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const SplitOptions* schema_params = op->builtin_options_as_SplitOptions(); - - if (schema_params != nullptr) { - params->num_splits = schema_params->num_splits(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseSplitV(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - SafeBuiltinDataAllocator safe_allocator(allocator); - - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const SplitVOptions* schema_params = op->builtin_options_as_SplitVOptions(); - - if (schema_params != nullptr) { - params->num_splits = schema_params->num_splits(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseUnidirectionalSequenceLSTM(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - SafeBuiltinDataAllocator safe_allocator(allocator); - auto params = - safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - if (const auto* seq_lstm_params = - op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { - params->activation = - ConvertActivation(seq_lstm_params->fused_activation_function()); - params->cell_clip = seq_lstm_params->cell_clip(); - params->proj_clip = seq_lstm_params->proj_clip(); - params->time_major = seq_lstm_params->time_major(); - params->asymmetric_quantize_inputs = - seq_lstm_params->asymmetric_quantize_inputs(); - params->diagonal_recurrent_tensors = - seq_lstm_params->diagonal_recurrent_tensors(); - } - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseSqueeze(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - SafeBuiltinDataAllocator safe_allocator(allocator); - - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const SqueezeOptions* schema_params = op->builtin_options_as_SqueezeOptions(); - - if (schema_params != nullptr) { - const auto* squeeze_dims = schema_params->squeeze_dims(); - if (squeeze_dims != nullptr) { - TFL_MIGRATION_ENSURE_STATUS( - FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, - params->squeeze_dims, "squeeze")); - params->num_squeeze_dims = squeeze_dims->size(); - } else { - params->num_squeeze_dims = 0; - } - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseSqrt(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseSquare(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseSquaredDifference(const Operator*, BuiltinDataAllocator*, - void**) { - return OkStatus(); -} - -absl::Status ParseStridedSlice(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const StridedSliceOptions* schema_params = - op->builtin_options_as_StridedSliceOptions(); - - if (schema_params != nullptr) { - params->begin_mask = schema_params->begin_mask(); - params->end_mask = schema_params->end_mask(); - params->ellipsis_mask = schema_params->ellipsis_mask(); - params->new_axis_mask = schema_params->new_axis_mask(); - params->shrink_axis_mask = schema_params->shrink_axis_mask(); - params->offset = schema_params->offset(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseSub(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const SubOptions* schema_params = op->builtin_options_as_SubOptions(); - - if (schema_params != nullptr) { - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - params->pot_scale_int16 = schema_params->pot_scale_int16(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseSvdf(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const SVDFOptions* schema_params = op->builtin_options_as_SVDFOptions(); - if (schema_params != nullptr) { - params->rank = schema_params->rank(); - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - params->asymmetric_quantize_inputs = - schema_params->asymmetric_quantize_inputs(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseTanh(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} -// -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseTranspose(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseTransposeConv(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - const TransposeConvOptions* transpose_conv_params = - op->builtin_options_as_TransposeConvOptions(); - if (transpose_conv_params != nullptr) { - params->padding = ConvertPadding(transpose_conv_params->padding()); - params->stride_width = transpose_conv_params->stride_w(); - params->stride_height = transpose_conv_params->stride_h(); - - params->activation = - ConvertActivation(transpose_conv_params->fused_activation_function()); - TFL_MIGRATION_ENSURE_STATUS( - ConvertTensorType(transpose_conv_params->quantized_bias_type(), - ¶ms->quantized_bias_type)); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseUnpack(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const UnpackOptions* schema_params = op->builtin_options_as_UnpackOptions(); - - if (schema_params != nullptr) { - params->num = schema_params->num(); - params->axis = schema_params->axis(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseVarHandle(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const VarHandleOptions* schema_params = - op->builtin_options_as_VarHandleOptions(); - - if (schema_params != nullptr) { - if (schema_params->container()) { - params->container = schema_params->container()->c_str(); - } - if (schema_params->shared_name()) { - params->shared_name = schema_params->shared_name()->c_str(); - } - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -absl::Status ParseWhile(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data) { - CheckParsePointerParams(op, allocator, builtin_data); - - SafeBuiltinDataAllocator safe_allocator(allocator); - std::unique_ptr - params = safe_allocator.Allocate(); - TFL_MIGRATION_ENSURE(params != nullptr); - - const WhileOptions* schema_params = op->builtin_options_as_WhileOptions(); - - if (schema_params != nullptr) { - params->cond_subgraph_index = schema_params->cond_subgraph_index(); - params->body_subgraph_index = schema_params->body_subgraph_index(); - } else { - // TODO(b/157480169): We should either return kTfLiteError or fill in some - // reasonable defaults in the params struct. We are not doing so until we - // better understand the ramifications of changing the legacy behavior. - } - - *builtin_data = params.release(); - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseZerosLike(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseBitwiseXor(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -// We have this parse function instead of directly returning OkStatus() from the -// switch-case in ParseOpData because this function is used as part of the -// selective registration for the OpResolver implementation in micro. -absl::Status ParseRightShift(const Operator*, BuiltinDataAllocator*, void**) { - return OkStatus(); -} - -absl::Status ParseOpData(const Operator* op, BuiltinOperator op_type, - BuiltinDataAllocator* allocator, void** builtin_data) { -// TODO(b/145762662): It would be preferable to have the build graph for TF Lite -// Micro not have the ParseOpData function at all. This would require splitting -// the current file into two separate files, one of which defines the -// ParseOpData function and the other that defines the operator specific parse -// functions (e.g. ParseAdd). -// -// Such a split was attempted but was not worth the effort at the time because -// of the following reasons: -// * We could either duplicate the functions and the SafeBuiltinDataAllocator -// class in the anonymous namespace of this file, or attempt to make a common -// library with these helper functions and class. -// * Making a common library with a separate build target was not feasible as -// it introduced circular dependencies due to the ErrorReporter and a common -// .cc and .h within the same api build target the also cause circular -// dependencies due to the BuiltinDataAllocator class. -// * If all the builtin operators were to have their own parse functions, or we -// were ok with some amount of code duplication, then this split of the .cc -// files would be a lot more feasible. -#ifdef TF_LITE_STATIC_MEMORY - auto error_message = - "ParseOpData is unsupported on TfLiteMicro, please use the operator " - "specific parse functions (e.g. ParseAdd etc.).\n"; - LOG(ERROR) << error_message; - return absl::UnimplementedError(error_message); -#else - return ParseOpDataTfLite(op, op_type, allocator, builtin_data); -#endif -} - -} // namespace tflite_migration diff --git a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h deleted file mode 100644 index 5a6aa526e0971b..00000000000000 --- a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h +++ /dev/null @@ -1,440 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ - -#include -#include - -#include "absl/status/status.h" -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" - -namespace tflite_migration { - -using tflite::Operator; - -// Interface class for builtin data allocations. -class BuiltinDataAllocator { - public: - virtual void* Allocate(size_t size, size_t alignment_hint) = 0; - virtual void Deallocate(void* data) = 0; - - // Allocate a structure, but make sure it is a POD structure that doesn't - // require constructors to run. The reason we do this, is that Interpreter's C - // extension part will take ownership so destructors will not be run during - // deallocation. - template - T* AllocatePOD() { - // TODO(b/154346074): Change this to is_trivially_destructible when all - // platform targets support that properly. - static_assert(std::is_pod::value, "Builtin data structure must be POD."); - void* allocated_memory = this->Allocate(sizeof(T), alignof(T)); - return new (allocated_memory) T(); - } - - virtual ~BuiltinDataAllocator() = default; -}; - -// Parse the appropriate data out of the op. -// -// This handles builtin data explicitly as there are flatbuffer schemas. -// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The -// calling function has to pass in an allocator object, and this allocator -// will be called to reserve space for the output data. If the calling -// function's allocator reserves memory on the heap, then it's the calling -// function's responsibility to free it. -// If it returns kTfLiteError, `builtin_data` will be `nullptr`. -absl::Status ParseOpData(const tflite::Operator* op, - tflite::BuiltinOperator op_type, - BuiltinDataAllocator* allocator, void** builtin_data); - -// Converts the tensor data type used in the flat buffer to the representation -// used by the runtime. -absl::Status ConvertTensorType(tflite::TensorType tensor_type, - TfLiteType* type); - -absl::Status ParseAbs(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseAdd(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseAddN(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseArgMax(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseArgMin(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseAssignVariable(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseBatchMatMul(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseBatchToSpaceNd(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseBroadcastArgs(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseBroadcastTo(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseCallOnce(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseCeil(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseCast(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseConcatenation(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseConv2D(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseCos(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseCumsum(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseDepthToSpace(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseDepthwiseConv2D(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseDequantize(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseDiv(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseElu(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseEmbeddingLookup(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseEqual(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseExp(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseExpandDims(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseFill(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseFloor(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseFloorDiv(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseFloorMod(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseFullyConnected(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseGather(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseGatherNd(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseGreater(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseGreaterEqual(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseHardSwish(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseIf(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseL2Normalization(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLeakyRelu(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLess(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLessEqual(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLog(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLogicalAnd(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLogicalNot(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLogicalOr(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLogistic(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLogSoftmax(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseLSTM(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseMaximum(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseMinimum(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseMirrorPad(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseMul(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseNeg(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseNotEqual(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParsePack(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParsePad(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParsePadV2(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParsePool(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParsePow(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParsePrelu(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseQuantize(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseReadVariable(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseReducer(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseRelu(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseRelu6(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseReshape(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseResizeBilinear(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseResizeNearestNeighbor(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseRound(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseRsqrt(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSelectV2(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseShape(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSin(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSlice(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSoftmax(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSpaceToBatchNd(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSpaceToDepth(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSplit(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSplitV(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSqueeze(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSqrt(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSquare(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSquaredDifference(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseStridedSlice(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSub(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseSvdf(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseTanh(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseTranspose(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseTransposeConv(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseUnpack(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseUnidirectionalSequenceLSTM(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseVarHandle(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseWhile(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseZerosLike(const Operator* op, BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseBitwiseXor(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseRightShift(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseStablehloScatter(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseStablehloRngBitGenerator(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseStablehloGather(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseStablehloReduceWindow(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseStablehloPad(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseStablehloComposite(const Operator* op, - - BuiltinDataAllocator* allocator, - void** builtin_data); - -absl::Status ParseStablehloShiftLeft(const Operator* op, - BuiltinDataAllocator* allocator, - void** builtin_data); - -} // namespace tflite_migration - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions_test.cc deleted file mode 100644 index ac6ba0243eaa7f..00000000000000 --- a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions_test.cc +++ /dev/null @@ -1,873 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "flatbuffers/buffer.h" // from @flatbuffers -#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" - -using testing::AllOf; -using testing::Each; -using testing::ElementsAre; -using testing::Eq; -using testing::HasSubstr; -using testing::StrEq; -using tflite::BuiltinOptions; -using tflite::BuiltinOptions2; -using tflite::BuiltinOptions_SqueezeOptions; -using tflite::CustomOptionsFormat_FLEXBUFFERS; - -namespace tflite_migration { -using tflite::ActivationFunctionType_RELU; -using tflite::BuiltinOperator_CONV_2D; -using tflite::BuiltinOperator_CUSTOM; -using tflite::BuiltinOperator_FULLY_CONNECTED; -using tflite::BuiltinOperator_RESHAPE; -using tflite::BuiltinOperator_SQUEEZE; -using tflite::BuiltinOperator_STABLEHLO_PAD; -using tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW; -using tflite::BuiltinOptions2_StablehloPadOptions; -using tflite::BuiltinOptions2_StablehloReduceWindowOptions; -using tflite::BuiltinOptions_Conv2DOptions; -using tflite::BuiltinOptions_FullyConnectedOptions; -using tflite::BuiltinOptions_NONE; -using tflite::BuiltinOptions_ReshapeOptions; -using tflite::CreateReshapeOptions; -using tflite::CreateSqueezeOptions; -using tflite::CreateStablehloPadOptions; -using tflite::CreateStablehloReduceWindowOptions; -using tflite::FullyConnectedOptionsWeightsFormat; -using tflite::Padding_SAME; -using tflite::TensorType_BFLOAT16; -using tflite::TensorType_FLOAT16; -using tflite::TensorType_FLOAT32; -using tflite::TensorType_INT4; - -namespace { - -using std::string; - -// Used to determine how the op data parsing function creates its working space. -class MockDataAllocator : public BuiltinDataAllocator { - public: - MockDataAllocator() : is_allocated_(false) {} - void* Allocate(size_t size, size_t alignment_hint) override { - EXPECT_FALSE(is_allocated_); - const int max_size = kBufferSize; - EXPECT_LE(size, max_size); - is_allocated_ = true; - return buffer_; - } - void Deallocate(void* data) override { is_allocated_ = false; } - - private: - static constexpr int kBufferSize = 1024; - char buffer_[kBufferSize]; - bool is_allocated_; -}; - -} // namespace - -class FlatbufferConversionsTest : public ::testing::Test { - public: - const Operator* BuildTestOperator(BuiltinOptions op_type, - flatbuffers::Offset options) { - flatbuffers::Offset offset = - CreateOperatorDirect(builder_, 0, nullptr, nullptr, op_type, options, - nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr); - builder_.Finish(offset); - void* pointer = builder_.GetBufferPointer(); - return flatbuffers::GetRoot(pointer); - } - - const Operator* BuildTestOperator(BuiltinOptions2 op_type, - flatbuffers::Offset options) { - flatbuffers::Offset offset = CreateOperatorDirect( - builder_, /*opcode_index=*/0, /*inputs=*/nullptr, /*outputs=*/nullptr, - /*builtin_options_type=*/tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, /*custom_options=*/nullptr, - /*custom_options_format=*/tflite::CustomOptionsFormat_FLEXBUFFERS, - /*mutating_variable_inputs=*/nullptr, /*intermediates=*/nullptr, - /*large_custom_options_offset=*/0, /*large_custom_options_size=*/0, - /*builtin_options_2_type=*/op_type, - /*builtin_options_2=*/options); - builder_.Finish(offset); - void* pointer = builder_.GetBufferPointer(); - return flatbuffers::GetRoot(pointer); - } - - protected: - MockDataAllocator mock_allocator_; - flatbuffers::FlatBufferBuilder builder_; -}; - -TEST_F(FlatbufferConversionsTest, ParseSqueezeAll) { - const Operator* op = BuildTestOperator( - BuiltinOptions_SqueezeOptions, CreateSqueezeOptions(builder_).Union()); - void* output_data = nullptr; - EXPECT_TRUE( - ParseOpData(op, BuiltinOperator_SQUEEZE, &mock_allocator_, &output_data) - .ok()); -} - -TEST_F(FlatbufferConversionsTest, ParseDynamicReshape) { - const Operator* op = BuildTestOperator( - BuiltinOptions_ReshapeOptions, CreateReshapeOptions(builder_).Union()); - void* output_data = nullptr; - EXPECT_TRUE( - ParseOpData(op, BuiltinOperator_RESHAPE, &mock_allocator_, &output_data) - .ok()); -} - -TEST_F(FlatbufferConversionsTest, TestParseOpDataConv) { - const Operator* conv_op = - BuildTestOperator(BuiltinOptions_Conv2DOptions, - CreateConv2DOptions(builder_, Padding_SAME, 1, 2, - ActivationFunctionType_RELU, 3, 4) - .Union()); - void* output_data = nullptr; - EXPECT_TRUE(ParseOpData(conv_op, BuiltinOperator_CONV_2D, &mock_allocator_, - &output_data) - .ok()); - EXPECT_NE(nullptr, output_data); - TfLiteConvParams* params = reinterpret_cast(output_data); - EXPECT_EQ(kTfLitePaddingSame, params->padding); - EXPECT_EQ(1, params->stride_width); - EXPECT_EQ(2, params->stride_height); - EXPECT_EQ(kTfLiteActRelu, params->activation); - EXPECT_EQ(3, params->dilation_width_factor); - EXPECT_EQ(4, params->dilation_height_factor); -} - -TEST_F(FlatbufferConversionsTest, ParseBadFullyConnected) { - const Operator* conv_op = BuildTestOperator( - BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions( - builder_, ActivationFunctionType_RELU, - static_cast(-1), true) - .Union()); - void* output_data = nullptr; - EXPECT_FALSE(ParseOpData(conv_op, BuiltinOperator_FULLY_CONNECTED, - &mock_allocator_, &output_data) - .ok()); -} - -TEST_F(FlatbufferConversionsTest, TestParseOpDataCustom) { - const Operator* custom_op = - BuildTestOperator(BuiltinOptions_NONE, flatbuffers::Offset()); - void* output_data = nullptr; - EXPECT_TRUE(ParseOpData(custom_op, BuiltinOperator_CUSTOM, &mock_allocator_, - &output_data) - .ok()); - EXPECT_EQ(nullptr, output_data); -} - -TEST_F(FlatbufferConversionsTest, TestConvertTensorType) { - TfLiteType type; - EXPECT_TRUE(ConvertTensorType(TensorType_FLOAT32, &type).ok()); - EXPECT_EQ(kTfLiteFloat32, type); -} - -TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeFloat16) { - TfLiteType type; - EXPECT_TRUE(ConvertTensorType(TensorType_FLOAT16, &type).ok()); - EXPECT_EQ(kTfLiteFloat16, type); -} - -TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeBFloat16) { - TfLiteType type; - EXPECT_TRUE(ConvertTensorType(TensorType_BFLOAT16, &type).ok()); - EXPECT_EQ(kTfLiteBFloat16, type); -} - -TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeInt4) { - TfLiteType type; - EXPECT_TRUE(ConvertTensorType(TensorType_INT4, &type).ok()); - EXPECT_EQ(kTfLiteInt4, type); -} - -class StablehloReduceWindowFlatbufferConversionsTest - : public FlatbufferConversionsTest { - public: - static constexpr int kMaxDims = - TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT; - static constexpr int64_t kValidValue = 5; - - auto ValidAttr() { - return builder_.CreateVector(std::vector(kMaxDims, kValidValue)); - } - - auto InvalidAttr() { - return builder_.CreateVector( - std::vector(kMaxDims + 1, kValidValue)); - } - - auto ValidPaddingAttr() { - return builder_.CreateVector( - std::vector(2 * kMaxDims, kValidValue)); - } - - auto InvalidPaddingAttr() { - return builder_.CreateVector( - std::vector(2 * kMaxDims + 1, kValidValue)); - } - - auto EmptyAttr() { return builder_.CreateVector({}); } -}; - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, Succeeds) { - const Operator* stablehlo_reduce_window_op = BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions( - builder_, - /*window_dimensions=*/builder_.CreateVector({1, 2}), - /*window_strides=*/builder_.CreateVector({3, 4}), - /*base_dilations=*/builder_.CreateVector({5, 6}), - /*window_dilations=*/builder_.CreateVector({7, 8}), - /*padding=*/builder_.CreateVector({9, 10, 11, 12}), - /*body_subgraph_index=*/13) - .Union()); - TfLiteStablehloReduceWindowParams* output_data = nullptr; - EXPECT_TRUE(ParseOpData(stablehlo_reduce_window_op, - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, - &mock_allocator_, (void**)&output_data) - .ok()); - - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, 2), - ElementsAre(1, 2)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, 2), - ElementsAre(3, 4)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, 2), - ElementsAre(5, 6)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, 2), - ElementsAre(7, 8)); - EXPECT_THAT(std::make_tuple(output_data->padding, 4), - ElementsAre(9, 10, 11, 12)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - FailsWithNoWindowDimensions) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/0, - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("'window_dimensions' attribute is not optional for " - "'stablehlo.reduce_window' and cannot be empty.")); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - SucceedsWithNoWindowStrides) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/0, - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), Each(1)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), - Each(kValidValue)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - SucceedsWithNoBaseDilations) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/0, - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), Each(1)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), - Each(kValidValue)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - SucceedsWithNoWindowDilations) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/0, - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), - Each(1)); - EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), - Each(kValidValue)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, SucceedsWithNoPadding) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/0, - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), Each(0)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - FailsWithEmptyWindowDimensions) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/EmptyAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("'window_dimensions' attribute is not optional for " - "'stablehlo.reduce_window' and cannot be empty.")); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - SucceedsWithEmptyWindowStrides) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/EmptyAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), Each(1)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), - Each(kValidValue)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - SucceedsWithEmptyBaseDilations) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/EmptyAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), Each(1)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), - Each(kValidValue)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - SucceedsWithEmptyWindowDilations) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/EmptyAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), - Each(1)); - EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), - Each(kValidValue)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - SucceedsWithEmptyPadding) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/EmptyAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); - EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims), - Each(kValidValue)); - EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), Each(0)); - EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - SucceedsWithParamsAtMaxDims) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_TRUE(status.ok()); - EXPECT_THAT(status.message(), StrEq("")); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - FailsWhenWindowDimensionsHasMoreThanMaxDims) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions( - builder_, - /*window_dimensions=*/InvalidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - AllOf(HasSubstr("Found too many dimensions in the input array of " - "operation 'stablehlo.reduce_window'."), - HasSubstr("Check the 'window_dimensions' attribute."))); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - FailsWhenWindowStridesHasWrongDimCount) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/InvalidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - HasSubstr("'window_strides' attribute of 'stablehlo.reduce_window' does " - "not have the expected size")); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - FailsWhenBaseDilationsHasWrongDimCount) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/InvalidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - HasSubstr("'base_dilations' attribute of 'stablehlo.reduce_window' does " - "not have the expected size")); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - FailsWhenWindowDilationsHasWrongDimCount) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/InvalidAttr(), - /*padding=*/ValidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - HasSubstr( - "'window_dilations' attribute of 'stablehlo.reduce_window' does " - "not have the expected size")); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - FailsWhenPaddingHasWrongDimCount) { - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData( - BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions(builder_, - /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/InvalidPaddingAttr(), - /*body_subgraph_index=*/13) - .Union()), - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_allocator_, - (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("'padding' attribute of 'stablehlo.reduce_window' does " - "not have the expected size")); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, FailsWithWrongOptions) { - const Operator* stablehlo_reduce_window_op = - BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, 0); - TfLiteStablehloReduceWindowParams* output_data = nullptr; - auto status = ParseOpData(stablehlo_reduce_window_op, - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, - &mock_allocator_, (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - HasSubstr( - "Could not get 'stablehlo.reduce_window' operation parameters.")); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, DeathTests) { - const Operator* stablehlo_reduce_window_op = BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions( - builder_, /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), /*body_subgraph_index=*/13) - .Union()); - TfLiteStablehloReduceWindowParams* output_data = nullptr; -#ifdef NDEBUG - GTEST_SKIP(); -#endif - EXPECT_DEATH(ParseOpData(nullptr, BuiltinOperator_STABLEHLO_REDUCE_WINDOW, - &mock_allocator_, (void**)&output_data) - .IgnoreError(), - ""); - EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, nullptr, - (void**)&output_data) - .IgnoreError(), - ""); - EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, - &mock_allocator_, nullptr) - .IgnoreError(), - ""); -} - -class StablehloPadFlatbufferConversionsTest : public FlatbufferConversionsTest { - public: - static constexpr int kMaxDims = - TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT; - static constexpr int64_t kValidValue = 5; -}; - -TEST_F(StablehloPadFlatbufferConversionsTest, Succeeds) { - const Operator* stablehlo_pad_op = BuildTestOperator( - BuiltinOptions2_StablehloPadOptions, - CreateStablehloPadOptions( - builder_, - /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), - /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), - /*interior_padding=*/builder_.CreateVector({3, 0, 3})) - .Union()); - TfLiteStablehloPadParams* output_data = nullptr; - EXPECT_TRUE(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, - &mock_allocator_, (void**)&output_data) - .ok()); - EXPECT_THAT(std::make_tuple(output_data->edge_padding_low, 3), - ElementsAre(1, 0, -1)); - EXPECT_THAT(std::make_tuple(output_data->edge_padding_high, 3), - ElementsAre(2, 0, -2)); - EXPECT_THAT(std::make_tuple(output_data->interior_padding, 3), - ElementsAre(3, 0, 3)); -} - -TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingLowPadding) { - const Operator* stablehlo_pad_op = BuildTestOperator( - BuiltinOptions2_StablehloPadOptions, - CreateStablehloPadOptions( - builder_, - /*edge_padding_low=*/0, - /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), - /*interior_padding=*/builder_.CreateVector({3, 0, 3})) - .Union()); - TfLiteStablehloPadParams* output_data = nullptr; - auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, - &mock_allocator_, (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - AllOf( - HasSubstr("Input array not provided for operation 'stablehlo.pad'."), - HasSubstr("Check the 'edge_padding_low' attribute."))); -} - -TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingHighPadding) { - const Operator* stablehlo_pad_op = BuildTestOperator( - BuiltinOptions2_StablehloPadOptions, - CreateStablehloPadOptions( - builder_, - /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), - /*edge_padding_high=*/0, - /*interior_padding=*/builder_.CreateVector({3, 0, 3})) - .Union()); - TfLiteStablehloPadParams* output_data = nullptr; - auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, - &mock_allocator_, (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - AllOf( - HasSubstr("Input array not provided for operation 'stablehlo.pad'."), - HasSubstr("Check the 'edge_padding_high' attribute."))); -} - -TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingInteriorPadding) { - const Operator* stablehlo_pad_op = BuildTestOperator( - BuiltinOptions2_StablehloPadOptions, - CreateStablehloPadOptions( - builder_, - /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), - /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), - /*interior_padding=*/0) - .Union()); - TfLiteStablehloPadParams* output_data = nullptr; - auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, - &mock_allocator_, (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - AllOf( - HasSubstr("Input array not provided for operation 'stablehlo.pad'."), - HasSubstr("Check the 'interior_padding' attribute."))); -} - -TEST_F(StablehloPadFlatbufferConversionsTest, FailsInconsistentSizes) { - const Operator* stablehlo_pad_op = BuildTestOperator( - BuiltinOptions2_StablehloPadOptions, - CreateStablehloPadOptions( - builder_, - /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), - /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), - /*interior_padding=*/builder_.CreateVector({3, 0, -3, 5})) - .Union()); - TfLiteStablehloPadParams* output_data = nullptr; - auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, - &mock_allocator_, (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("'stablehlo.pad' operation parameter array sizes are " - "not consistent.")); -} - -TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithWrongOptions) { - const Operator* stablehlo_pad_op = BuildTestOperator(BuiltinOptions_NONE, 0); - TfLiteStablehloPadParams* output_data = nullptr; - auto status = ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, - &mock_allocator_, (void**)&output_data); - EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("Could not get 'stablehlo.pad' operation parameters.")); -} - -TEST_F(StablehloPadFlatbufferConversionsTest, DeathTests) { - const Operator* stablehlo_pad_op = BuildTestOperator(BuiltinOptions_NONE, 0); - TfLiteStablehloPadParams* output_data = nullptr; -#ifdef NDEBUG - GTEST_SKIP(); -#endif - EXPECT_DEATH(ParseOpData(nullptr, BuiltinOperator_STABLEHLO_PAD, - &mock_allocator_, (void**)&output_data) - .IgnoreError(), - ""); - EXPECT_DEATH(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, - nullptr, (void**)&output_data) - .IgnoreError(), - ""); - EXPECT_DEATH(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, - &mock_allocator_, nullptr) - .IgnoreError(), - ""); -} - -} // namespace tflite_migration diff --git a/tensorflow/compiler/mlir/lite/core/c/BUILD b/tensorflow/compiler/mlir/lite/core/c/BUILD index 2e07c72e817602..3338e5b8940fca 100644 --- a/tensorflow/compiler/mlir/lite/core/c/BUILD +++ b/tensorflow/compiler/mlir/lite/core/c/BUILD @@ -9,30 +9,19 @@ package( licenses = ["notice"], ) -exports_files( - srcs = [ - "builtin_op_data.h", - "tflite_types.h", - ], - visibility = [ - "//tensorflow/lite:__subpackages__", - ], -) - # LINT.IfChange(common) cc_library( name = "tflite_common", srcs = [], hdrs = [ "builtin_op_data.h", - "tflite_types.h", + "dimension_type.h", ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = [ "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:__pkg__", - "//tensorflow/lite/core/c:__subpackages__", ], alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) diff --git a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h index 836d80ab59eabf..7a67c630fe1ebd 100644 --- a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h +++ b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,67 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/// WARNING: Users of TensorFlow Lite should not include this file directly, -/// but should instead include -/// "third_party/tensorflow/lite/c/builtin_op_data.h". -/// Only the TensorFlow Lite implementation itself should include this -/// file directly. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ -#include // IWYU pragma: keep -#include -#include - -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// TfLiteReshapeParams can't have dynamic data so we fix the maximum possible -// number of dimensions. -#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT 8 - -// TODO(aselle): Consider using "if this then that" for testing. - -// Useful placeholder to put in otherwise empty structs to avoid size warnings. -typedef struct { - char dummy; -} EmptyStructPlaceholder; - -// IMPORTANT: All new members of structs must be added at the end to ensure -// backwards compatibility. - -// Possible padding types (for convolutions) +// LINT.IfChange(enum) typedef enum { kTfLitePaddingUnknown = 0, kTfLitePaddingSame, kTfLitePaddingValid, } TfLitePadding; -typedef enum { - kTfLiteMirrorPaddingUnknown = 0, - kTfLiteMirrorPaddingReflect, - kTfLiteMirrorPaddingSymmetric, -} TfLiteMirrorPaddingMode; - -// TODO(b/130259536): We should move this out of builtin_op_data. -typedef struct { - int width; - int height; - int width_offset; - int height_offset; -} TfLitePaddingValues; - -typedef struct { - TfLiteMirrorPaddingMode mode; -} TfLiteMirrorPaddingParams; - // Possible fused activation functions. typedef enum { kTfLiteActNone = 0, @@ -83,36 +32,16 @@ typedef enum { kTfLiteActSignBit, kTfLiteActSigmoid, } TfLiteFusedActivation; +// LINT.ThenChange(//tensorflow/lite/core/c/builtin_op_data.h) +// LINT.IfChange(struct) +// TODO(b/130259536): We should move this out of builtin_op_data. typedef struct { - // Parameters for CONV_2D version 1. - TfLitePadding padding; - int stride_width; - int stride_height; - TfLiteFusedActivation activation; - - // Parameters for CONV_2D version 2. - // Note: Version 2 supports dilation values not equal to 1. - int dilation_width_factor; - int dilation_height_factor; - - // Parameters for CONV_2D version 7 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteConvParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int stride_depth; - int dilation_width_factor; - int dilation_height_factor; - int dilation_depth_factor; - TfLiteFusedActivation activation; -} TfLiteConv3DParams; - -typedef TfLiteConv3DParams TfLiteConv3DTransposeParams; + int width; + int height; + int width_offset; + int height_offset; +} TfLitePaddingValues; typedef struct { TfLitePadding padding; @@ -125,537 +54,6 @@ typedef struct { TfLitePaddingValues padding; } computed; } TfLitePoolParams; - -typedef struct { - // Parameters for DepthwiseConv version 1 or above. - TfLitePadding padding; - int stride_width; - int stride_height; - // `depth_multiplier` is redundant. It's used by CPU kernels in - // TensorFlow 2.0 or below, but ignored in versions above. - // - // The information can be deduced from the shape of input and the shape of - // weights. Since the TFLiteConverter toolchain doesn't support partially - // specified shapes, relying on `depth_multiplier` stops us from supporting - // graphs with dynamic shape tensors. - // - // Note: Some of the delegates (e.g. NNAPI, GPU) are still relying on this - // field. - int depth_multiplier; - TfLiteFusedActivation activation; - // Parameters for DepthwiseConv version 2 or above. - int dilation_width_factor; - int dilation_height_factor; -} TfLiteDepthwiseConvParams; - -typedef struct { - int rank; - TfLiteFusedActivation activation; - - // Parameter for SVDF version 4. - bool asymmetric_quantize_inputs; -} TfLiteSVDFParams; - -typedef struct { - TfLiteFusedActivation activation; - - // Parameter for RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteRNNParams; - -typedef struct { - bool time_major; - TfLiteFusedActivation activation; - - // Parameter for Sequence RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteSequenceRNNParams; - -typedef struct { - bool time_major; - TfLiteFusedActivation activation; - bool merge_outputs; - - // Parameter for Bidirectional RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteBidirectionalSequenceRNNParams; - -typedef enum { - kTfLiteFullyConnectedWeightsFormatDefault = 0, - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, -} TfLiteFullyConnectedWeightsFormat; - -typedef struct { - // Parameters for FullyConnected version 1 or above. - TfLiteFusedActivation activation; - - // Parameters for FullyConnected version 2 or above. - TfLiteFullyConnectedWeightsFormat weights_format; - - // Parameters for FullyConnected version 5 or above. - // If set to true, then the number of dimensions in the input and the output - // tensors are the same. Furthermore, all but the last dimension of the input - // and output shapes will be equal. - bool keep_num_dims; - - // Parameters for FullyConnected version 7 or above. - // If set to true and the weights are quantized, then non constant inputs - // are quantized at evaluation time with asymmetric quantization. - bool asymmetric_quantize_inputs; - - // Parameters for FullyConnected version 10 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteFullyConnectedParams; - -typedef enum { - kTfLiteLshProjectionUnknown = 0, - kTfLiteLshProjectionSparse = 1, - kTfLiteLshProjectionDense = 2, -} TfLiteLSHProjectionType; - -typedef struct { - TfLiteLSHProjectionType type; -} TfLiteLSHProjectionParams; - -typedef struct { - float beta; -} TfLiteSoftmaxParams; - -typedef struct { - int axis; - TfLiteFusedActivation activation; -} TfLiteConcatenationParams; - -typedef struct { - TfLiteFusedActivation activation; - // Parameter added for the version 4. - bool pot_scale_int16; -} TfLiteAddParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteSpaceToBatchNDParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteBatchToSpaceNDParams; - -typedef struct { - bool adj_x; - bool adj_y; - // Parameters for BatchMatMul version 4 or above. - // If set to true and the weights are quantized, then non constant inputs - // are quantized at evaluation time with asymmetric quantization. - bool asymmetric_quantize_inputs; -} TfLiteBatchMatMulParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteMulParams; - -typedef struct { - TfLiteFusedActivation activation; - // Parameter added for the version 5. - bool pot_scale_int16; -} TfLiteSubParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteDivParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteL2NormParams; - -typedef struct { - int radius; - float bias; - float alpha; - float beta; -} TfLiteLocalResponseNormParams; - -typedef enum { - kTfLiteLSTMFullKernel = 0, - kTfLiteLSTMBasicKernel -} TfLiteLSTMKernelType; - -typedef struct { - // Parameters for LSTM version 1. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // Parameters for LSTM version 2. - // kTfLiteLSTMBasicKernel is only supported in version 2 or above. - TfLiteLSTMKernelType kernel_type; - - // Parameters for LSTM version 4. - bool asymmetric_quantize_inputs; -} TfLiteLSTMParams; - -typedef struct { - // Parameters needed for the underlying LSTM. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // If set to true then the first dimension is time, otherwise batch. - bool time_major; - - // Parameter for unidirectional sequence RNN version 3. - bool asymmetric_quantize_inputs; - - // Parameter for unidirectional sequence RNN version 4. - bool diagonal_recurrent_tensors; -} TfLiteUnidirectionalSequenceLSTMParams; - -typedef struct { - // Parameters supported by version 1: - // Parameters inherited for the LSTM kernel. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // If true, store the outputs of both directions in the first output. - bool merge_outputs; - - // Parameters supported by version 2: - // If set to true then the first dimension is time, otherwise batch. - bool time_major; - - // Parameters supported by version 3: - // If set to true, then hybrid ops use asymmetric quantization for inputs. - bool asymmetric_quantize_inputs; -} TfLiteBidirectionalSequenceLSTMParams; - -typedef struct { - bool align_corners; - // half_pixel_centers assumes pixels are of half the actual dimensions, and - // yields more accurate resizes. Corresponds to the same argument for the - // original TensorFlow op in TF2.0. - bool half_pixel_centers; -} TfLiteResizeBilinearParams; - -typedef struct { - bool align_corners; - bool half_pixel_centers; -} TfLiteResizeNearestNeighborParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLitePadParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLitePadV2Params; - -typedef struct { - // These fields are only used in old models for backward compatibility. - // In the current implementation, we use the 2nd input of the op as the shape, - // and these fields are unused. - int32_t shape[TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT]; - int num_dimensions; -} TfLiteReshapeParams; - -typedef struct { - int ngram_size; - int max_skip_size; - bool include_all_ngrams; -} TfLiteSkipGramParams; - -typedef struct { - int block_size; -} TfLiteSpaceToDepthParams; - -typedef struct { - int block_size; -} TfLiteDepthToSpaceParams; - -typedef struct { - TfLiteType in_data_type; - TfLiteType out_data_type; -} TfLiteCastParams; - -typedef enum { - kTfLiteCombinerTypeSum = 0, - kTfLiteCombinerTypeMean = 1, - kTfLiteCombinerTypeSqrtn = 2, -} TfLiteCombinerType; - -typedef struct { - TfLiteCombinerType combiner; -} TfLiteEmbeddingLookupSparseParams; - -typedef struct { - int axis; - int batch_dims; -} TfLiteGatherParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteTransposeParams; - -typedef struct { - bool keep_dims; -} TfLiteReducerParams; - -typedef struct { - int num_splits; -} TfLiteSplitParams; - -typedef struct { - int num_splits; -} TfLiteSplitVParams; - -typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int32_t squeeze_dims[8]; - int num_squeeze_dims; -} TfLiteSqueezeParams; - -typedef struct { - int begin_mask; - int end_mask; - int ellipsis_mask; - int new_axis_mask; - int shrink_axis_mask; - - // Parameters supported by version 8: - // If true, then the end tensor is an offset of the begin tensor. - bool offset; -} TfLiteStridedSliceParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMaxParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMinParams; - -typedef struct { - // Parameters supported by version 1: - TfLitePadding padding; - int stride_width; - int stride_height; - - // Parameters supported by version 4: - TfLiteFusedActivation activation; - - // Parameters for TransposeConv version 5 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteTransposeConvParams; - -typedef struct { - bool validate_indices; -} TfLiteSparseToDenseParams; - -typedef struct { - TfLiteType out_type; -} TfLiteShapeParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteRankParams; - -typedef struct { - // Parameters supported by version 1: - float min; - float max; - int num_bits; - - // Parameters supported by version 2: - bool narrow_range; -} TfLiteFakeQuantParams; - -typedef struct { - int values_count; - int axis; -} TfLitePackParams; - -typedef struct { - int axis; -} TfLiteOneHotParams; - -typedef struct { - int num; - int axis; -} TfLiteUnpackParams; - -typedef struct { - float alpha; -} TfLiteLeakyReluParams; - -typedef struct { - TfLiteType index_out_type; -} TfLiteUniqueParams; - -typedef struct { - int seq_dim; - int batch_dim; -} TfLiteReverseSequenceParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteMatrixDiagParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteMatrixSetDiagParams; - -typedef struct { - int then_subgraph_index; - int else_subgraph_index; -} TfLiteIfParams; - -typedef struct { - int cond_subgraph_index; - int body_subgraph_index; -} TfLiteWhileParams; - -typedef struct { - bool exclusive; - bool reverse; -} TfLiteCumsumParams; - -typedef struct { - int init_subgraph_index; -} TfLiteCallOnceParams; - -typedef struct { - int table_id; - TfLiteType key_dtype; - TfLiteType value_dtype; -} TfLiteHashtableParams; - -typedef struct { - const char* container; - const char* shared_name; -} TfLiteVarHandleParams; - -typedef struct { - int seed; - int seed2; -} TfLiteRandomParams; - -typedef struct { - int num_boundaries; - // This points to the memory stored in the model (flatbuffer), - // and is not owned. - const float* boundaries; -} TfLiteBucketizeParams; - -typedef struct { - bool approximate; -} TfLiteGeluParams; - -typedef struct { - int64_t dimension; -} TfLiteStablehloConcatenateParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter - bool indices_are_sorted; - int64_t - update_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_update_window_dims; - int64_t - inserted_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_inserted_window_dims; - int64_t scatter_dims_to_operand_dims - [TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_scatter_dims_to_operand_dims; - int64_t index_vector_dim; - bool unique_indices; - int update_computation_subgraph_index; -} TfLiteStablehloScatterParams; - -typedef enum { - kTfLiteRngAlgorithmUnknown = 0, - // An algorithm auto-selected by the system according to device type. - kTfLiteRngAlgorithmDefault, - // The Philox algorithm, as described in paper - // ['Parallel Random Numbers: As Easy as 1, 2, 3'] - // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) - kTfLiteRngAlgorithmPhilox, - // The ThreeFry algorithm, as described in paper - // ['Parallel Random Numbers: As Easy as 1, 2, 3'] - // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) - kTfLiteRngAlgorithmThreefry, -} TfLiteRngAlgorithm; - -typedef struct { - TfLiteRngAlgorithm algorithm; -} TfLiteStablehloRngBitGeneratorParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather - int64_t offset_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_offset_dims; - int64_t - collapsed_slice_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_collapsed_slice_dims; - int64_t start_index_map[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_start_index_map; - int64_t index_vector_dim; - int64_t slice_sizes[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_slice_sizes; - bool indices_are_sorted; -} TfLiteStablehloGatherParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window - int64_t window_dimensions - [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t window_dilations - [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int body_subgraph_index; -} TfLiteStablehloReduceWindowParams; - -enum TfLiteReduceWindowFunction { - TfLiteReduceWindowFunctionUnsupported, - TfLiteReduceWindowFunctionAdd, - TfLiteReduceWindowFunctionMul, - TfLiteReduceWindowFunctionMin, - TfLiteReduceWindowFunctionMax, - TfLiteReduceWindowFunctionAll, - TfLiteReduceWindowFunctionAny -}; - -typedef struct { - enum TfLiteReduceWindowFunction reduce_function; -} TfLiteReduceWindowParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad - int64_t edge_padding_low[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; - int64_t edge_padding_high[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; - int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; -} TfLiteStablehloPadParams; - -typedef struct { - const char* name; - int32_t subgraph_index; - int32_t version; - const uint8_t* attributes; - size_t attributes_size; -} TfLiteStablehloCompositeParams; - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus +// LINT.ThenChange(//tensorflow/lite/core/c/builtin_op_data.h) #endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/compiler/mlir/lite/core/c/dimension_type.h b/tensorflow/compiler/mlir/lite/core/c/dimension_type.h new file mode 100644 index 00000000000000..fd2c6122897065 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/c/dimension_type.h @@ -0,0 +1,38 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ + +// LINT.IfChange + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + + +/// Storage format of each dimension in a sparse tensor. +typedef enum TfLiteDimensionType { + kTfLiteDimDense = 0, + kTfLiteDimSparseCSR, +} TfLiteDimensionType; + +#ifdef __cplusplus +} // extern "C" + +#endif // __cplusplus +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ + +// LINT.ThenChange(//tensorflow/lite/core/c/common.h) diff --git a/tensorflow/compiler/mlir/lite/core/c/tflite_types.h b/tensorflow/compiler/mlir/lite/core/c/tflite_types.h deleted file mode 100644 index 6006b2d3c2ee5d..00000000000000 --- a/tensorflow/compiler/mlir/lite/core/c/tflite_types.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/// Types supported by tensor -// LINT.IfChange -typedef enum { - kTfLiteNoType = 0, - kTfLiteFloat32 = 1, - kTfLiteInt32 = 2, - kTfLiteUInt8 = 3, - kTfLiteInt64 = 4, - kTfLiteString = 5, - kTfLiteBool = 6, - kTfLiteInt16 = 7, - kTfLiteComplex64 = 8, - kTfLiteInt8 = 9, - kTfLiteFloat16 = 10, - kTfLiteFloat64 = 11, - kTfLiteComplex128 = 12, - kTfLiteUInt64 = 13, - kTfLiteResource = 14, - kTfLiteVariant = 15, - kTfLiteUInt32 = 16, - kTfLiteUInt16 = 17, - kTfLiteInt4 = 18, - kTfLiteBFloat16 = 19, -} TfLiteType; -// LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType) - -/// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`. -/// If per-layer quantization is specified this field will still be populated in -/// addition to `TfLiteAffineQuantization`. -/// Parameters for asymmetric quantization. Quantized values can be converted -/// back to float using: `real_value = scale * (quantized_value - zero_point)` -typedef struct TfLiteQuantizationParams { - float scale; - int32_t zero_point; -} TfLiteQuantizationParams; - -/// Storage format of each dimension in a sparse tensor. -typedef enum TfLiteDimensionType { - kTfLiteDimDense = 0, - kTfLiteDimSparseCSR, -} TfLiteDimensionType; - -#ifdef __cplusplus -} // extern C -#endif - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc index e5db23a8831872..4a28c1474e9be8 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/core/c/dimension_type.h" namespace tflite_migration { namespace internal { diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h index 56ba7181098c79..12b54502b46369 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h @@ -18,7 +18,7 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/core/c/dimension_type.h" namespace tflite_migration { namespace internal { diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index f7032741c62060..f1664849f36e50 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -416,7 +416,6 @@ filegroup( "c_api.h", "c_api_types.h", "common.h", - "//tensorflow/compiler/mlir/lite/core/c:tflite_types.h", ], ) diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index efbdc19755744f..00a1a27ec6d819 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -284,9 +284,6 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__", "//tensorflow/lite:__subpackages__", ] + c_api_visibility_allowlist(), - deps = [ - "//tensorflow/compiler/mlir/lite/core/c:tflite_common", - ], ) # Test the C extension API code. @@ -341,7 +338,6 @@ tflite_cc_library_with_c_headers_test( visibility = ["//tensorflow/lite:__subpackages__"] + common_header_visibility_allowlist(), deps = [ ":c_api_types", - "//tensorflow/compiler/mlir/lite/core/c:tflite_common", "//tensorflow/lite:tflite_kernel_use_xnnpack_optional", ] + select({ "//tensorflow/lite:tensorflow_profiler_config": [ diff --git a/tensorflow/lite/core/c/builtin_op_data.h b/tensorflow/lite/core/c/builtin_op_data.h index cfe3d825a7fa2a..e1428e72307134 100644 --- a/tensorflow/lite/core/c/builtin_op_data.h +++ b/tensorflow/lite/core/c/builtin_op_data.h @@ -20,7 +20,642 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ -#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" // IWYU pragma: export -#include "tensorflow/lite/core/c/common.h" // IWYU pragma: export +#include +#include +#include + +#include "tensorflow/lite/core/c/common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TfLiteReshapeParams can't have dynamic data so we fix the maximum possible +// number of dimensions. +#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT 8 + +// TODO(aselle): Consider using "if this then that" for testing. + +// Useful placeholder to put in otherwise empty structs to avoid size warnings. +typedef struct { + char dummy; +} EmptyStructPlaceholder; + +// IMPORTANT: All new members of structs must be added at the end to ensure +// backwards compatibility. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef enum { + kTfLiteMirrorPaddingUnknown = 0, + kTfLiteMirrorPaddingReflect, + kTfLiteMirrorPaddingSymmetric, +} TfLiteMirrorPaddingMode; + +// TODO(b/130259536): We should move this out of builtin_op_data. +typedef struct { + int width; + int height; + int width_offset; + int height_offset; +} TfLitePaddingValues; + +typedef struct { + TfLiteMirrorPaddingMode mode; +} TfLiteMirrorPaddingParams; + +// Possible fused activation functions. +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActReluN1To1, // min(max(-1, x), 1) + kTfLiteActRelu6, // min(max(0, x), 6) + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + // Parameters for CONV_2D version 1. + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; + + // Parameters for CONV_2D version 2. + // Note: Version 2 supports dilation values not equal to 1. + int dilation_width_factor; + int dilation_height_factor; + + // Parameters for CONV_2D version 7 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int stride_depth; + int dilation_width_factor; + int dilation_height_factor; + int dilation_depth_factor; + TfLiteFusedActivation activation; +} TfLiteConv3DParams; + +typedef TfLiteConv3DParams TfLiteConv3DTransposeParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + // Parameters for DepthwiseConv version 1 or above. + TfLitePadding padding; + int stride_width; + int stride_height; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // + // The information can be deduced from the shape of input and the shape of + // weights. Since the TFLiteConverter toolchain doesn't support partially + // specified shapes, relying on `depth_multiplier` stops us from supporting + // graphs with dynamic shape tensors. + // + // Note: Some of the delegates (e.g. NNAPI, GPU) are still relying on this + // field. + int depth_multiplier; + TfLiteFusedActivation activation; + // Parameters for DepthwiseConv version 2 or above. + int dilation_width_factor; + int dilation_height_factor; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; + + // Parameter for SVDF version 4. + bool asymmetric_quantize_inputs; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; + + // Parameter for RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + + // Parameter for Sequence RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteSequenceRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + bool merge_outputs; + + // Parameter for Bidirectional RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteBidirectionalSequenceRNNParams; + +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + +typedef struct { + // Parameters for FullyConnected version 1 or above. + TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimensions in the input and the output + // tensors are the same. Furthermore, all but the last dimension of the input + // and output shapes will be equal. + bool keep_num_dims; + + // Parameters for FullyConnected version 7 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; + + // Parameters for FullyConnected version 10 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; + +typedef struct { + float beta; +} TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; + // Parameter added for the version 4. + bool pot_scale_int16; +} TfLiteAddParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteSpaceToBatchNDParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteBatchToSpaceNDParams; + +typedef struct { + bool adj_x; + bool adj_y; + // Parameters for BatchMatMul version 4 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; +} TfLiteBatchMatMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; + // Parameter added for the version 5. + bool pot_scale_int16; +} TfLiteSubParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteDivParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + +typedef struct { + // Parameters for LSTM version 1. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; + + // Parameters for LSTM version 4. + bool asymmetric_quantize_inputs; +} TfLiteLSTMParams; + +typedef struct { + // Parameters needed for the underlying LSTM. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If set to true then the first dimension is time, otherwise batch. + bool time_major; + + // Parameter for unidirectional sequence RNN version 3. + bool asymmetric_quantize_inputs; + + // Parameter for unidirectional sequence RNN version 4. + bool diagonal_recurrent_tensors; +} TfLiteUnidirectionalSequenceLSTMParams; + +typedef struct { + // Parameters supported by version 1: + // Parameters inherited for the LSTM kernel. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If true, store the outputs of both directions in the first output. + bool merge_outputs; + + // Parameters supported by version 2: + // If set to true then the first dimension is time, otherwise batch. + bool time_major; + + // Parameters supported by version 3: + // If set to true, then hybrid ops use asymmetric quantization for inputs. + bool asymmetric_quantize_inputs; +} TfLiteBidirectionalSequenceLSTMParams; + +typedef struct { + bool align_corners; + // half_pixel_centers assumes pixels are of half the actual dimensions, and + // yields more accurate resizes. Corresponds to the same argument for the + // original TensorFlow op in TF2.0. + bool half_pixel_centers; +} TfLiteResizeBilinearParams; + +typedef struct { + bool align_corners; + bool half_pixel_centers; +} TfLiteResizeNearestNeighborParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLitePadParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLitePadV2Params; + +typedef struct { + // These fields are only used in old models for backward compatibility. + // In the current implementation, we use the 2nd input of the op as the shape, + // and these fields are unused. + int32_t shape[TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef struct { + int block_size; +} TfLiteDepthToSpaceParams; + +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +typedef struct { + int axis; + int batch_dims; +} TfLiteGatherParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteTransposeParams; + +typedef struct { + bool keep_dims; +} TfLiteReducerParams; + +typedef struct { + int num_splits; +} TfLiteSplitParams; + +typedef struct { + int num_splits; +} TfLiteSplitVParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int32_t squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; + + // Parameters supported by version 8: + // If true, then the end tensor is an offset of the begin tensor. + bool offset; +} TfLiteStridedSliceParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMaxParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + +typedef struct { + // Parameters supported by version 1: + TfLitePadding padding; + int stride_width; + int stride_height; + + // Parameters supported by version 4: + TfLiteFusedActivation activation; + + // Parameters for TransposeConv version 5 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteRankParams; + +typedef struct { + // Parameters supported by version 1: + float min; + float max; + int num_bits; + + // Parameters supported by version 2: + bool narrow_range; +} TfLiteFakeQuantParams; + +typedef struct { + int values_count; + int axis; +} TfLitePackParams; + +typedef struct { + int axis; +} TfLiteOneHotParams; + +typedef struct { + int num; + int axis; +} TfLiteUnpackParams; + +typedef struct { + float alpha; +} TfLiteLeakyReluParams; + +typedef struct { + TfLiteType index_out_type; +} TfLiteUniqueParams; + +typedef struct { + int seq_dim; + int batch_dim; +} TfLiteReverseSequenceParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixDiagParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixSetDiagParams; + +typedef struct { + int then_subgraph_index; + int else_subgraph_index; +} TfLiteIfParams; + +typedef struct { + int cond_subgraph_index; + int body_subgraph_index; +} TfLiteWhileParams; + +typedef struct { + bool exclusive; + bool reverse; +} TfLiteCumsumParams; + +typedef struct { + int init_subgraph_index; +} TfLiteCallOnceParams; + +typedef struct { + int table_id; + TfLiteType key_dtype; + TfLiteType value_dtype; +} TfLiteHashtableParams; + +typedef struct { + const char* container; + const char* shared_name; +} TfLiteVarHandleParams; + +typedef struct { + int seed; + int seed2; +} TfLiteRandomParams; + +typedef struct { + int num_boundaries; + // This points to the memory stored in the model (flatbuffer), + // and is not owned. + const float* boundaries; +} TfLiteBucketizeParams; + +typedef struct { + bool approximate; +} TfLiteGeluParams; + +typedef struct { + int64_t dimension; +} TfLiteStablehloConcatenateParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter + bool indices_are_sorted; + int64_t + update_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_update_window_dims; + int64_t + inserted_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_inserted_window_dims; + int64_t scatter_dims_to_operand_dims + [TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_scatter_dims_to_operand_dims; + int64_t index_vector_dim; + bool unique_indices; + int update_computation_subgraph_index; +} TfLiteStablehloScatterParams; + +typedef enum { + kTfLiteRngAlgorithmUnknown = 0, + // An algorithm auto-selected by the system according to device type. + kTfLiteRngAlgorithmDefault, + // The Philox algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + kTfLiteRngAlgorithmPhilox, + // The ThreeFry algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + kTfLiteRngAlgorithmThreefry, +} TfLiteRngAlgorithm; + +typedef struct { + TfLiteRngAlgorithm algorithm; +} TfLiteStablehloRngBitGeneratorParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather + int64_t offset_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_offset_dims; + int64_t + collapsed_slice_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_collapsed_slice_dims; + int64_t start_index_map[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_start_index_map; + int64_t index_vector_dim; + int64_t slice_sizes[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_slice_sizes; + bool indices_are_sorted; +} TfLiteStablehloGatherParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window + int64_t window_dimensions + [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t window_dilations + [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int body_subgraph_index; +} TfLiteStablehloReduceWindowParams; + +enum TfLiteReduceWindowFunction { + TfLiteReduceWindowFunctionUnsupported, + TfLiteReduceWindowFunctionAdd, + TfLiteReduceWindowFunctionMul, + TfLiteReduceWindowFunctionMin, + TfLiteReduceWindowFunctionMax, + TfLiteReduceWindowFunctionAll, + TfLiteReduceWindowFunctionAny +}; + +typedef struct { + enum TfLiteReduceWindowFunction reduce_function; +} TfLiteReduceWindowParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad + int64_t edge_padding_low[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; + int64_t edge_padding_high[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; + int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; +} TfLiteStablehloPadParams; + +typedef struct { + const char* name; + int32_t subgraph_index; + int32_t version; + const uint8_t* attributes; + size_t attributes_size; +} TfLiteStablehloCompositeParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index dc2601bf127169..f0b76bde0258cb 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -36,12 +36,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_C_C_API_TYPES_H_ #define TENSORFLOW_LITE_CORE_C_C_API_TYPES_H_ +#include + #ifdef __cplusplus extern "C" { #endif -#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" // IWYU pragma: export - // clang-format off // NOLINTBEGIN(whitespace/line_length) /** \defgroup c_api_types lite/c/c_api_types.h @@ -112,6 +112,42 @@ typedef enum TfLiteStatus { kTfLiteCancelled = 8, } TfLiteStatus; +/// Types supported by tensor +// LINT.IfChange +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, + kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, + kTfLiteInt8 = 9, + kTfLiteFloat16 = 10, + kTfLiteFloat64 = 11, + kTfLiteComplex128 = 12, + kTfLiteUInt64 = 13, + kTfLiteResource = 14, + kTfLiteVariant = 15, + kTfLiteUInt32 = 16, + kTfLiteUInt16 = 17, + kTfLiteInt4 = 18, + kTfLiteBFloat16 = 19, +} TfLiteType; +// LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType) + +/// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`. +/// If per-layer quantization is specified this field will still be populated in +/// addition to `TfLiteAffineQuantization`. +/// Parameters for asymmetric quantization. Quantized values can be converted +/// back to float using: `real_value = scale * (quantized_value - zero_point)` +typedef struct TfLiteQuantizationParams { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + // -------------------------------------------------------------------------- // Opaque types used by c_api.h, c_api_opaque.h and common.h. diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 29cf64aa9d351c..5d3100816492ae 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -442,6 +442,12 @@ enum { kTfLiteNullBufferHandle = -1, }; +/// Storage format of each dimension in a sparse tensor. +typedef enum TfLiteDimensionType { + kTfLiteDimDense = 0, + kTfLiteDimSparseCSR, +} TfLiteDimensionType; + /// Metadata to encode each dimension in a sparse tensor. typedef struct TfLiteDimensionMetadata { TfLiteDimensionType format; diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index c11994f7c86c24..e4c1432e6c8dda 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -142,9 +142,7 @@ TFLITE_HEADERS = [ "//tensorflow/lite/core/c:c_api.h", "//tensorflow/lite/core/c:c_api_opaque.h", "//tensorflow/lite/core/c:c_api_types.h", - "//tensorflow/compiler/mlir/lite/core/c:tflite_types.h", "//tensorflow/lite/core/c:builtin_op_data.h", - "//tensorflow/compiler/mlir/lite/core/c:builtin_op_data.h", "//tensorflow/lite/core/c:c_api_experimental.h", "//tensorflow/lite/core/c:common.h", "//tensorflow/lite/core/c:operator.h", From f75a78e9520b6f6666ec53c957911029b244f38c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 10:06:26 -0700 Subject: [PATCH 262/483] Mark Tensorflow compatible with Protobuf v26+. PiperOrigin-RevId: 678745731 --- tensorflow/tools/pip_package/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 1145b2fd83ca17..0d0ada7899a006 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -86,7 +86,7 @@ def standard_or_nightly(standard, nightly): 'packaging', # pylint:disable=line-too-long ( - 'protobuf>=3.20.3,<5.0.0dev,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5' + 'protobuf>=3.20.3,<6.0.0dev,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5' ), 'requests >= 2.21.0, < 3', 'setuptools', From ff8692d1f3dfe590e6120cbe150141579ab03334 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 25 Sep 2024 10:13:53 -0700 Subject: [PATCH 263/483] Reverts 54acff042a003087f24f9dc0e695d6895f6fee1d PiperOrigin-RevId: 678748794 --- third_party/xla/xla/python/pytree.cc | 12 +++--------- third_party/xla/xla/python/xla_client.py | 2 +- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 65bfb3fe5305e4..5592a96454821d 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -595,15 +595,9 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { case PyTreeKind::kNone: if (!object.is_none()) { - PythonDeprecationWarning( - /*stacklevel=*/3, - "In a future release of JAX, flatten-up-to will no longer " - "consider None to be a tree-prefix of non-None values, got: " - "%s.\n\n" - "To preserve the current behavior, you can usually write:\n" - " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " - "b, is_leaf=lambda x: x is None)", - nb::cast(nb::repr(object))); + throw std::invalid_argument( + absl::StrFormat("Expected None, got %s.", + nb::cast(nb::repr(object)))); } break; diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index bcb7d7a96fd249..a692e90f5813c8 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 287 +_version = 288 # Version number for MLIR:Python components. mlir_api_version = 57 From 3ffbfa563ccf51cd9cfbcfac7a24046f49bfbf0a Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Wed, 25 Sep 2024 10:13:59 -0700 Subject: [PATCH 264/483] Integrate Triton up to [698e97a7](https://github.com/openai/triton/commits/6152840d3747056c9f10375ab418903e698e97a7) PiperOrigin-RevId: 678748836 --- third_party/triton/workspace.bzl | 4 ++-- third_party/xla/third_party/triton/workspace.bzl | 4 ++-- .../service/gpu/fusions/triton/compilation_pipeline_cuda.cc | 1 + .../service/gpu/fusions/triton/compilation_pipeline_rocm.cc | 1 + 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index f952345e93b979..7b8b5c5e4073f4 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl675928942" - TRITON_SHA256 = "4e31bfdd10d3e9c6277a47b9af9d64fc60a2cc1b81330da3cb7d01d938be1d36" + TRITON_COMMIT = "cl678220124" + TRITON_SHA256 = "d999dfce02398707993fae590d38bc5c8ca8c8bcd820717b6c777747a172c30f" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index f952345e93b979..7b8b5c5e4073f4 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl675928942" - TRITON_SHA256 = "4e31bfdd10d3e9c6277a47b9af9d64fc60a2cc1b81330da3cb7d01d938be1d36" + TRITON_COMMIT = "cl678220124" + TRITON_SHA256 = "d999dfce02398707993fae590d38bc5c8ca8c8bcd820717b6c777747a172c30f" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index 46a569d265bcdd..9930d06c762c0b 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -58,6 +58,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createLoopInvariantCodeMotionPass()); pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mt::createLoopUnrollPass()); // Based on make_ttgir() in // @triton//:third_party/nvidia/backend/compiler.py diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index a48e65ab3a6953..19be28128edf69 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -70,6 +70,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createLoopInvariantCodeMotionPass()); pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mt::createLoopUnrollPass()); // Based on make_ttgir() in // @triton//:third_party/amd/backend/compiler.py From d449933cf56fb08be31e20818773e3cf95300626 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 10:26:24 -0700 Subject: [PATCH 265/483] Some refactoring to simplify code associated with error handling at the end of the auto-sharding pass. PiperOrigin-RevId: 678753820 --- .../auto_sharding/auto_sharding.cc | 134 ++++++++++-------- 1 file changed, 71 insertions(+), 63 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 3feff04d56e1b3..41713ce3bd5314 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3950,14 +3950,45 @@ absl::Status MoveComputationsFromModuleToModule(HloModule* from_module, AutoSharding::AutoSharding(const AutoShardingOption& option) : option_(option) {} +absl::Time DumpModuleAndRecordPassStart(const HloModule* module) { + XLA_VLOG_LINES(6, + absl::StrCat("Before auto sharding:\n", module->ToString())); + DumpHloModuleIfEnabled(*module, "before_auto_spmd_sharding"); + + // TODO(b/348372403) Explore replacing these with a runtime check, per + // go/no-ifdefs-in-xla +#if !defined(__APPLE__) + // Streamz metrics. + metrics::RecordAutoShardingInvocations(); +#endif + return absl::Now(); +} + +void RecordPassEndAndDumpModule(absl::Time start_time, + const HloModule* module) { + absl::Time end_time = absl::Now(); + absl::Duration duration = end_time - start_time; + LOG(INFO) << "Auto Sharding took " << absl::ToInt64Seconds(duration) + << " seconds"; + // TODO(b/348372403) Explore replacing these with a runtime check, per + // go/no-ifdefs-in-xla +#if !defined(__APPLE__) + metrics::RecordAutoShardingCompilationTime( + absl::ToInt64Microseconds(duration)); +#endif + + XLA_VLOG_LINES(6, absl::StrCat("After auto sharding:\n", module->ToString())); + DumpHloModuleIfEnabled(*module, "after_auto_spmd_sharding"); +} + absl::StatusOr AutoSharding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!option_.enable) { return false; } - LOG(INFO) << "Starting the auto sharding pass"; + LOG(INFO) << "Starting the auto sharding pass"; // TODO(b/332951306): Remove this check once nested tuples are supported // everywhere if (HasUnsupportedNestedTuples(*module)) { @@ -3967,15 +3998,7 @@ absl::StatusOr AutoSharding::Run( return false; } - XLA_VLOG_LINES(6, - absl::StrCat("Before auto sharding:\n", module->ToString())); - DumpHloModuleIfEnabled(*module, "before_auto_spmd_sharding"); - - absl::Time start_time = absl::Now(); -#if !defined(__APPLE__) - // Streamz metrics. - metrics::RecordAutoShardingInvocations(); -#endif + absl::Time start_time = DumpModuleAndRecordPassStart(module); TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); TF_RETURN_IF_ERROR(option_.CheckAndSetup()); @@ -4079,9 +4102,8 @@ absl::StatusOr AutoSharding::Run( } } - size_t num_meshes = mesh_shapes.size(); - std::vector> changed( - num_meshes, AutoShardingResult::kModuleUnchanged); + absl::StatusOr min_mesh_pass_result = + AutoShardingResult::kModuleUnchanged; VLOG(1) << "Original mesh shape " << spmd::ToString(option_.device_mesh_shape); @@ -4104,15 +4126,14 @@ absl::StatusOr AutoSharding::Run( absl::StatusOr pass_result = pass->RunAutoSharding(module_clone.get(), replicated_small_tensors, execution_threads, sharding_propagation_solution); - - changed[i] = pass_result; - double this_mesh_objective_value = pass->GetSolverOptimalObjectiveValue(); if (!pass_result.ok()) { VLOG(1) << "Mesh shape " << spmd::ToString(mesh_shapes[i]) << " led to the following error: " << pass_result.status().message(); continue; } + + double this_mesh_objective_value = pass->GetSolverOptimalObjectiveValue(); VLOG(1) << "Mesh shape " << spmd::ToString(mesh_shapes[i]) << " has objective value " << this_mesh_objective_value; if (this_mesh_objective_value >= 0 && @@ -4120,6 +4141,7 @@ absl::StatusOr AutoSharding::Run( min_mesh_shape_index = i; min_mesh_shape_module = std::move(module_clone); min_objective_value = this_mesh_objective_value; + min_mesh_pass_result = pass_result; } if (*pass_result != AutoShardingResult::kModuleUnchangedNoShardingPerformed) { @@ -4127,59 +4149,45 @@ absl::StatusOr AutoSharding::Run( } } - absl::StatusOr module_is_changed; if (skip_auto_sharding) { - module_is_changed = false; // The auto-sharding solver timed out. - } else { - std::string trying_to_find = - option_.try_multiple_mesh_shapes - ? "a device mesh (and the corresponding shardings)" - : "shardings"; - CHECK_GE(min_mesh_shape_index, 0) - << "The auto-sharding pass could not find " << trying_to_find - << " that works for this input. This could be the result of a low " - "memory budget (please refer to the " - "`--xla_tpu_auto_spmd_partitioning_memory_budget_ratio` flag to set " - "a higher budget). If you think you have set a reasonably large " - "memory budget, please report this as a bug."; - - if (!changed[min_mesh_shape_index].ok()) { - module_is_changed = changed[min_mesh_shape_index].status(); - } else { - solver_optimal_objective_value_ = min_objective_value; - if (changed[min_mesh_shape_index].value() == - AutoShardingResult::kModuleChangedShardingPerformed) { - VLOG(1) << "Choosing mesh shape " - << spmd::ToString(mesh_shapes[min_mesh_shape_index]) - << " which had the minimal solver objective value of " - << min_objective_value; - chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; - TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule( - min_mesh_shape_module.get(), module)); - module_is_changed = true; - } else { - module_is_changed = false; - } - } + RecordPassEndAndDumpModule(start_time, module); + LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; } - absl::Time end_time = absl::Now(); - absl::Duration duration = end_time - start_time; - LOG(INFO) << "Auto Sharding took " << absl::ToInt64Seconds(duration) - << " seconds"; -#if !defined(__APPLE__) - metrics::RecordAutoShardingCompilationTime( - absl::ToInt64Microseconds(duration)); -#endif - - XLA_VLOG_LINES(6, absl::StrCat("After auto sharding:\n", module->ToString())); - DumpHloModuleIfEnabled(*module, "after_auto_spmd_sharding"); + std::string trying_to_find = + option_.try_multiple_mesh_shapes + ? "a device mesh (and the corresponding shardings)" + : "shardings"; + CHECK_GE(min_mesh_shape_index, 0) + << "The auto-sharding pass could not find " << trying_to_find + << " that works for this input. This could be the result of a low memory " + "budget (please refer to the " + "`--xla_tpu_auto_spmd_partitioning_memory_budget_ratio` flag to set a " + "higher budget). If you think you have set a reasonably large memory " + "budget, please report this as a bug."; + + if (!min_mesh_pass_result.ok()) { + RecordPassEndAndDumpModule(start_time, module); + return min_mesh_pass_result.status(); + } - if (skip_auto_sharding) { - LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; + absl::StatusOr module_is_changed; + solver_optimal_objective_value_ = min_objective_value; + if (*min_mesh_pass_result != + AutoShardingResult::kModuleChangedShardingPerformed) { + RecordPassEndAndDumpModule(start_time, module); + return false; } - return module_is_changed; + VLOG(1) << "Choosing mesh shape " + << spmd::ToString(mesh_shapes[min_mesh_shape_index]) + << " which had the minimal solver objective value of " + << min_objective_value; + chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; + TF_RETURN_IF_ERROR( + MoveComputationsFromModuleToModule(min_mesh_shape_module.get(), module)); + RecordPassEndAndDumpModule(start_time, module); + return true; } } // namespace xla From 75246c169a09dd6afb2e52b62238c0b168f9aa4a Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 25 Sep 2024 10:52:30 -0700 Subject: [PATCH 266/483] [IFRT] Remove `xla::ifrt::Layout` alias This removes now unused `xla::ifrt::Layout` alias. A new `xla::ifrt::Layout` tpye will be defined soon. PiperOrigin-RevId: 678764034 --- third_party/xla/xla/python/ifrt/array.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/array.h b/third_party/xla/xla/python/ifrt/array.h index e080cceeec31c0..d70c17d03b17ae 100644 --- a/third_party/xla/xla/python/ifrt/array.h +++ b/third_party/xla/xla/python/ifrt/array.h @@ -39,8 +39,6 @@ namespace ifrt { class Client; -using Layout = ::xla::PjRtLayout; - // Semantics for operations that may copy or move sharded buffers in an array. enum class ArrayCopySemantics : int { // Always creates new buffers to construct an output array. Mutation of the From 777b74ac8a2f67b97bbe1d03705e133f5709a241 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 25 Sep 2024 10:58:21 -0700 Subject: [PATCH 267/483] [PjRt-IFRT] Remove pjrt_dtype.h include from pjrt_array.h This change remove the pjrt_dtype.h include from pjrt_array.h that was kept while migrating downstream users. PiperOrigin-RevId: 678766470 --- third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h index a48156d9ffd61b..8fd648068c0c10 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h @@ -39,7 +39,6 @@ limitations under the License. #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/python/pjrt_ifrt/pjrt_dtype.h" // IWYU pragma: keep // TODO(hyeontaek): Remove this include once downstream users are migrated to use the new header directly. #include "xla/tsl/concurrency/ref_count.h" namespace xla { From 7cbf3d17d732bcbb11d6c29030e5efa34bc8c92b Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 25 Sep 2024 11:01:27 -0700 Subject: [PATCH 268/483] Remove completed experiment. This experiment completed and not being actively pursued further. PiperOrigin-RevId: 678767769 --- tensorflow/c/BUILD | 2 - tensorflow/cc/experimental/libexport/BUILD | 1 - tensorflow/cc/experimental/libtf/BUILD | 303 -------- tensorflow/cc/experimental/libtf/function.cc | 263 ------- tensorflow/cc/experimental/libtf/function.h | 54 -- tensorflow/cc/experimental/libtf/impl/BUILD | 134 ---- .../cc/experimental/libtf/impl/iostream.cc | 44 -- .../experimental/libtf/impl/iostream_test.cc | 65 -- tensorflow/cc/experimental/libtf/impl/none.cc | 28 - tensorflow/cc/experimental/libtf/impl/none.h | 57 -- .../cc/experimental/libtf/impl/none_test.cc | 42 -- .../cc/experimental/libtf/impl/scalars.h | 70 -- .../experimental/libtf/impl/scalars_test.cc | 29 - .../cc/experimental/libtf/impl/string.cc | 35 - .../cc/experimental/libtf/impl/string.h | 68 -- .../cc/experimental/libtf/impl/string_test.cc | 31 - .../cc/experimental/libtf/impl/tensor_spec.h | 52 -- .../libtf/impl/tensor_spec_test.cc | 46 -- tensorflow/cc/experimental/libtf/mlir/BUILD | 30 - .../experimental/libtf/mlir/mlir_transform.cc | 91 --- .../experimental/libtf/mlir/mlir_transform.h | 30 - tensorflow/cc/experimental/libtf/module.cc | 119 --- tensorflow/cc/experimental/libtf/module.h | 61 -- tensorflow/cc/experimental/libtf/object.cc | 29 - tensorflow/cc/experimental/libtf/object.h | 709 ------------------ .../cc/experimental/libtf/runtime/BUILD | 44 -- .../cc/experimental/libtf/runtime/core/BUILD | 24 - .../experimental/libtf/runtime/core/core.cc | 45 -- .../cc/experimental/libtf/runtime/core/core.h | 33 - .../cc/experimental/libtf/runtime/runtime.cc | 185 ----- .../cc/experimental/libtf/runtime/runtime.h | 107 --- .../experimental/libtf/tests/function_test.cc | 294 -------- .../libtf/tests/generate_testdata.py | 105 --- .../libtf/tests/mlir_transform_test.cc | 55 -- .../experimental/libtf/tests/module_test.cc | 135 ---- .../experimental/libtf/tests/object_test.cc | 184 ----- .../cc/experimental/libtf/tests/perf_test.cc | 99 --- .../experimental/libtf/tests/runtime_test.cc | 131 ---- .../experimental/libtf/tests/runtime_test.h | 44 -- .../libtf/tests/runtime_test_core.cc | 27 - .../experimental/libtf/tests/tensor_test.cc | 129 ---- .../experimental/libtf/tests/testdata/README | 2 - .../data-structure-model/saved_model.pb | Bin 10117 -> 0 bytes .../variables/variables.data-00000-of-00001 | Bin 417 -> 0 bytes .../variables/variables.index | Bin 252 -> 0 bytes .../testdata/simple-model/saved_model.pb | Bin 7696 -> 0 bytes .../variables/variables.data-00000-of-00001 | Bin 25 -> 0 bytes .../simple-model/variables/variables.index | Bin 144 -> 0 bytes .../cc/experimental/libtf/tests/value_test.cc | 114 --- .../experimental/libtf/tests/variable_test.cc | 120 --- .../cc/experimental/libtf/tests/visit_test.cc | 45 -- tensorflow/cc/experimental/libtf/value.h | 596 --------------- .../cc/experimental/libtf/value_iostream.h | 93 --- 53 files changed, 5004 deletions(-) delete mode 100644 tensorflow/cc/experimental/libtf/BUILD delete mode 100644 tensorflow/cc/experimental/libtf/function.cc delete mode 100644 tensorflow/cc/experimental/libtf/function.h delete mode 100644 tensorflow/cc/experimental/libtf/impl/BUILD delete mode 100644 tensorflow/cc/experimental/libtf/impl/iostream.cc delete mode 100644 tensorflow/cc/experimental/libtf/impl/iostream_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/impl/none.cc delete mode 100644 tensorflow/cc/experimental/libtf/impl/none.h delete mode 100644 tensorflow/cc/experimental/libtf/impl/none_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/impl/scalars.h delete mode 100644 tensorflow/cc/experimental/libtf/impl/scalars_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/impl/string.cc delete mode 100644 tensorflow/cc/experimental/libtf/impl/string.h delete mode 100644 tensorflow/cc/experimental/libtf/impl/string_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/impl/tensor_spec.h delete mode 100644 tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/mlir/BUILD delete mode 100644 tensorflow/cc/experimental/libtf/mlir/mlir_transform.cc delete mode 100644 tensorflow/cc/experimental/libtf/mlir/mlir_transform.h delete mode 100644 tensorflow/cc/experimental/libtf/module.cc delete mode 100644 tensorflow/cc/experimental/libtf/module.h delete mode 100644 tensorflow/cc/experimental/libtf/object.cc delete mode 100644 tensorflow/cc/experimental/libtf/object.h delete mode 100644 tensorflow/cc/experimental/libtf/runtime/BUILD delete mode 100644 tensorflow/cc/experimental/libtf/runtime/core/BUILD delete mode 100644 tensorflow/cc/experimental/libtf/runtime/core/core.cc delete mode 100644 tensorflow/cc/experimental/libtf/runtime/core/core.h delete mode 100644 tensorflow/cc/experimental/libtf/runtime/runtime.cc delete mode 100644 tensorflow/cc/experimental/libtf/runtime/runtime.h delete mode 100644 tensorflow/cc/experimental/libtf/tests/function_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/generate_testdata.py delete mode 100644 tensorflow/cc/experimental/libtf/tests/mlir_transform_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/module_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/object_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/perf_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/runtime_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/runtime_test.h delete mode 100644 tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/tensor_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/testdata/README delete mode 100644 tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/saved_model.pb delete mode 100644 tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.data-00000-of-00001 delete mode 100644 tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.index delete mode 100644 tensorflow/cc/experimental/libtf/tests/testdata/simple-model/saved_model.pb delete mode 100644 tensorflow/cc/experimental/libtf/tests/testdata/simple-model/variables/variables.data-00000-of-00001 delete mode 100644 tensorflow/cc/experimental/libtf/tests/testdata/simple-model/variables/variables.index delete mode 100644 tensorflow/cc/experimental/libtf/tests/value_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/variable_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/tests/visit_test.cc delete mode 100644 tensorflow/cc/experimental/libtf/value.h delete mode 100644 tensorflow/cc/experimental/libtf/value_iostream.h diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index ae24298fad8f4e..4380e1cc70101f 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -347,8 +347,6 @@ tf_cuda_library( ], visibility = [ "//tensorflow/c:__subpackages__", - "//tensorflow/cc/experimental/libtf:__pkg__", - "//tensorflow/cc/experimental/libtf:__subpackages__", # copybara:uncomment_begin(google-only) # "//tensorflow/cc/experimental/tf2:__pkg__", # "//tensorflow/cc/experimental/tf2:__subpackages__", diff --git a/tensorflow/cc/experimental/libexport/BUILD b/tensorflow/cc/experimental/libexport/BUILD index d206c115abea65..117bc64b436864 100644 --- a/tensorflow/cc/experimental/libexport/BUILD +++ b/tensorflow/cc/experimental/libexport/BUILD @@ -8,7 +8,6 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", "//tensorflow/python/saved_model:__subpackages__", ], licenses = ["notice"], diff --git a/tensorflow/cc/experimental/libtf/BUILD b/tensorflow/cc/experimental/libtf/BUILD deleted file mode 100644 index def467880f96fb..00000000000000 --- a/tensorflow/cc/experimental/libtf/BUILD +++ /dev/null @@ -1,303 +0,0 @@ -#include "third_party/absl/strings/str_cat.h" -#TODO(aselle) : describe this package. - -load("//tensorflow:strict.default.bzl", "py_strict_binary") -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", -) -load("//tensorflow:tensorflow.default.bzl", "filegroup") -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "libtf", - srcs = [ - "object.cc", - ], - hdrs = [ - "object.h", - "value.h", - "value_iostream.h", - ], - deps = [ - "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/cc/experimental/libtf/impl:iostream", - "//tensorflow/cc/experimental/libtf/impl:none", - "//tensorflow/cc/experimental/libtf/impl:scalars", - "//tensorflow/cc/experimental/libtf/impl:string", - "//tensorflow/cc/experimental/libtf/impl:tensor_spec", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:intrusive_ptr", - "//tensorflow/core/platform:status", - "//tensorflow/core/platform:statusor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "module", - srcs = ["module.cc"], - hdrs = ["module.h"], - deps = [ - "//tensorflow/cc/experimental/libexport:load", - "//tensorflow/cc/experimental/libtf/runtime", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/protobuf:for_core_protos_cc", - ], -) - -cc_library( - name = "function", - srcs = [ - "function.cc", - ], - hdrs = [ - "function.h", - ], - deps = [ - ":libtf", - "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:abstract_function", - "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:statusor", - "@com_google_absl//absl/cleanup", - ], -) - -py_strict_binary( - name = "generate_testdata", - srcs = ["tests/generate_testdata.py"], - python_version = "PY3", - deps = [ - "//tensorflow/python/compat:v2_compat", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:tensor_spec", - "//tensorflow/python/module", - "//tensorflow/python/ops:variables", - "//tensorflow/python/saved_model", - "@absl_py//absl:app", - "@absl_py//absl/flags", - ], -) - -filegroup( - name = "testdata", - srcs = glob([ - "tests/testdata/**", - ]), -) - -tf_cc_test( - name = "libtf_object_test", - size = "medium", - srcs = ["tests/object_test.cc"], - deps = [ - ":libtf", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:statusor", - ], -) - -tf_cc_test( - name = "libtf_perf_test", - size = "medium", - srcs = ["tests/perf_test.cc"], - deps = [ - ":libtf", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -tf_cc_test( - name = "libtf_value_test", - size = "medium", - srcs = ["tests/value_test.cc"], - deps = [ - ":libtf", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -tf_cc_test( - name = "libtf_visit_test", - size = "medium", - srcs = ["tests/visit_test.cc"], - deps = [ - ":libtf", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "runtime_test", - testonly = 1, - srcs = [ - "tests/runtime_test.cc", - ], - hdrs = [ - "tests/runtime_test.h", - ], - data = [":testdata"], - deps = [ - ":libtf", - "//tensorflow/c:tf_datatype", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c/eager:unified_api_testutil", - "//tensorflow/cc/experimental/libtf/runtime", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:resource_loader", - "//tensorflow/core/platform:status_matchers", - "//tensorflow/core/platform:statusor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "libtf_runtime_test", - size = "medium", - srcs = [ - "tests/runtime_test_core.cc", - ], - deps = [ - ":runtime_test", - "//tensorflow/cc/experimental/libtf/runtime", - "//tensorflow/cc/experimental/libtf/runtime/core", - ], -) - -tf_cc_test( - name = "libtf_module_test", - size = "medium", - srcs = ["tests/module_test.cc"], - data = [":testdata"], - deps = [ - ":libtf", - ":module", - "//tensorflow/cc/experimental/libexport:load", - "//tensorflow/cc/experimental/libtf/runtime/core", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:path", - "//tensorflow/core/platform:protobuf", - "//tensorflow/core/platform:resource_loader", - "//tensorflow/core/platform:status_matchers", - "//tensorflow/core/platform:statusor", - ], -) - -tf_cc_test( - name = "libtf_tensor_test", - size = "medium", - srcs = ["tests/tensor_test.cc"], - tags = ["no_oss"], # TODO(b/193268458): Need to disable TFRT. - deps = [ - ":libtf", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_experimental", - "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:unified_api_testutil", - "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/lib/llvm_rtti", - "//tensorflow/core/platform:errors", - ], -) - -tf_cc_test( - name = "function_test", - size = "medium", - srcs = ["tests/function_test.cc"], - tags = ["no_oss"], # TODO(b/193268458): Need to disable TFRT. - deps = [ - ":function", - ":libtf", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:abstract_function", - "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:graph_function", - "//tensorflow/c/eager:unified_api_testutil", - "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:refcount", - "//tensorflow/core/platform:statusor", - ], -) - -tf_cc_test( - name = "variable_test", - size = "small", - srcs = ["tests/variable_test.cc"], - tags = ["no_oss"], # TODO(b/193268458): Need to disable TFRT. - deps = [ - ":function", - ":libtf", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:abstract_function", - "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:graph_function", - "//tensorflow/c/eager:unified_api_testutil", - "//tensorflow/c/experimental/ops:resource_variable_ops", - "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:refcount", - "//tensorflow/core/platform:statusor", - ], -) - -tf_cc_test( - name = "libtf_transform_test", - size = "medium", - srcs = ["tests/mlir_transform_test.cc"], - data = [":testdata"], - deps = [ - ":libtf", - "//tensorflow/c/eager:c_api_experimental", - "//tensorflow/cc/experimental/libtf/mlir:transform", - "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:path", - "//tensorflow/core/platform:resource_loader", - "//tensorflow/core/platform:statusor", - ], -) diff --git a/tensorflow/cc/experimental/libtf/function.cc b/tensorflow/cc/experimental/libtf/function.cc deleted file mode 100644 index 06b7fa15db5a6b..00000000000000 --- a/tensorflow/cc/experimental/libtf/function.cc +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/experimental/libtf/function.h" - -#include -#include - -#include "absl/cleanup/cleanup.h" -#include "tensorflow/c/eager/abstract_function.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/statusor.h" - -namespace tf { -namespace libtf { - -using tensorflow::AbstractContext; -using tensorflow::AbstractFunctionPtr; -using tensorflow::AbstractOperationPtr; -using tensorflow::AbstractTensorHandle; -using tensorflow::Status; -using tensorflow::StatusOr; - -// TODO(srbs): Move this to unified execution API. -tensorflow::Status ExecuteFunction( - AbstractFunctionPtr trace, AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - // TODO(srbs): Provide a function execution API on ctx so that we do not - // expose the internals of how functions are to be executed here. - std::string fname; - { - const tensorflow::FunctionDef* fdef = nullptr; - TF_RETURN_IF_ERROR(trace->GetFunctionDef(&fdef)); - fname = fdef->signature().name(); - } - // TODO(srbs): Update RegisterFunction to accept AbstractFunctionPtr. - TF_RETURN_IF_ERROR(ctx->RegisterFunction(trace.get())); - auto cleanup = absl::MakeCleanup( - [fname, ctx]() { ctx->RemoveFunction(fname).IgnoreError(); }); - auto call_op = AbstractOperationPtr(ctx->CreateOperation()); - TF_RETURN_IF_ERROR( - call_op->Reset(fname.c_str(), /*raw_device_name=*/nullptr)); - for (auto t : inputs) { - TF_RETURN_IF_ERROR(call_op->AddInput(t)); - } - int num_outputs = outputs.size(); - return call_op->Execute(outputs, &num_outputs); -} - -Status VerifySupportedSignature(TaggedValue signature) { - if (signature.type() == TaggedValue::Type::TENSOR_SPEC) { - return absl::OkStatus(); - } - if (signature.type() == TaggedValue::Type::TUPLE) { - for (const auto& t : signature.tuple()) { - if (t.type() != TaggedValue::Type::TENSOR_SPEC) { - break; - } - } - return absl::OkStatus(); - } - return tensorflow::errors::Unimplemented( - "Only functions with inputs/outputs containing a single tensor or a tuple" - " of tensors are supported right now."); -} - -Status VerifySupportedArgs(TaggedValue args) { - if (args.type() == TaggedValue::Type::TENSOR) { - return absl::OkStatus(); - } - if (args.type() == TaggedValue::Type::TUPLE) { - for (const auto& t : args.tuple()) { - if (t.type() != TaggedValue::Type::TENSOR) { - break; - } - } - return absl::OkStatus(); - } - return tensorflow::errors::Unimplemented( - "Only functions with inputs/outputs containing a single tensor or a tuple" - " of tensors are supported right now."); -} - -Status Function::RegisterTrace(AbstractFunctionPtr fn, - TaggedValue input_signature, - TaggedValue output_signature) { - TF_RETURN_IF_ERROR(VerifySupportedSignature(input_signature)); - TF_RETURN_IF_ERROR(VerifySupportedSignature(output_signature)); - concrete_fns_.push_back({fn, input_signature, output_signature}); - return absl::OkStatus(); -} - -bool Match(TaggedValue signature, TaggedValue value) { - // TODO(b/187216309): Extend this to handle more elaborate signatures and - // values. - switch (signature.type()) { - case TaggedValue::Type::TENSOR_SPEC: { - if (value.type() != TaggedValue::Type::TENSOR) { - return false; - } - auto spec = signature.tensor_spec(); - const auto& tensor = value.tensor(); - if (tensor->DataType() != spec.dtype) { - return false; - } - tensorflow::PartialTensorShape tensor_shape; - DCHECK(tensor->Shape(&tensor_shape).ok()); - if (!tensor_shape.IsCompatibleWith(spec.shape)) { - return false; - } - } break; - case TaggedValue::Type::TUPLE: { - if (value.type() != TaggedValue::Type::TUPLE) { - return false; - } - if (value.tuple().size() != signature.tuple().size()) { - return false; - } - for (auto i = 0; i < value.tuple().size(); i++) { - if (!Match(signature.tuple()[i], value.tuple()[i])) { - return false; - } - } - } break; - default: - return false; - } - return true; -} - -// TODO(b/190203981): Move to a separate nest-like library. -void Flatten(const TaggedValue& value, - std::vector* flat_args) { - if (value.type() == TaggedValue::Type::TENSOR) { - flat_args->emplace_back(value.tensor().get()); - } else if (value.type() == TaggedValue::Type::TUPLE) { - for (const auto& t : value.tuple()) { - Flatten(t, flat_args); - } - } else { - // TODO(b/190203981): Supported arbitrary structures. - LOG(ERROR) << "Unimplemented"; - } -} - -absl::StatusOr Unflatten( - absl::Span flat_args, TaggedValue structure) { - if (structure.type() == TaggedValue::Type::TENSOR_SPEC) { - if (flat_args.size() != 1) { - // Denotes a corrupted SavedModel in which output_signature does not match - // FunctionDef outputs. - return tensorflow::errors::Internal("Expected single tensor but found ", - flat_args.size()); - } - TaggedValue wrapped_t = - TaggedValue(impl::TaggedValueTensor(flat_args[0], /*add_ref=*/true)); - if (!Match(structure, wrapped_t)) { - // Denotes a corrupted SavedModel in which output_signature does not match - // FunctionDef outputs. - std::stringstream stream; - stream << "Shape and dtype of tensor " << wrapped_t - << " does not match that in signature " << structure; - return tensorflow::errors::Internal(stream.str()); - } - return wrapped_t; - } else if (structure.type() == TaggedValue::Type::TUPLE) { - // TODO(b/190203981): Remove this check when handling nested structures - // inside tuples. - if (flat_args.size() != structure.tuple().size()) { - return tensorflow::errors::InvalidArgument( - "Tuple length ", structure.tuple().size(), - " does not match length of flat args ", flat_args.size()); - } - auto result = impl::TaggedValue::Tuple(); - for (auto i = 0; i < structure.tuple().size(); i++) { - TF_ASSIGN_OR_RETURN(TaggedValue ele, - Unflatten({flat_args[i]}, structure.tuple()[i])); - result.tuple().emplace_back(std::move(ele)); - } - return result; - } else { - // TODO(b/190203981): Support arbitrary structures. - return tensorflow::errors::Unimplemented( - "Only tensors and tuples of tensors are supported right now."); - } -} - -size_t GetFlatSize(const TaggedValue& value) { - if (value.type() == TaggedValue::Type::TUPLE) { - size_t result = 0; - for (const auto& t : value.tuple()) { - result += GetFlatSize(t); - } - return result; - } else if (value.type() == TaggedValue::Type::LIST) { - size_t result = 0; - for (const auto& t : value.list()) { - result += GetFlatSize(t); - } - return result; - } else if (value.type() == TaggedValue::Type::DICT) { - size_t result = 0; - for (const auto& t : value.dict()) { - result += GetFlatSize(t.second); - } - return result; - } - return 1; -} - -absl::StatusOr Function::Execute(AbstractContext* ctx, - TaggedValue value) const { - TF_RETURN_IF_ERROR(VerifySupportedArgs(value)); - TF_ASSIGN_OR_RETURN(auto concrete_fn, GetConcreteFunction(value)); - std::vector args; - Flatten(value, &args); - std::vector outs( - GetFlatSize(concrete_fn.output_signature)); - TF_RETURN_IF_ERROR( - ExecuteFunction(concrete_fn.trace, ctx, args, absl::MakeSpan(outs))); - auto cleanup_tensors = absl::MakeCleanup([outs]() { - for (auto t : outs) { - t->Unref(); - } - }); - return Unflatten(outs, concrete_fn.output_signature); -} - -absl::StatusOr Function::GetConcreteFunction( - TaggedValue value) const { - if (concrete_fns_.empty()) { - return tensorflow::errors::FailedPrecondition( - "No registered ConcreteFunctions."); - } - for (auto& spec : concrete_fns_) { - if (Match(spec.input_signature, value)) { - return spec; - } - } - return tensorflow::errors::InvalidArgument("No match found."); -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/function.h b/tensorflow/cc/experimental/libtf/function.h deleted file mode 100644 index 21232dd6fecc69..00000000000000 --- a/tensorflow/cc/experimental/libtf/function.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_FUNCTION_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_FUNCTION_H_ - -#include - -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_function.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/core/platform/statusor.h" - -namespace tf { -namespace libtf { - -class Function { - public: - tensorflow::Status RegisterTrace(tensorflow::AbstractFunctionPtr, - TaggedValue input_signature, - TaggedValue output_signature); - - // Executes this function under the execution context. - // - // Raises an error is no matching signature is found for TaggedValue. - absl::StatusOr Execute(tensorflow::AbstractContext*, - TaggedValue) const; - - private: - struct ConcreteFunction { - tensorflow::AbstractFunctionPtr trace; - TaggedValue input_signature; - TaggedValue output_signature; - }; - absl::StatusOr GetConcreteFunction(TaggedValue) const; - std::vector concrete_fns_; -}; - -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_FUNCTION_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/BUILD b/tensorflow/cc/experimental/libtf/impl/BUILD deleted file mode 100644 index 97b06b21682daa..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/BUILD +++ /dev/null @@ -1,134 +0,0 @@ -# libtf implementation details. - -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", -) -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "iostream", - srcs = [ - "iostream.cc", - ], - deps = [ - ":none", - ":string", - ":tensor_spec", - ], -) - -tf_cc_test( - name = "iostream_test", - size = "small", - srcs = ["iostream_test.cc"], - deps = [ - ":iostream", - ":none", - ":scalars", - ":string", - ":tensor_spec", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "scalars", - hdrs = [ - "scalars.h", - ], -) - -tf_cc_test( - name = "scalars_test", - size = "small", - srcs = ["scalars_test.cc"], - deps = [ - ":scalars", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "string", - srcs = [ - "string.cc", - ], - hdrs = [ - "string.h", - ], -) - -tf_cc_test( - name = "string_test", - size = "small", - srcs = ["string_test.cc"], - deps = [ - ":string", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "none", - srcs = [ - "none.cc", - ], - hdrs = [ - "none.h", - ], -) - -tf_cc_test( - name = "none_test", - size = "small", - srcs = ["none_test.cc"], - deps = [ - ":none", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/container:flat_hash_set", - ], -) - -cc_library( - name = "tensor_spec", - hdrs = [ - "tensor_spec.h", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_cc_test( - name = "tensor_spec_test", - size = "small", - srcs = ["tensor_spec_test.cc"], - deps = [ - ":iostream", # Necessary for absl::VerifyTypeImplementsAbslHashCorrectly. - ":tensor_spec", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/hash:hash_testing", - ], -) diff --git a/tensorflow/cc/experimental/libtf/impl/iostream.cc b/tensorflow/cc/experimental/libtf/impl/iostream.cc deleted file mode 100644 index eee899b8704d82..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/iostream.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Specializations of ostream::operator<< for API values. These are defined here -// so that they don't need to be linked in executables that need to be kept -// small (and don't use the functionality). -#include - -#include "tensorflow/cc/experimental/libtf/impl/none.h" -#include "tensorflow/cc/experimental/libtf/impl/string.h" -#include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" - -namespace tf { -namespace libtf { -namespace impl { - -std::ostream& operator<<(std::ostream& o, const None& none) { - return o << "None"; -} - -std::ostream& operator<<(std::ostream& o, const String& str) { - return o << str.str(); -} - -std::ostream& operator<<(std::ostream& o, const TensorSpec& x) { - o << "TensorSpec(shape = " << x.shape.DebugString() << ", dtype = " << x.dtype - << ")"; - return o; -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/iostream_test.cc b/tensorflow/cc/experimental/libtf/impl/iostream_test.cc deleted file mode 100644 index dede1483d76187..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/iostream_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/none.h" -#include "tensorflow/cc/experimental/libtf/impl/scalars.h" -#include "tensorflow/cc/experimental/libtf/impl/string.h" -#include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(OStreamTest, TestInt64) { - Int64 x(42); - std::stringstream stream; - stream << x; - ASSERT_EQ(stream.str(), "42"); -} - -TEST(OStreamTest, TestFloat32) { - Float32 x(0.375); // Exactly representable as a float. - std::stringstream stream; - stream << x; - ASSERT_EQ(stream.str(), "0.375"); -} - -TEST(OStreamTest, TestString) { - String s("foo"); - std::stringstream stream; - stream << s; - ASSERT_EQ(stream.str(), "foo"); -} - -TEST(OStreamTest, TestNone) { - std::stringstream stream; - stream << None::GetInstance(); - ASSERT_EQ(stream.str(), "None"); -} - -TEST(OStreamTest, TestTensorSpec) { - std::stringstream stream; - TensorSpec tensor_spec; - tensor_spec.shape = tensorflow::PartialTensorShape({2}); - tensor_spec.dtype = tensorflow::DT_FLOAT; - stream << tensor_spec; - ASSERT_EQ(stream.str(), "TensorSpec(shape = [2], dtype = 1)"); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/none.cc b/tensorflow/cc/experimental/libtf/impl/none.cc deleted file mode 100644 index 8f16b1ed4ab760..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/none.cc +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/none.h" - -namespace tf { -namespace libtf { -namespace impl { - -None& None::GetInstance() { - static None* none_inst = new None(); - return *none_inst; -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/none.h b/tensorflow/cc/experimental/libtf/impl/none.h deleted file mode 100644 index 84dd654a4502b5..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/none.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_NONE_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_NONE_H_ - -#include -#include - -namespace tf { -namespace libtf { -namespace impl { -/// @brief The Singleton `None` class. -/// -/// This class is not user-constructible. To create a `None` instance, use -/// None::GetInstance(). - -class None final { - public: - /// Retrieves the `None` instance. - /// - /// @return Returns the `None` singleton. - static None& GetInstance(); - - /// Equality operator. - bool operator==(const None& other) const { return true; } - - /// Overload AbslHashValue. - template - friend H AbslHashValue(H h, const None& n) { - return H::combine(std::move(h), 34559); - } - - private: - // Private contructor. - None() {} -}; - -// Defined in iostream.cc. -std::ostream& operator<<(std::ostream& o, const None& none); - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_NONE_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/none_test.cc b/tensorflow/cc/experimental/libtf/impl/none_test.cc deleted file mode 100644 index d9629e09704eb5..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/none_test.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/experimental/libtf/impl/none.h" - -#include "absl/container/flat_hash_set.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(NoneTest, TestSingleton) { - None& a = None::GetInstance(); - None& b = None::GetInstance(); - EXPECT_EQ(&a, &b); -} - -TEST(NoneTest, TestSupportsAbslHash) { - absl::flat_hash_set none_set; - None& a = None::GetInstance(); - None& b = None::GetInstance(); - none_set.insert(a); - none_set.insert(b); - EXPECT_EQ(none_set.size(), 1); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/scalars.h b/tensorflow/cc/experimental/libtf/impl/scalars.h deleted file mode 100644 index 2345705637e585..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/scalars.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_SCALARS_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_SCALARS_H_ - -#include - -#include -#include - -namespace tf { -namespace libtf { -namespace impl { - -/** A thin wrapper around a C++ scalar value. - * This wrapper makes the scalar immutable. - */ -template -class Scalar final { - public: - explicit Scalar(T x) : value_(x) {} - Scalar(const Scalar& o) : value_(o.value_) {} - - bool operator==(const Scalar& o) const { return o.value_ == value_; } - - T get() const { return value_; } - - /** Absl hash function. */ - template - friend H AbslHashValue(H h, const Scalar& x) { - return H::combine(std::move(h), x.value_); - } - - private: - const T value_; -}; - -template -inline std::ostream& operator<<(std::ostream& o, const Scalar& x) { - return o << x.get(); -} - -/** The overloaded addition operator. */ -template -inline auto operator+(const Scalar& x1, const Scalar& x2) - -> Scalar { - using Ret = decltype(x1 + x2); // Return type of this function. - return Ret(x1.get() + x2.get()); -} - -using Int64 = Scalar; -using Float32 = Scalar; - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_SCALARS_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/scalars_test.cc b/tensorflow/cc/experimental/libtf/impl/scalars_test.cc deleted file mode 100644 index 79c73f194426d5..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/scalars_test.cc +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/scalars.h" - -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(ScalarsTest, TestHeterogeneousAddition) { - ASSERT_EQ((Int64(1) + Float32(0.375)).get(), 1.375); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/string.cc b/tensorflow/cc/experimental/libtf/impl/string.cc deleted file mode 100644 index 70c716e552a08f..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/string.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/string.h" - -#include - -// It is important for the container below to not invalidate pointers to -// elements when elements are inserted, because the String class stores such -// pointers. This rules out, for example, absl::flat_hash_set. -using StringTable = std::unordered_set; - -namespace tf { -namespace libtf { -namespace impl { - -String::String(const char* s) { - static StringTable* table = new StringTable; - value_ = &*table->insert(s).first; -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/string.h b/tensorflow/cc/experimental/libtf/impl/string.h deleted file mode 100644 index a54fb25b9775c5..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/string.h +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_STRING_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_STRING_H_ - -#include -#include - -namespace tf { -namespace libtf { -namespace impl { - -/** A string value. - * This class wraps an interned, immutable string value. Currently, interned - * values are never deleted, so memory usage increases without bound as new - * strings are created. - */ -class String final { - public: - /** Interning constructor. - * Interns the given string value. - */ - explicit String(const char* s); - - String() : String("") {} - String(const String& s) : value_(s.value_) {} - - // This is the same as the default equality operator, which works because - // we're interning all strings. It is specified here so we are explicit about - // it. We're not saying "= default;" because we can't use C++20 features yet. - bool operator==(const String& other) const { return value_ == other.value_; } - - const std::string& str() const { return *value_; } - - /** Absl hash function. */ - template - friend H AbslHashValue(H h, const String& s) { - return H::combine(std::move(h), *s.value_); - } - - private: - //! The interned string value. This is never null. - const std::string* value_; -}; - -// This is defined in the `iostream.cc` file in this directory. It is not -// defined inline here because the `iosfwd` header does not provide enough -// functionality (in Windows), and we don't want to include `iostream` to avoid -// increasing the binary size. -std::ostream& operator<<(std::ostream& o, const String& str); - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_STRING_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/string_test.cc b/tensorflow/cc/experimental/libtf/impl/string_test.cc deleted file mode 100644 index 4cc07d07dfa095..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/string_test.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/string.h" - -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(StringTest, TestBasicInterning) { - String s1("foo"); - String s2("foo"); - EXPECT_EQ(&s1.str(), &s2.str()); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/tensor_spec.h b/tensorflow/cc/experimental/libtf/impl/tensor_spec.h deleted file mode 100644 index be7c19297d8c8c..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/tensor_spec.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_TENSOR_SPEC_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_TENSOR_SPEC_H_ - -#include - -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" - -namespace tf { -namespace libtf { -namespace impl { -/// @brief The TensorSpec struct. -/// -/// The TensorSpec describes the shape and dtype of a Tensor. - -struct TensorSpec { - tensorflow::PartialTensorShape shape; - tensorflow::DataType dtype; - - bool operator==(const TensorSpec& o) const { - return dtype == o.dtype && shape.IsIdenticalTo(o.shape); - } - - /// Overload AbslHashValue to make TensorSpec hashable. - template - friend H AbslHashValue(H h, const TensorSpec& t) { - return H::combine(std::move(h), t.shape.DebugString(), t.dtype); - } -}; - -// Defined in `iostream.cc`. -std::ostream& operator<<(std::ostream& o, const TensorSpec& x); - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_TENSOR_SPEC_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc b/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc deleted file mode 100644 index dc07f77c7ba9b7..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" - -#include "absl/hash/hash_testing.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(TensorSpecTest, TestSupportsAbslHash) { - tensorflow::PartialTensorShape unknown_shape; - TensorSpec ts1; - ts1.shape = unknown_shape; - ts1.dtype = tensorflow::DT_FLOAT; - - TensorSpec ts2; - ts2.shape = tensorflow::PartialTensorShape({2}); - ts2.dtype = tensorflow::DT_FLOAT; - - TensorSpec ts3; - ts3.shape = tensorflow::PartialTensorShape({1, 2}); - ts3.dtype = tensorflow::DT_FLOAT; - - EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ts1, ts2, ts3})); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/mlir/BUILD b/tensorflow/cc/experimental/libtf/mlir/BUILD deleted file mode 100644 index db86e4c34e8de9..00000000000000 --- a/tensorflow/cc/experimental/libtf/mlir/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -# Parts of new C++ API that interface with MLIR. - -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "transform", - srcs = [ - "mlir_transform.cc", - ], - hdrs = ["mlir_transform.h"], - deps = [ - "//tensorflow/cc/experimental/libtf", - "//tensorflow/cc/saved_model:bundle_v2", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", - ], -) diff --git a/tensorflow/cc/experimental/libtf/mlir/mlir_transform.cc b/tensorflow/cc/experimental/libtf/mlir/mlir_transform.cc deleted file mode 100644 index f5bd971caec516..00000000000000 --- a/tensorflow/cc/experimental/libtf/mlir/mlir_transform.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/mlir/mlir_transform.h" - -#include -#include - -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/saved_model/bundle_v2.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" - -namespace tf { -namespace libtf { - -// TODO(b/190837282): All return None's become errors. -Handle LoadModule(Object self, String saved_model) { - // Parse arguments. - // Load SavedModel into memory. - tensorflow::SavedModelV2Bundle bundle; - tensorflow::Status status = - tensorflow::SavedModelV2Bundle::Load(saved_model.get(), &bundle); - if (!status.ok()) { - return None(); - } - // Fetch MLIR context - auto* context = self.Get(String("_context")) - ->cast(); - - // Load the saved model into MLIR TF dialect. - absl::Span exported_names(nullptr, 0); - auto module_or = - tensorflow::ConvertSavedModelToMlir(&bundle, context, exported_names); - if (!module_or.status().ok()) { - return None(); - } - - // Make a module to wrap MLIR module and allow getting strings and running - // transforms. - // auto obj = TaggedValue::Dict(); - Object obj; - obj.Set( - String("_module"), - Handle(impl::TaggedValue::Capsule(new mlir::OwningOpRef( - std::move(module_or).value())))); - - auto get_string = [](Object self) { - auto ref = self.Get(String("_module")) - ->cast*>(); - return String(tensorflow::MlirModuleToString(ref->get(), false).c_str()); - }; - obj.Set(String("ToString"), Callable(TFLIB_CALLABLE_ADAPTOR(get_string))); - - return obj; -} - -None SaveModule(Object self, Object module, String directory) { - // TODO(b/190835292): Implement save. - return None(); -} - -None Transform(Object self, Object module, List passes) { - // TODO(b/190835292): Implement save. - return None(); -} - -Object MLIR() { - Object obj; - obj.Set(String("LoadSavedModel"), - Callable(TFLIB_CALLABLE_ADAPTOR(LoadModule))); - obj.Set(String("SaveSavedModel"), - Callable(TFLIB_CALLABLE_ADAPTOR(SaveModule))); - obj.Set(String("_context"), - Handle(impl::TaggedValue::Capsule(new mlir::MLIRContext()))); - return obj; -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/mlir/mlir_transform.h b/tensorflow/cc/experimental/libtf/mlir/mlir_transform.h deleted file mode 100644 index bd5ec58c29fa56..00000000000000 --- a/tensorflow/cc/experimental/libtf/mlir/mlir_transform.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MLIR_MLIR_TRANSFORM_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MLIR_MLIR_TRANSFORM_H_ - -#include "tensorflow/cc/experimental/libtf/object.h" - -namespace tf { -namespace libtf { - -// Returns a MLIR object with methods that can be used to load/save saved -// models, and also do transformations. -Object MLIR(); - -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MLIR_MLIR_TRANSFORM_H_ diff --git a/tensorflow/cc/experimental/libtf/module.cc b/tensorflow/cc/experimental/libtf/module.cc deleted file mode 100644 index b2102dc466edd6..00000000000000 --- a/tensorflow/cc/experimental/libtf/module.cc +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/module.h" - -#include - -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" -namespace tf { -namespace libtf { -namespace impl { - -using tensorflow::libexport::TFPackage; -using tf::libtf::runtime::Runtime; - -// TODO(danielellis): Fill in with implementations. - -// Builds a vector of runtime representations of `SavedObject`s from a -// SavedModel. These are returned as a flat list. The full hierarchy building -// and initialization should be done in a later pass. -absl::StatusOr> BuildObjects(TFPackage& tf_package) { - std::vector objects; - const tensorflow::SavedObjectGraph object_graph = tf_package.GetObjectGraph(); - for (auto& node : object_graph.nodes()) { - if (node.kind_case() == tensorflow::SavedObject::kUserObject) { - absl::StatusOr result = BuildSavedUserObject(node); - if (result.ok()) { - objects.push_back(*result); - } else { - return result.status(); - } - } - } - return objects; -} - -absl::StatusOr BuildSavedUserObject( - tensorflow::SavedObject saved_object_proto) { - if (saved_object_proto.kind_case() != tensorflow::SavedObject::kUserObject) { - return tensorflow::errors::InvalidArgument("Not a UserObject."); - } - - std::string identifier = saved_object_proto.user_object().identifier(); - if (identifier == "trackable_list_wrapper") { - tf::libtf::List user_list; - // TODO(b/191267013): Populate with values. - return user_list; - } - if (identifier == "trackable_dict_wrapper") { - tf::libtf::Dictionary user_dict; - // TODO(b/191267013): Populate with values. - return user_dict; - } - if (identifier == "signature_map") { - tf::libtf::Dictionary signature_map; - // TODO(b/191267013): Populate with values. - return signature_map; - } - if (identifier == "_generic_user_object") { - tf::libtf::Dictionary user_object; - // TODO(b/191267013): Populate with values. - return user_object; - } - return tensorflow::errors::Unimplemented(absl::StrCat( - "UserObject with identifier '", identifier, "' not implemented.")); -} - -// Register all available concrete functions from a SavedModel into a runtime. -tensorflow::Status RegisterConcreteFunctions(Runtime runtime, - TFPackage tf_package) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -// Initialize any variables found in the SavedModel and attach them to the -// appropriate object representation in the runtime. -tensorflow::Status InitializeVariables(Runtime runtime, TFPackage tf_package, - std::vector objects) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -// Register concrete functions with their associated polymorphic functions. -tensorflow::Status SetupPolymorphicFunctions(Runtime runtime, - TFPackage tf_package, - std::vector objects) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -// Register any captures with their associated higher-level functions. -tensorflow::Status SetupFunctionCaptures(Runtime runtime, TFPackage tf_package, - std::vector objects) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -// Takes a flat list of Handles and builds them into the hierarchical -// representation defined by the SavedModel. -absl::StatusOr BuildObjectHierarchy(TFPackage tf_package, - std::vector objects) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -absl::StatusOr BuildProgram(Runtime runtime, TFPackage& tf_package) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/module.h b/tensorflow/cc/experimental/libtf/module.h deleted file mode 100644 index c857f702888a82..00000000000000 --- a/tensorflow/cc/experimental/libtf/module.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MODULE_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MODULE_H_ - -#include "tensorflow/cc/experimental/libexport/load.h" -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" - -namespace tf { -namespace libtf { -namespace impl { - -// The main interface for taking a serialized saved model and getting back a -// fully-built model. -// -// Implementation steps: -// -// 1) For each function def in the SavedModel, register it with the runtime. -// 2) For each object in the object graph def, build it. -// 3) For each variable stored in the checkpoint in the SavedModel, -// restore it, and attach it to the associated variable object. -// 4) For each polymorphic function, associate it with the appropriate -// concrete function(s). -// 5) For each function with captures, bind the appropriate objects as -// captured inputs. -// 6) Take the fully-prepared objects, and build them into a hierarchy. -// 7) Return the prepared model. - -// Converts a SavedUserObject into its corresponding data structure. -// TODO(b/185579152): This method returns empty data structures currently. -absl::StatusOr BuildSavedUserObject( - tensorflow::SavedObject saved_object_proto); - -// "Build" all SavedObjects, ie convert from proto to their runtime -// representation, in the tf_package. -absl::StatusOr> BuildObjects( - tensorflow::libexport::TFPackage& tf_package); - -// Convert tf_package to a program in the runtime. -absl::StatusOr BuildProgram( - runtime::Runtime runtime, tensorflow::libexport::TFPackage& tf_package); - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MODULE_H_ diff --git a/tensorflow/cc/experimental/libtf/object.cc b/tensorflow/cc/experimental/libtf/object.cc deleted file mode 100644 index a5f4882d532fce..00000000000000 --- a/tensorflow/cc/experimental/libtf/object.cc +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Implementation of objects. -#include "tensorflow/cc/experimental/libtf/object.h" - -#include - -namespace tf { -namespace libtf { - -const String& Object::ParentKey() { - static const String* key = new String("__parent__"); - return *key; -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/object.h b/tensorflow/cc/experimental/libtf/object.h deleted file mode 100644 index bebf28b6d496d3..00000000000000 --- a/tensorflow/cc/experimental/libtf/object.h +++ /dev/null @@ -1,709 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -/// @file object.h -/// @brief Object hierarchy for the TensorFlow C++ API. All "objects" are -/// derived from the `Handle` class. Instances of `Handle` are referred to as -/// "handles". All handles have a tagged value. -/// -/// Example Usage: -/// Object runtime = GetRuntime("tfrt"); -/// Object module = runtime.Get("Import")("cool_mobilenet") -/// runtime.Get("Tensor")(Tuple(5,5,5), 3.3); -/// Object test = CreateModule("test"); -/// test.Set("cool_function", callable); -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_ - -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" - -namespace tf { -namespace libtf { - -using TaggedValue = impl::TaggedValue; -class Handle; - -// Necessary forward declare. -template -Handle Convert(T value); - -/// @brief Base Handle class that wraps TaggedValue data. All data creation and -/// manipulation should done using Handle instances. Users should not be working -/// with TaggedValues directly. - -/// The `Handle` class contains a TaggedValue in the `value_` member, which -/// contains the underlying data. An object belonging to `Foo`, a derived class -/// of `Handle`, can be referred to as a `Foo` handle. -/// -/// It is important that all derived classes do not add any new data fields. -/// This ensures that it is always safe to slice down (i.e. assign an object of -/// a derived class to the base class) a handle to the base Handle class. -class Handle { - public: - /// Default constructor, which initializes a TaggedValue with type NONE. - Handle() : value_(TaggedValue::None()) {} - - public: - /// Constructs a handle from a TaggedValue. - explicit Handle(TaggedValue value) : value_(std::move(value)) {} - // explicit Handle(TaggedValue value, Handle* class_input) - // : value_(std::move(value)), class_(class_input) {} - // const Handle& type() { return *class_; } - - protected: - /// The wrapped TaggedValue. - TaggedValue value_; - // effectively a "weak reference" to intern'd class value. - // types are compared by comparing pointer values here. - // Handle* class_; // effectively a "weak reference" to intern'd class value. - - /// The Integer handle. - friend class Integer; - /// The Float handle. - friend class Float; - /// The String handle. - friend class String; - /// The Object handle. - friend class Object; - /// The List handle. - friend class List; - /// The Dictionary handle. - friend class Dictionary; - /// The Tuple handle. - friend class Tuple; - /// The Callable handle. - friend class Callable; - /// The Tensor handle. - friend class Tensor; - /// Converts a Handle instance to an instance of a derived class `T`. - template - friend tensorflow::StatusOr Cast(Handle handle); - /// Infrastructure for converting a TaggedValue tuple function signature to an - /// unpacked variable list. - template - friend class UneraseCallHelper; -}; - -// Forward declare. -template -tensorflow::StatusOr Cast(Handle handle); - -/// @brief The None class for holding TaggedValues of type NONE. -class None final : public Handle { - public: - /// Creates a handle that wraps a NONE TaggedValue. - None() : Handle(TaggedValue::None()) {} - - private: - explicit None(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The String class for holding TaggedValues of type STRING. -class String final : public Handle { - public: - /// Creates a handle that wraps a STRING TaggedValue. - explicit String(const char* s) : Handle(TaggedValue(s)) {} - /// Returns the underlying TaggedValue string. - const char* get() const { return value_.s(); } - - private: - // Private since it is in general unsafe. - explicit String(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The `Object` class modeled after Python "objects". -/// -/// An `Object` uses a TaggedValue dictionary to store its attributes. The -/// "__parent__" attribute is reserved. -class Object : public Handle { - public: - /// Constructs a handle that acts as an object. - Object() : Handle(TaggedValue::Dict()) {} - /// Retrieves the key of the object's parent. - static const String& ParentKey(); - - /// @brief Gets an object member attribute`key`. - /// - /// If the `key` is not found in the object, the object's "__parent__" - /// attribute is then searched. - /// - /// @tparam T The desired return type. - /// @param key The key to look up. - /// @return `StatusOr` wrapping the key's value. - template - tensorflow::StatusOr Get(const String& key) { - auto& dict = value_.dict(); - auto it = dict.find(key.value_); - if (it != dict.end()) { - return Cast(Handle(it->second)); - } else { - // Lookup in object stored by reference in attribute "__parent__". - auto it_class = dict.find(ParentKey().value_); - if (it_class != dict.end()) { - auto& class_dict_maybe = it_class->second; - if (class_dict_maybe.type() == TaggedValue::DICT) { - auto& dict = class_dict_maybe.dict(); - auto it = dict.find(key.value_); - if (it != dict.end()) { - return Cast(Handle(it->second)); - } - } - } - } - return absl::NotFoundError("Key not in dictionary."); - } - - /// Sets `key` attribute with the underlying value of `h`. - void Set(const String& key, Handle h) { - value_.dict()[key.value_] = std::move(h.value_); - } - - /// Removes `key` from the object's attributes. - void Unset(const String& key) { value_.dict().erase(key.value_); } - // TODO(b/): Adding dir() is in the future. - private: - // Private since it is in general unsafe. - explicit Object(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The Dictionary class for holding TaggedValues of type DICT. -class Dictionary final : public Handle { - public: - /// Constructs a handle that wraps a DICT TaggedValue. - Dictionary() : Handle(TaggedValue::Dict()) {} - // TODO(aselle): make this private to preserve invariant. - - /// Retrieves `key` with type `T`. - template - tensorflow::StatusOr Get(const Handle& key) { - auto it = value_.dict().find(key.value_); - if (it != value_.dict().end()) return Cast(Handle(it->second)); - return absl::NotFoundError("Key not in dictionary."); - } - /// Sets `key` with value `value`. - void Set(const String& key, Handle value) { - value_.dict()[key.value_] = std::move(value.value_); - } - /// Sets `key` with value `value`. - void Set(const Handle& key, Handle value) { - value_.dict()[key.value_] = std::move(value.value_); - } - /// Retrieves size of dictionary. - size_t size() const { return value_.dict().size(); } - - private: - // Private since it is in general unsafe. - explicit Dictionary(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The Integer class for holding TaggedValues of type INT. -class Integer final : public Handle { - public: - /// Creates a handle that wraps an INT TaggedValue. - explicit Integer(Handle h) : Handle(h.value_) {} - /// Creates a handle that wraps an INT TaggedValue. - explicit Integer(int64_t i) : Handle(TaggedValue(i)) {} - /// Retrieves the underlying integer value. - int64_t get() const { return value_.i64().get(); } - - private: - // Private since it is in general unsafe. - explicit Integer(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The Float class for holding TaggedValues of type FLOAT. -class Float final : public Handle { - public: - /// Constructs a Float handle that wraps a FLOAT TaggedValue. - explicit Float(Handle h) : Handle(h.value_) {} - /// Constructs a Float handle that wraps a FLOAT TaggedValue. - explicit Float(float i) : Handle(TaggedValue(i)) {} - /// Retrieves the underlying float value. - float get() const { return value_.f32().get(); } - - private: - // Private since it is in general unsafe. - explicit Float(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The Tensor class for holding TaggedValues of type TENSOR. -class Tensor final : public Handle { - public: - /// Constructs a Tensor handle from a Handle that wraps a TENSOR TaggedValue. - explicit Tensor(Handle h) : Handle(h.value_) {} - - /// @brief Retrieves the value of the Tensor handle. - - /// @param data Buffer in which to copy contents of the handle. - /// @throws InvalidArgument Raises error if `data` is of invalid size. - template - tensorflow::Status GetValue(absl::Span data) const; - - private: - // Private since it is in general unsafe. - explicit Tensor(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -template -tensorflow::Status Tensor::GetValue(absl::Span data) const { - tensorflow::AbstractTensorPtr t; - { - const auto abstract_t = value_.tensor().get(); - if (!tensorflow::ImmediateExecutionTensorHandle::classof(abstract_t)) { - return absl::InvalidArgumentError( - "Attempting to get value of non eager tensor."); - } - auto imm_t = - static_cast(abstract_t); - tensorflow::Status status; - t.reset(imm_t->Resolve(&status)); - if (!status.ok()) { - return status; - } - } - if (data.size() != t->NumElements()) { - return tensorflow::errors::InvalidArgument(absl::StrCat( - "Mismatched number of elements: \n", "Expected: ", data.size(), "\n", - "Actual: ", t->NumElements(), "\n")); - } - memcpy(data.data(), t->Data(), t->ByteSize()); - return absl::OkStatus(); -} - -/// @brief The Tuple class for holding TaggedValues of type TUPLE. -class Tuple : public Handle { - public: - /// Constructs a Tuple handle. - template - explicit Tuple(T... args) : Handle(TaggedValue::Tuple()) { - add(args...); - } - - /// Retrieves value at index `i`. - template - tensorflow::StatusOr Get(size_t i) { - if (i >= value_.tuple().size()) - return absl::InvalidArgumentError("Out of bounds index."); - return Cast(Handle(value_.tuple()[i])); - } - - /// Retrieves number of elements. - size_t size() const { return value_.tuple().size(); } - - private: - // Add an item to a tuple. Should only be done by special construction - // like Callables (which are a friend). - void add() {} - template - void add(T arg, T2... args) { - value_.tuple().emplace_back(Convert(arg).value_); - add(args...); - } - - // Private since it is in general unsafe. - explicit Tuple(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The List class for holding TaggedValues of type LIST. -class List final : public Handle { - public: - /// Constructs a List handle. - template - explicit List(T... args) : Handle(TaggedValue::List()) {} - /// Retrieves value at index `i`. - template - tensorflow::StatusOr Get(size_t i) { - if (i >= size()) { - return absl::InvalidArgumentError("Out of bounds index."); - } - return Cast(Handle(value_.list()[i])); - } - - /// Sets value `h` at index `i`. - tensorflow::Status Set(size_t i, Handle h) { - if (i >= size()) { - return absl::InvalidArgumentError("Out of bounds index."); - } - value_.list()[i] = std::move(h.value_); - return absl::OkStatus(); - } - - /// Appends `arg` to list. - template - void append(T arg) { - value_.list().emplace_back(Convert(arg).value_); - } - /// Retrieves size of list. - size_t size() const { return value_.list().size(); } - - private: - // Private since it is in general unsafe. - explicit List(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The `KeywordArg` class for storing keyword arguments as name value -/// pairs. -class KeywordArg { - public: - explicit KeywordArg(const char* s) : key_(String(s)), value_() {} - - template - KeywordArg& operator=(const T obj) { - value_ = Convert(obj); - return *this; - } - - friend class Callable; - - private: - String key_; - Handle value_; -}; - -/// @brief The Callable class for creating callables. -class Callable final : public Handle { - private: - // Collect arguments for call - void CollectArgs(Tuple& args, Dictionary& kwargs, int idx) {} - template - void CollectArgs(Tuple& args, Dictionary& kwargs, int idx, T v, - Types... vars) { - const Handle& o = Convert(v); - args.value_.tuple().emplace_back(o.value_); - CollectArgs(args, kwargs, idx + 1, vars...); - } - template - void CollectArgs(Tuple& args, Dictionary& kwargs, int idx, KeywordArg v, - Types... vars) { - kwargs.Set(v.key_, v.value_); - CollectArgs(args, kwargs, idx + 1, vars...); - } - - public: - /// @brief Calls the wrapped TaggedValue function on a variable argument - /// list. - template - tensorflow::StatusOr Call(Types... vars) { - Dictionary kwargs = Dictionary(); - Tuple args; - CollectArgs(args, kwargs, 0, vars...); - auto maybe_value = - value_.func()(std::move(args.value_), std::move(kwargs.value_)); - if (!maybe_value.ok()) { - return maybe_value.status(); - } - return Cast(Handle(maybe_value.value())); - } - - public: - // TODO(aselle): need to find a way to write test w/o this being public. - // Private since it is in general unsafe. - explicit Callable(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -namespace internal { -/// @brief The Capsule class for holding pointers. -class Capsule final : public Handle { - public: - /// Statically cast the TaggedValue capsule to type `T`. - template - T cast() { - return static_cast(value_.capsule()); - } - - private: - // Private since it is in general unsafe. - explicit Capsule(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr tf::libtf::Cast(Handle handle); -}; -} // namespace internal - -/// @defgroup Util Functions for type conversion -/// -/// @brief Functions for retrieving and converting Handle types. -/// @{ - -/// Retrieves tagged type of `T` handle. -template -inline TaggedValue::Type TypeToTaggedType() {} -/// Retrieves tagged type of base class handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::NONE; -} -/// Retrieves tagged type of None handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::NONE; -} -/// Retrieves tagged type of String handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::STRING; -} -/// Retrieves tagged type of Callable handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::FUNC; -} -/// Retrieves tagged type of Integer handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::INT64; -} -/// Retrieves tagged type of Float handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::FLOAT32; -} -/// Retrieves tagged type of Object handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::DICT; -} -/// Retrieves tagged type of Dictionary handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::DICT; -} -/// Retrieves tagged type of List handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::LIST; -} -/// Retrieves tagged type of Tensor handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::TENSOR; -} -/// Retrieves tagged type of Capsule handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::CAPSULE; -} -// TODO(unknown): fully populate - -/// @brief Casts a handle to type `T` -/// -/// @param handle The handle to cast. -/// @tparam T The target handle type. -/// @exception InvalidArgument Raises error if the underlying TaggedValue type -/// of `handle` is not equivalent to `T`. -template -tensorflow::StatusOr Cast(Handle handle) { - if (handle.value_.type() == TypeToTaggedType() || - std::is_same::value) - return T((std::move(handle.value_))); - return absl::InvalidArgumentError("Incompatible cast."); -} - -// Converters for C++ primitives like float and int to handles. Allows callable -// calls and list appends to be more idiomatic. - -/// Converts a C++ const char* to a String handle. -template <> -inline Handle Convert(const char* value) { - return String(value); -} -/// Converts a C++ int32_t to an Integer handle. -template <> -inline Handle Convert(int32_t value) { - return Integer(value); -} -/// Converts a C++ int64_t to an Integer handle. -template <> -inline Handle Convert(int64_t value) { - return Integer(value); -} -/// Converts a C++ float to an Integer handle. -template <> -inline Handle Convert(float value) { - return Float(value); -} -/// Converts a value with primitive type T to a Handle. -template -inline Handle Convert(T value) { - return Handle(std::move(value)); -} - -/// @} - -// in the future it will be possible to make additional hard typed APIs -// by generating code by introspecting objects. - -// Here's a code gen'd example -// The dynamic structure can be turned into it. -/* -class Tf : Object { - Tensor ones(Tensor shape, String dtype); - // ... -} -*/ - -// Adapter to allow users to define Callables. Use TFLIB_CALLABLE_ADAPTOR -// instead. -template -class CallableWrapper; - -// Template extracts arguments from a lambda function. This base -// class definition inherits from a another specialization in order. We use -// this top level template to extract the function pointer associated with -// the created lambda functor class. -template -class CallableWrapperUnpackArgs - : public CallableWrapperUnpackArgs { - public: - CallableWrapperUnpackArgs(TLambda fn, const char* name) - : CallableWrapperUnpackArgs(fn, name) {} -}; - -// This specialization unpacks the arguments from a normal function pointer. -template -class CallableWrapperUnpackArgs - : public CallableWrapper { - using Fn = TReturn (*)(TFuncArgs...); - - public: - CallableWrapperUnpackArgs(Fn fn, const char* name) - : CallableWrapper(fn, name) {} -}; - -// This is the second stage of extracting the arguments from lambda function. -// NOTE: CallableWrapper's first template argument is the type of the -// function or functor (not the member pointer). -template -class CallableWrapperUnpackArgs - : public CallableWrapper { - using Fn = TClass; - - public: - CallableWrapperUnpackArgs(Fn fn, const char* name) - : CallableWrapper(fn, name) {} -}; - -template -class UneraseCallHelper; - -// UneraseCallHelper::Call allows transforming all the incoming arguments -// from a TaggedValue tuple to a variadic list of args. The class template -// starts as a list of argument types and ends empty. The static member -// template starts empty and ends with the unerased types of the signature. - -// Base case (all arguments are processed, so call the function TFunc. -template -class UneraseCallHelper { - public: - template - static absl::StatusOr Call(const char* name, Fn functor_, - int argument_index, - const TaggedValue& args_in, - ArgsOut... args) { - // Call concrete type function - TReturn ret = functor_(args...); - return ret.value_; - } -}; - -// Unpack a single argument case. Each argument is then cast. -template -class UneraseCallHelper { - public: - template - static absl::StatusOr Call(const char* name, Fn fn, - int argument_index, - TaggedValue& args_in, - TArgsOut... args) { - Handle h(std::move(args_in.tuple()[argument_index])); - tensorflow::StatusOr x = Cast(std::move(h)); - if (!x.ok()) - return absl::InvalidArgumentError( - absl::StrCat(std::string("Function ") + name + " Arg " + - std::to_string(argument_index) + - " cannot be cast to desired signature type ")); - return UneraseCallHelper::Call( - name, fn, argument_index + 1, args_in, args..., *x); - } -}; - -// Template specialization that allows extracting arguments from a C function -// pointer. -template -class CallableWrapper { - private: - Fn functor_; - const char* name_; - - public: - explicit CallableWrapper(Fn fn, const char* name) - : functor_(fn), name_(name) {} - - // Entry point of the Adaptor functor. Note args, and kwargs are attempted - // to be moved. - absl::StatusOr operator()(TaggedValue args, TaggedValue kwargs) { - constexpr size_t argument_count = sizeof...(TFuncArgs); - if (argument_count != args.tuple().size()) - return absl::InvalidArgumentError( - absl::StrCat(std::string("Function ") + name_ + " expected " + - std::to_string(argument_count) + " args.")); - return UneraseCallHelper::Call(name_, functor_, - 0, args); - } -}; - -// Wrap a function that uses object handles as arguments and return types -// with one that takes TaggedValues. For example: -// Tuple Pack(Integer, Float, String); -// TaggedValue callable = TFLIB_CALLABLE_ADAPTOR(Pack); -#define TFLIB_CALLABLE_ADAPTOR(x) ::tf::libtf::CreateCallableAdaptor(x, #x) - -template -TaggedValue CreateCallableAdaptor(TF x, const char* name) { - return TaggedValue((CallableWrapperUnpackArgs(x, name))); -} - -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_ diff --git a/tensorflow/cc/experimental/libtf/runtime/BUILD b/tensorflow/cc/experimental/libtf/runtime/BUILD deleted file mode 100644 index b20c0e6e3f903b..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "runtime", - srcs = [ - "runtime.cc", - ], - hdrs = [ - "runtime.h", - ], - deps = [ - "//tensorflow/c:tf_datatype", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c:tf_status_internal", - "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:c_api_experimental", - "//tensorflow/c/eager:graph_function", - "//tensorflow/c/eager:immediate_execution_context", - "//tensorflow/c/eager:tfe_context_internal", - "//tensorflow/cc/experimental/libexport:load", - "//tensorflow/cc/experimental/libtf", - "//tensorflow/cc/experimental/libtf:function", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:statusor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/tensorflow/cc/experimental/libtf/runtime/core/BUILD b/tensorflow/cc/experimental/libtf/runtime/core/BUILD deleted file mode 100644 index 09106ea8cb75b4..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/core/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "core", - srcs = [ - "core.cc", - ], - hdrs = [ - "core.h", - ], - deps = [ - "//tensorflow/c:tf_status_internal", - "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:tfe_context_internal", - "//tensorflow/cc/experimental/libtf", - "//tensorflow/cc/experimental/libtf/runtime", - ], -) diff --git a/tensorflow/cc/experimental/libtf/runtime/core/core.cc b/tensorflow/cc/experimental/libtf/runtime/core/core.cc deleted file mode 100644 index 5d23c7aa0920da..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/core/core.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/runtime/core/core.h" - -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/tfe_context_internal.h" -#include "tensorflow/c/tf_status.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" - -namespace tf { -namespace libtf { -namespace runtime { -namespace core { - -runtime::Runtime Runtime() { - TaggedValue ctx_capsule; - TFE_Context* ctx; - TFE_ContextOptions* ctx_options = TFE_NewContextOptions(); - TFE_ContextOptionsSetDevicePlacementPolicy(ctx_options, - TFE_DEVICE_PLACEMENT_WARN); - TF_Status* status = TF_NewStatus(); - ctx = TFE_NewContext(ctx_options, status); - TF_DeleteStatus(status); - TFE_DeleteContextOptions(ctx_options); - return runtime::Runtime(tensorflow::unwrap(ctx)); -} - -} // namespace core -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/runtime/core/core.h b/tensorflow/cc/experimental/libtf/runtime/core/core.h deleted file mode 100644 index 12ced72eccb79b..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/core/core.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_CORE_CORE_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_CORE_CORE_H_ - -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" - -namespace tf { -namespace libtf { -namespace runtime { -namespace core { - -// Instantiate a Core Runtime. -Runtime Runtime(); - -} // namespace core -} // namespace runtime -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_CORE_CORE_H_ diff --git a/tensorflow/cc/experimental/libtf/runtime/runtime.cc b/tensorflow/cc/experimental/libtf/runtime/runtime.cc deleted file mode 100644 index 460964be0f4f29..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/runtime.cc +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/graph_function.h" -#include "tensorflow/c/eager/immediate_execution_context.h" -#include "tensorflow/c/eager/tfe_context_internal.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/cc/experimental/libexport/load.h" -#include "tensorflow/cc/experimental/libtf/function.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" -#include "tensorflow/core/protobuf/struct.pb.h" -#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" - -namespace tf { -namespace libtf { -namespace runtime { - -using tensorflow::AbstractContext; -using tensorflow::AbstractFunctionPtr; -using tensorflow::DataType; -using tensorflow::FunctionDef; -using tensorflow::PartialTensorShape; -using tensorflow::SavedConcreteFunction; -using tensorflow::SavedObjectGraph; -using tensorflow::Status; -using tensorflow::StructuredValue; -using tensorflow::TensorSpecProto; -using tensorflow::libexport::TFPackage; -using tensorflow::protobuf::RepeatedPtrField; -using tensorflow::tracing::graph::GraphFunction; - -TaggedValue MakeCallable(const std::string& fn_name, Function fn, - AbstractContext* ctx) { - auto CallFn = [fn_name, fn, ctx](TaggedValue args_, - TaggedValue kwargs_) -> TaggedValue { - std::cout << "Calling " << fn_name << std::endl; - tensorflow::StatusOr v = fn.Execute(ctx, args_); - return v.value(); - }; - return TaggedValue(CallFn); -} - -// Import a module from a saved model. -// -// Returns a TaggedValue::Dict. All functions found on the root of the module -// will be attached as callables to this TaggedValue. -// -// `name` specifies the full path to the saved model. -// -// `ctx` should outlive the lifetime of the module. -static tensorflow::StatusOr ImportModule(String name, - AbstractContext* ctx) { - // Load the serialized model. - tensorflow::StatusOr tf_package = TFPackage::Load(name.get()); - if (!tf_package.status().ok()) { - return tf_package.status(); - } - TaggedValue module = TaggedValue::Dict(); - - // Initialize concrete function traces. - const RepeatedPtrField function_defs = - tf_package->GetFunctionDefs(); - absl::flat_hash_map traces; - for (auto& fdef : function_defs) { - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - traces[fdef.signature().name()] = trace; - } - - // Setup polymorphic functions and wrap in Callables. - // - // For each child of the root, check what type it is. If it's a - // SavedFunction, attach that function to the module as a Callable. - const SavedObjectGraph object_graph = tf_package->GetObjectGraph(); - auto& nodes = object_graph.nodes(); - // Get a map of the concrete functions to their input / output signatures. - auto& concrete_functions = object_graph.concrete_functions(); - auto& root = nodes.at(0); - for (auto& child : root.children()) { - // The child's name describes the name of the edge that connects to the - // parent object. This name will be the name of the object stored in the - // generated module. - auto& child_node = nodes.at(child.node_id()); - auto child_name = child.local_name().c_str(); - - if (child_node.kind_case() == tensorflow::SavedObject::kFunction) { - Function tf_function; - for (const std::string& fn_name : - child_node.function().concrete_functions()) { - // Setup input signature. - // - // For now, we have to do a lot of manual digging through these and - // assume they are tensorspecs. Once TODO(b/190203981) is done, we - // should be able to pass along the `StructuredValue`s to an API in a - // much cleaner way. - // - // TODO(b/190206621): Implement API for inspecting signatures - SavedConcreteFunction saved_concrete_function = - concrete_functions.at(fn_name); - TaggedValue input_signature = TaggedValue::Tuple(); - const RepeatedPtrField& args = - saved_concrete_function.canonicalized_input_signature() - .tuple_value() - .values(0) - .tuple_value() - .values(); - for (const StructuredValue& arg : args) { - PartialTensorShape shape = arg.tensor_spec_value().shape(); - DataType dtype = arg.tensor_spec_value().dtype(); - TaggedValue tensor_spec(shape, dtype); - input_signature.tuple().emplace_back(tensor_spec); - } - - // Setup output signature. - TensorSpecProto output_tensor_spec_proto = - saved_concrete_function.output_signature().tensor_spec_value(); - PartialTensorShape output_shape = output_tensor_spec_proto.shape(); - DataType output_dtype = output_tensor_spec_proto.dtype(); - TaggedValue output_tensor_spec(output_shape, output_dtype); - - // Register the function trace. - // - // This does *not* currently register the function with the runtime. - // Instead, we're registering JIT at call time. This is likely - // something that we'll change in TODO(b/190070277). - auto& trace = traces[fn_name]; - Status status = tf_function.RegisterTrace( - std::move(trace), input_signature, output_tensor_spec); - } - TaggedValue callable = MakeCallable(child_name, tf_function, ctx); - module.dict()[TaggedValue(child_name)] = callable; - } - } - return module; -} - -// Instantiate the Runtime, creating relevant Callables for later use. -Runtime::Runtime(AbstractContext* ctx) { - TaggedValue ctx_capsule = - TaggedValue::Capsule(static_cast(ctx), [](void* p) { - auto ctx = static_cast(p); - ctx->Release(); - }); - Set(String("ctx"), Handle(ctx_capsule)); - auto Load = [](Object self, String name) -> Object { - auto ctx_capsule = self.Get(String("ctx")).value(); - auto ctx = ctx_capsule.cast(); - // TODO(b/191689645): This needs to do error handling better. - return *Cast(Handle(*ImportModule(name, ctx))); - }; - - Set(String("Load"), Callable(TFLIB_CALLABLE_ADAPTOR(Load))); -} - -tensorflow::StatusOr Runtime::Load(const String& name) { - return Get(String("Load"))->Call(*this, name); -} - -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/runtime/runtime.h b/tensorflow/cc/experimental/libtf/runtime/runtime.h deleted file mode 100644 index 5c3ac94fbe03c6..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/runtime.h +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_RUNTIME_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_RUNTIME_H_ - -#include - -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/tfe_context_internal.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/statusor.h" - -namespace tf { -namespace libtf { -namespace runtime { - -/// @brief A runtime object capable of loading modules and executing functions. -/// -/// It is the responsibility of the owner of the Runtime to keep it alive longer -/// than all imported modules. -class Runtime : public Object { - public: - // TODO(b/191264214): Remove need for AbstractContext - explicit Runtime(tensorflow::AbstractContext* ctx); - /// @brief Loads the module indicated by `name` and returns it. - /// - /// @param name The name of the module / file path to load - /// @return An `Object` representing the module, if successful. Otherwise, a - /// non-ok `absl::Status`. - tensorflow::StatusOr Load(const String& name); - // TODO(b/186787000): Loading a module with identically-named functions as - // a previously loaded module results in undefined behavior. This - // functionality will be supported in the future. - - // Create a host tensor and copy data into it. - // - // Raises an error if shape or dtype are incompatible with T. - // TODO(b/189458441): Update this when we decide on the representation of - // shape and dtype in this API. - // Disclaimer: This API is subject to change as we add support for creating - // device tensors b/187222691 and enable buffer re-use b/187223179. - // TODO(b/190715501): Make this available via a soft API as well. - template - tensorflow::StatusOr CreateHostTensor(absl::Span shape, - int dtype, - absl::Span data); -}; - -template -tensorflow::StatusOr Runtime::CreateHostTensor( - absl::Span shape, int dtype, absl::Span data) { - size_t num_elements = 1; - for (int dim = 0; dim < shape.size(); dim++) { - if (shape[dim] < 0) { - return tensorflow::errors::InvalidArgument(absl::StrCat( - "Shape must be fully-defined, got: shape[", dim, "] = ", shape[dim])); - } - num_elements *= shape[dim]; - } - if (data.size() != num_elements) { - return tensorflow::errors::InvalidArgument(absl::StrCat( - "Mismatched shape and data size: \n", "Shape num_elements: ", - num_elements, "\n", "Data size: ", data.size(), "\n")); - } - auto maybe_capsule = Get(String("ctx")); - if (!maybe_capsule.status().ok()) { - return maybe_capsule.status(); - } - auto capsule = maybe_capsule.value(); - auto ctx = capsule.cast(); - tensorflow::AbstractTensorPtr t( - ctx->CreateTensor(static_cast(dtype), shape)); - // TODO(srbs): This is still a weak check. Check that dtype and T are - // compatible. - if (t->ByteSize() != sizeof(T) * data.size()) { - return tensorflow::errors::InvalidArgument(absl::StrCat( - "Invalid number of bytes in data buffer\n", "Expected bytes: ", - t->ByteSize(), "\n", "Actual bytes: ", sizeof(T) * data.size())); - } - memcpy(t->Data(), data.data(), t->ByteSize()); - return Tensor(Convert(TaggedValue( - impl::TaggedValueTensor(ctx->CreateLocalHandle(t.get()), false)))); -} - -} // namespace runtime -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_RUNTIME_H_ diff --git a/tensorflow/cc/experimental/libtf/tests/function_test.cc b/tensorflow/cc/experimental/libtf/tests/function_test.cc deleted file mode 100644 index fa1f21389df969..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/function_test.cc +++ /dev/null @@ -1,294 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/function.h" - -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_function.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/graph_function.h" -#include "tensorflow/c/eager/unified_api_testutil.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -using tensorflow::AbstractContext; -using tensorflow::AbstractContextPtr; -using tensorflow::AbstractFunctionPtr; -using tensorflow::AbstractTensorHandle; -using tensorflow::DT_FLOAT; -using tensorflow::FunctionDef; -using tensorflow::FunctionDefHelper; -using tensorflow::PartialTensorShape; -using tensorflow::Status; -using tensorflow::StatusOr; -using tensorflow::TF_StatusPtr; -using tensorflow::tracing::graph::GraphFunction; - -class FunctionTest - : public ::testing::TestWithParam> { - public: - template - impl::TaggedValueTensor CreateScalarTensor(T val) { - AbstractTensorHandle* raw = nullptr; - Status s = TestScalarTensorHandle(ctx_.get(), val, &raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - return impl::TaggedValueTensor(raw, /*add_ref=*/false); - } - - bool UseTfrt() { return std::get<1>(GetParam()); } - - AbstractContextPtr ctx_; - - protected: - void SetUp() override { - // Set the tracing impl, GraphDef vs MLIR. - TF_StatusPtr status(TF_NewStatus()); - TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - - // Set the runtime impl, Core RT vs TFRT. - AbstractContext* ctx_raw = nullptr; - s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - ctx_.reset(ctx_raw); - } -}; - -// TODO(b/191361582): Use Abstract* APIs for building functions so that we can -// test with MLIR. -FunctionDef SquareFunc() { - return FunctionDefHelper::Define( - // Function Name - "SquareFunc", - // Args - {"x: float"}, - // Returns - {"y: float"}, - // Attr def - {}, - // Nodes - {{/*ret=*/{"y"}, - /*op=*/"Square", - /*arg=*/{"x"}, - /*attr=*/{{"T", DT_FLOAT}}, - /*dep=*/{}, - /*device=*/"", - /*name=*/"square"}}); -} - -FunctionDef AddFunc() { - return FunctionDefHelper::Define( - // Function Name - "AddFunc", - // Args - {"x: float", "y: float"}, - // Returns - {"z: float"}, - // Attr def - {}, - // Nodes - {{/*ret=*/{"z"}, - /*op=*/"Add", - /*arg=*/{"x", "y"}, - /*attr=*/{{"T", DT_FLOAT}}, - /*dep=*/{}, - /*device=*/"", - /*name=*/"add"}}); -} - -FunctionDef IdentityNFunc() { - return FunctionDefHelper::Define( - // Function Name - "IdentityNFunc", - // Args - {"x: float", "y: float"}, - // Returns - {"u: float", "v: float"}, - // Attr def - {}, - // Nodes - {{/*ret=*/{"u", "v"}, - /*op=*/"IdentityN", - /*arg=*/{"x", "y"}, - /*attr=*/{{"T", tensorflow::DataTypeSlice({DT_FLOAT, DT_FLOAT})}}, - /*dep=*/{}, - /*device=*/""}}); -} - -template -void ExpectEquals(AbstractTensorHandle* t, T expected) { - TF_Tensor* result_t; - Status s = tensorflow::GetValue(t, &result_t); - ASSERT_TRUE(s.ok()) << s.message(); - auto value = static_cast(TF_TensorData(result_t)); - EXPECT_EQ(*value, expected); - TF_DeleteTensor(result_t); -} - -// TODO(srbs): Add tests for captures. -// TODO(srbs): Add tests for polymorphism (different shapes and dtypes). -TEST_P(FunctionTest, Square) { - // Construct a scalar. - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - FunctionDef fdef = SquareFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue signature(unknown_shape, DT_FLOAT); - Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args(std::move(x)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().message(); - const TaggedValue& result = v.value(); - AbstractTensorHandle* t = result.tensor().get(); - ExpectEquals(t, 4.0f); -} - -TEST_P(FunctionTest, Add) { - // Construct a scalar. - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - FunctionDef fdef = AddFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue tensor_spec(unknown_shape, DT_FLOAT); - TaggedValue input_signature = TaggedValue::Tuple(); - input_signature.tuple().emplace_back(tensor_spec); - input_signature.tuple().emplace_back(tensor_spec); - Status s = - tf_function.RegisterTrace(std::move(trace), input_signature, tensor_spec); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(x)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().message(); - const TaggedValue& result = v.value(); - ExpectEquals(result.tensor().get(), 4.0f); -} - -TEST_P(FunctionTest, IdentityN) { - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - impl::TaggedValueTensor y = CreateScalarTensor(4.0f); - FunctionDef fdef = IdentityNFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue tensor_spec(unknown_shape, DT_FLOAT); - TaggedValue signature = TaggedValue::Tuple(); - signature.tuple().emplace_back(tensor_spec); - signature.tuple().emplace_back(tensor_spec); - Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(y)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().message(); - const TaggedValue& result = v.value(); - ExpectEquals(result.tuple()[0].tensor().get(), 2.0f); - ExpectEquals(result.tuple()[1].tensor().get(), 4.0f); -} - -TEST_P(FunctionTest, UnaryFuncCalledWithMultipleArgsFails) { - // Construct a scalar. - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - FunctionDef fdef = SquareFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue signature(unknown_shape, DT_FLOAT); - Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(x)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status())); - ASSERT_TRUE(absl::StrContains(v.status().message(), "No match")); -} - -TEST_P(FunctionTest, IncorrectArityOfOutputSignatureFails) { - if (UseTfrt()) { - GTEST_SKIP() << "TFRT crashes if expected number of output tensors does not" - " match actual."; - } - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - impl::TaggedValueTensor y = CreateScalarTensor(4.0f); - FunctionDef fdef = IdentityNFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue tensor_spec(unknown_shape, DT_FLOAT); - TaggedValue input_signature = TaggedValue::Tuple(); - input_signature.tuple().emplace_back(tensor_spec); - input_signature.tuple().emplace_back(tensor_spec); - // This is wrong! - TaggedValue output_signature(unknown_shape, DT_FLOAT); - Status s = tf_function.RegisterTrace(std::move(trace), input_signature, - output_signature); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(y)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status())) << v.status(); - ASSERT_TRUE(absl::StrContains(v.status().message(), - "Expecting 2 outputs, but *num_retvals is 1")); -} - -TEST_P(FunctionTest, IncorrectDtypeInOutputSignatureFails) { - // Construct a scalar. - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - FunctionDef fdef = AddFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue input_tensor_spec(unknown_shape, tensorflow::DT_FLOAT); - TaggedValue input_signature = TaggedValue::Tuple(); - input_signature.tuple().emplace_back(input_tensor_spec); - input_signature.tuple().emplace_back(input_tensor_spec); - // Incorrect type. - TaggedValue output_tensor_spec(unknown_shape, tensorflow::DT_INT64); - Status s = tf_function.RegisterTrace(std::move(trace), input_signature, - output_tensor_spec); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(x)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(tensorflow::errors::IsInternal(v.status())) << v.status(); - ASSERT_TRUE( - absl::StrContains(v.status().message(), "Shape and dtype of tensor")); - ASSERT_TRUE(absl::StrContains(v.status().message(), - "does not match that in signature")); -} - -INSTANTIATE_TEST_SUITE_P(TF2CAPI, FunctionTest, - ::testing::Combine(::testing::Values("graphdef", - "mlir"), - ::testing::Values(false))); - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/generate_testdata.py b/tensorflow/cc/experimental/libtf/tests/generate_testdata.py deleted file mode 100644 index 09b84399a00e2f..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/generate_testdata.py +++ /dev/null @@ -1,105 +0,0 @@ -# /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ==============================================================================*/ -r"""Generate models in testdata for use in tests. - -If this script is being run via ` run`, pass an absolute path. -Otherwise, this script will attempt to write to a non-writable directory. - -Example: - run //third_party/tensorflow/cc/experimental/libtf:generate_testdata - -- \ - --path`pwd`/third_party/tensorflow/cc/experimental/libtf/tests/testdata/ \ - --model_name=simple-model -""" -import os - -from absl import app -from absl import flags - -from tensorflow.python.compat import v2_compat -from tensorflow.python.eager import def_function -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_spec -from tensorflow.python.module import module -from tensorflow.python.ops import variables -from tensorflow.python.saved_model import saved_model - -TESTDATA_PATH = flags.DEFINE_string( - "path", None, help="Path to testdata directory.") - -MODEL_NAME = flags.DEFINE_string( - "model_name", None, help="Name of model to generate.") - - -class DataStructureModel(module.Module): - """Model used for testing data structures in the C++ API.""" - - def __init__(self): - self.arr1 = [1.] - self.const_arr = [constant_op.constant(1.)] - self.var_arr = [variables.Variable(1.), variables.Variable(2.)] - self.dict1 = {"a": 1.} - self.var_dict = {"a": variables.Variable(1.), "b": variables.Variable(2.)} - - -class SimpleModel(module.Module): - """A simple model used for exercising the C++ API.""" - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - ]) - def test_float(self, x): - return constant_op.constant(3.0) * x - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32), - ]) - def test_int(self, x): - return constant_op.constant(3) * x - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - ]) - def test_add(self, x, y): - # Test a function with multiple arguments. - return x + y - - -TEST_MODELS = { - "simple-model": SimpleModel, - "data-structure-model": DataStructureModel -} - - -def get_model(name): - if name not in TEST_MODELS: - raise ValueError("Model name '{}' not in TEST_MODELS") - return TEST_MODELS[name]() - - -def main(unused_argv): - - model = get_model(MODEL_NAME.value) - path = os.path.join(TESTDATA_PATH.value, MODEL_NAME.value) - saved_model.save(model, path) - - -if __name__ == "__main__": - v2_compat.enable_v2_behavior() - flags.mark_flag_as_required("path") - flags.mark_flag_as_required("model_name") - app.run(main) diff --git a/tensorflow/cc/experimental/libtf/tests/mlir_transform_test.cc b/tensorflow/cc/experimental/libtf/tests/mlir_transform_test.cc deleted file mode 100644 index 897b1235821e49..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/mlir_transform_test.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/mlir/mlir_transform.h" - -#include -#include - -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { - -TEST(TransformTest, LoadSavedModel) { - Object mlir = MLIR(); - TF_ASSERT_OK_AND_ASSIGN(Callable load, - mlir.Get(String("LoadSavedModel"))); - - TF_ASSERT_OK_AND_ASSIGN( - Handle model_bad, - load.Call(mlir, String("/error/doesnotexist___31284382"))); - TF_ASSERT_OK(Cast(model_bad).status()); - - const std::string model_good_path = tensorflow::GetDataDependencyFilepath( - "tensorflow/cc/experimental/libtf/tests/testdata/simple-model"); - - TF_ASSERT_OK_AND_ASSIGN( - Object model_good, - load.Call(mlir, String(model_good_path.c_str()))); - - TF_ASSERT_OK_AND_ASSIGN(Callable to_string, - model_good.Get(String("ToString"))); - - TF_ASSERT_OK_AND_ASSIGN(String s, to_string.Call(model_good)); - - ASSERT_GT(strlen(s.get()), 0); -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/module_test.cc b/tensorflow/cc/experimental/libtf/tests/module_test.cc deleted file mode 100644 index 78620846c59aee..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/module_test.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/module.h" - -#include - -#include "tensorflow/cc/experimental/libtf/runtime/core/core.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/status_matchers.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" - -namespace tf { -namespace libtf { -namespace impl { - -using ::tensorflow::libexport::TFPackage; -using ::tensorflow::testing::StatusIs; -using ::tf::libtf::runtime::Runtime; - -TEST(ModuleTest, TestStubbedFunctions) { - Runtime runtime = runtime::core::Runtime(); - TFPackage tf_package; - tensorflow::StatusOr result = BuildProgram(runtime, tf_package); - ASSERT_FALSE(result.status().ok()); -} - -TEST(ModuleTest, TestBuildObjectsDataStructures) { - const std::string path = tensorflow::GetDataDependencyFilepath( - "tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model"); - TF_ASSERT_OK_AND_ASSIGN(TFPackage tf_package, TFPackage::Load(path)); - - TF_ASSERT_OK_AND_ASSIGN(std::vector objects, - BuildObjects(tf_package)); - EXPECT_EQ(objects.size(), 7); - // The first node of data-structure-model is a dictionary. - TF_ASSERT_OK_AND_ASSIGN(tf::libtf::Dictionary node, - Cast(objects.front())); - - // The next three nodes of data-structure-model are lists. - for (unsigned int i = 1; i < 4; i++) { - TF_ASSERT_OK_AND_ASSIGN(tf::libtf::List node, - Cast(objects.at(i))); - } - // The last three nodes of data-structure-model are dictionaries. - for (unsigned int i = 4; i < 7; i++) { - TF_ASSERT_OK_AND_ASSIGN(tf::libtf::Dictionary node, - Cast(objects.at(i))); - } -} - -TEST(ModuleTest, TestBuildEmptyList) { - tensorflow::SavedObject saved_object_proto; - const std::string pb_txt = R"pb( - user_object { - identifier: "trackable_list_wrapper" - version { producer: 1 min_consumer: 1 } - } - )pb"; - - ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString( - pb_txt, &saved_object_proto)); - TF_ASSERT_OK_AND_ASSIGN(Handle result, - BuildSavedUserObject(saved_object_proto)); - EXPECT_EQ(Cast(result)->size(), 0); -} - -TEST(ModuleTest, TestBuildEmptyDict) { - tensorflow::SavedObject saved_object_proto; - const std::string pb_txt = R"pb( - user_object { - identifier: "trackable_dict_wrapper" - version { producer: 1 min_consumer: 1 } - } - )pb"; - - ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString( - pb_txt, &saved_object_proto)); - - TF_ASSERT_OK_AND_ASSIGN(Handle result, - BuildSavedUserObject(saved_object_proto)); - EXPECT_EQ(Cast(result)->size(), 0); -} - -TEST(ModuleTest, TestBuildSignatureMap) { - tensorflow::SavedObject saved_object_proto; - const std::string pb_txt = R"pb( - user_object { - identifier: "signature_map" - version { producer: 1 min_consumer: 1 } - } - )pb"; - - ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString( - pb_txt, &saved_object_proto)); - TF_ASSERT_OK_AND_ASSIGN(Handle result, - BuildSavedUserObject(saved_object_proto)); - EXPECT_EQ(Cast(result)->size(), 0); -} - -TEST(ModuleTest, TestUnimplementedUserObject) { - tensorflow::SavedObject saved_object_proto; - const std::string pb_txt = R"pb( - user_object { - identifier: "foo" - version { producer: 1 min_consumer: 1 } - } - )pb"; - - ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString( - pb_txt, &saved_object_proto)); - - EXPECT_THAT( - BuildSavedUserObject(saved_object_proto), - StatusIs(tensorflow::error::UNIMPLEMENTED, ::testing::HasSubstr("foo"))); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/object_test.cc b/tensorflow/cc/experimental/libtf/tests/object_test.cc deleted file mode 100644 index dd0916facf9984..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/object_test.cc +++ /dev/null @@ -1,184 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/object.h" - -#include - -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { - -TEST(ObjectTest, TestDictionary) { - Dictionary foo; - foo.Set(String("a"), Integer(33)); - foo.Set(String("b"), Integer(33)); - EXPECT_EQ(foo.Get(String("b"))->get(), 33); -} - -TEST(ObjectTest, TestTuple) { - Tuple foo(String("a"), Integer(33), Float(10.f)); - EXPECT_EQ(foo.size(), 3); - EXPECT_EQ(foo.Get(1)->get(), 33); -} - -TEST(ObjectTest, TestList) { - List l; - EXPECT_EQ(l.size(), 0); - l.append(Integer(3)); - EXPECT_EQ(l.Get(0)->get(), 3); - EXPECT_EQ(l.size(), 1); -} - -TaggedValue AddIntegers(TaggedValue args_, TaggedValue kwargs_) { - auto& args = args_.tuple(); - // auto& kwargs = kwargs_.dict(); - return TaggedValue(args[0].i64() + args[1].i64()); -} - -TEST(ObjectTest, TestCast) { - Integer i(3); - auto result = Cast(i); - ASSERT_TRUE(!result.ok()); -} - -TEST(ObjectTest, TestCall) { - TaggedValue add_func(AddIntegers); - Callable add(add_func); - TF_ASSERT_OK_AND_ASSIGN(Integer i, - add.Call(Integer(1), Integer(10))); - EXPECT_EQ(i.get(), 11); - - TF_ASSERT_OK_AND_ASSIGN( - Integer i2, add.Call(1, Integer(10), KeywordArg("foo") = 3)); - EXPECT_EQ(i2.get(), 11); -} - -TEST(ObjectTest, MakeObject) { - // TaggedValue func(f); - Object parent; - parent.Set(String("test3"), Integer(3)); - Object child; - child.Set(String("test1"), Integer(1)); - child.Set(String("test2"), Integer(2)); - child.Set(Object::ParentKey(), parent); - EXPECT_EQ(child.Get(String("test1"))->get(), 1); - EXPECT_EQ(child.Get(String("test2"))->get(), 2); - EXPECT_EQ(child.Get(String("test3"))->get(), 3); - ASSERT_FALSE(child.Get(String("test4")).status().ok()); - TF_ASSERT_OK(child.Get(String("test3")).status()); -} - -TEST(ObjectTest, CallFunctionOnObject) { - Object module; - module.Set(String("add"), Callable(TaggedValue(AddIntegers))); - TF_ASSERT_OK_AND_ASSIGN(Callable method, module.Get(String("add"))); - - TF_ASSERT_OK_AND_ASSIGN(Integer val, method.Call(1, 2)); - EXPECT_EQ(val.get(), 3); -} - -TEST(ObjectTest, Capsule) { - Object obj; - int* hundred = new int(100); - Handle capsule = - Handle(TaggedValue::Capsule(static_cast(hundred), [](void* p) { - delete static_cast(p); - })); - obj.Set(String("hundred"), capsule); - EXPECT_EQ(*static_cast( - obj.Get(String("hundred"))->cast()), - 100); -} - -None AppendIntegerToList(List a, Integer b) { - a.append(b); - return None(); -} -Integer AddIntegersTyped(Integer a, Integer b) { - return Integer(a.get() + b.get()); -} -Integer ReturnFive() { return Integer(5); } - -TEST(TypeUneraseCallTest, TestCallable) { - // Add two integers. - Callable add(TFLIB_CALLABLE_ADAPTOR(AddIntegersTyped)); - auto res = add.Call(Integer(3), Integer(1)); - EXPECT_EQ(res->get(), 4); -} - -TEST(TypeUneraseCallTest, TestAppend) { - // Append some indices to a list. - Callable append(TFLIB_CALLABLE_ADAPTOR(AppendIntegerToList)); - List l; - TF_ASSERT_OK(append.Call(l, Integer(3)).status()); - TF_ASSERT_OK(append.Call(l, Integer(6)).status()); - EXPECT_EQ(l.size(), 2); - EXPECT_EQ(l.Get(0)->get(), 3); - EXPECT_EQ(l.Get(1)->get(), 6); -} - -TEST(TypeUneraseCallTest, TestCallableWrongArgs) { - // Try variants of wrong argument types. - Callable append(TFLIB_CALLABLE_ADAPTOR(AddIntegersTyped)); - ASSERT_FALSE(append.Call(Object(), Integer(3)).ok()); - ASSERT_FALSE(append.Call(Object(), Object()).ok()); - // Try variants of wrong numbers of arguments. - ASSERT_FALSE(append.Call().ok()); - ASSERT_FALSE(append.Call(Integer(3)).ok()); - ASSERT_FALSE(append.Call(Integer(3), Integer(4), Integer(5)).ok()); -} - -Handle Polymorph(Handle a) { - auto i = Cast(a); - if (i.ok()) { - return Integer(i->get() * 2); - } - auto f = Cast(a); - if (f.ok()) { - return Float(f->get() * 2.f); - } - return None(); -} - -TEST(TypeUneraseCallTest, TestCallableGeneric) { - Callable f(TFLIB_CALLABLE_ADAPTOR(Polymorph)); - EXPECT_EQ(f.Call(Float(.2))->get(), .4f); - EXPECT_EQ(Cast(*f.Call(Float(.2)))->get(), .4f); - EXPECT_EQ(f.Call(Integer(3))->get(), 6); -} - -TEST(TypeUneraseCallTest, TestLambda) { - // Test a trivial lambda that doubles an integer. - Callable c( - TFLIB_CALLABLE_ADAPTOR([](Integer a) { return Integer(a.get() * 2); })); - EXPECT_EQ(c.Call(Integer(3))->get(), 6); - // Testa lambda that has captured state (call count). - int call_count = 0; - Callable f(TFLIB_CALLABLE_ADAPTOR([&call_count](Integer a, Integer b) { - call_count++; - return Integer(a.get() + b.get()); - })); - EXPECT_EQ(f.Call(Integer(3), Integer(-1))->get(), 2); - EXPECT_EQ(f.Call(Integer(3), Integer(-3))->get(), 0); - EXPECT_EQ(call_count, 2); -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/perf_test.cc b/tensorflow/cc/experimental/libtf/tests/perf_test.cc deleted file mode 100644 index 3c40ac0438e77a..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/perf_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -namespace tf { -namespace libtf { - -namespace { - -// AddTagged using tagged values -TaggedValue AddTagged(TaggedValue args, TaggedValue kwargs) { - return TaggedValue(args.tuple()[0].i64() + args.tuple()[1].i64()); -} - -int64_t AddRaw(int64_t a, int64_t b) { return a + b; } - -} // namespace - -// Add numbers in a loop by calling a callable. -void CallFunctions(::testing::benchmark::State& state) { - Integer sum(0); - Callable callable((impl::TaggedValue(impl::Func(AddTagged)))); - *callable.Call(sum, Integer(30)); - size_t i = 0; - for (auto dummy : state) { - sum = *callable.Call(sum, Integer(i)); - i++; - } -} - -// Add numbers in a loop by calling a callable, looking up method every -// time by tokenized string. -void CallFunctionsIndirect(::testing::benchmark::State& state) { - Integer sum(0); - Callable callable((impl::TaggedValue(impl::Func(AddTagged)))); - Object o; - String name("f"); - o.Set(name, callable); - size_t i = 0; - for (auto dummy : state) { - sum = *(o.Get(name))->Call(sum, Integer(i)); - i++; - } -} - -// Add numbers in a loop by calling a callable, looking up method every -// time by non-tokenized string. -void CallFunctionsIndirectNaive(::testing::benchmark::State& state) { - Integer sum(0); - Callable callable((impl::TaggedValue(impl::Func(AddTagged)))); - Object o; - o.Set(String("f"), callable); - size_t i = 0; - for (auto dummy : state) { - sum = *(o.Get(String("f")))->Call(sum, Integer(i)); - i++; - } -} - -// Add numbers in a loop by calling a raw C++ function with a function -// pointer. -void CallFunctionsBase(::testing::benchmark::State& state) { - int64_t sum = 0; - typedef int64_t (*Func)(int64_t a, int64_t b); - volatile Func f_raw = AddRaw; - Func f = f_raw; - size_t i = 0; - for (auto dummy : state) { - sum = f(sum, i); - i++; - } - // volatile int64_t result = sum; -} - -BENCHMARK(CallFunctions)->Arg(1 << 10); -BENCHMARK(CallFunctionsIndirect)->Arg(1 << 10); -BENCHMARK(CallFunctionsIndirectNaive)->Arg(1 << 10); -BENCHMARK(CallFunctionsBase)->Arg(1 << 10); - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/runtime_test.cc b/tensorflow/cc/experimental/libtf/tests/runtime_test.cc deleted file mode 100644 index 3610b8a964648b..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/runtime_test.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/tests/runtime_test.h" - -namespace tf { -namespace libtf { -namespace runtime { - -using ::tensorflow::testing::StatusIs; -using ::testing::HasSubstr; -using ::tf::libtf::impl::TaggedValueTensor; - -constexpr char kSimpleModel[] = - "tensorflow/cc/experimental/libtf/tests/testdata/simple-model"; - -TEST_P(RuntimeTest, SimpleModelCallableFloatTest) { - Runtime runtime = RuntimeTest::GetParam()(); - - // Import the module and grab the callable - const std::string module_path = - tensorflow::GetDataDependencyFilepath(kSimpleModel); - - TF_ASSERT_OK_AND_ASSIGN(Object module, - runtime.Load(String(module_path.c_str()))); - std::cout << "Module imported." << std::endl; - - TF_ASSERT_OK_AND_ASSIGN(Callable fn, - module.Get(String("test_float"))); - TF_ASSERT_OK_AND_ASSIGN( - Tensor tensor, runtime.CreateHostTensor({}, TF_FLOAT, {2.0f})); - TF_ASSERT_OK_AND_ASSIGN(Tensor result, fn.Call(Tensor(tensor))); - - float out_val[1]; - TF_ASSERT_OK(result.GetValue(absl::MakeSpan(out_val))); - EXPECT_EQ(out_val[0], 6.0); -} - -TEST_P(RuntimeTest, SimpleModelCallableIntTest) { - Runtime runtime = RuntimeTest::GetParam()(); - - // Import the module and grab the callable - const std::string module_path = - tensorflow::GetDataDependencyFilepath(kSimpleModel); - TF_ASSERT_OK_AND_ASSIGN(Object module, - runtime.Load(String(module_path.c_str()))); - - TF_ASSERT_OK_AND_ASSIGN(Callable fn, - module.Get(String("test_int"))); - - // Call the function - TF_ASSERT_OK_AND_ASSIGN(Tensor host_tensor, - runtime.CreateHostTensor({}, TF_INT32, {2})); - - TF_ASSERT_OK_AND_ASSIGN(Tensor tensor, fn.Call(Tensor(host_tensor))); - - int out_val[1]; - TF_ASSERT_OK(tensor.GetValue(absl::MakeSpan(out_val))); - EXPECT_EQ(out_val[0], 6); -} - -TEST_P(RuntimeTest, SimpleModelCallableMultipleArgsTest) { - Runtime runtime = RuntimeTest::GetParam()(); - - // Import the module and grab the callable - const std::string module_path = - tensorflow::GetDataDependencyFilepath(kSimpleModel); - TF_ASSERT_OK_AND_ASSIGN(Object module, - runtime.Load(String(module_path.c_str()))); - - TF_ASSERT_OK_AND_ASSIGN(Callable fn, - module.Get(String("test_add"))); - - TF_ASSERT_OK_AND_ASSIGN(Tensor tensor1, - runtime.CreateHostTensor({}, TF_FLOAT, {2.0f})) - TF_ASSERT_OK_AND_ASSIGN(Tensor tensor2, - runtime.CreateHostTensor({}, TF_FLOAT, {3.0f})) - - TF_ASSERT_OK_AND_ASSIGN(Tensor result_tensor, - fn.Call(tensor1, tensor2)); - float out_val[1]; - TF_ASSERT_OK(result_tensor.GetValue(absl::MakeSpan(out_val))); - EXPECT_EQ(out_val[0], 5.0f); -} - -TEST_P(RuntimeTest, CreateHostTensorIncompatibleShape) { - Runtime runtime = RuntimeTest::GetParam()(); - EXPECT_THAT(runtime.CreateHostTensor({2}, TF_FLOAT, {2.0f}), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Mismatched shape and data size"))); -} - -TEST_P(RuntimeTest, CreateHostTensorNonFullyDefinedShapeRaises) { - Runtime runtime = RuntimeTest::GetParam()(); - EXPECT_THAT(runtime.CreateHostTensor({-1}, TF_FLOAT, {2.0f}), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Shape must be fully-defined"))); -} - -TEST_P(RuntimeTest, CreateHostTensorIncompatibleDataType) { - Runtime runtime = RuntimeTest::GetParam()(); - EXPECT_THAT(runtime.CreateHostTensor({1}, TF_BOOL, {2.0f}), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Invalid number of bytes in data buffer"))); -} - -TEST_P(RuntimeTest, TensorCopyInvalidSize) { - Runtime runtime = RuntimeTest::GetParam()(); - TF_ASSERT_OK_AND_ASSIGN( - Tensor tensor, runtime.CreateHostTensor({1}, TF_FLOAT, {2.0f})) - float val[2]; - - EXPECT_THAT(tensor.GetValue(absl::MakeSpan(val)), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Mismatched number of elements"))); -} - -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/runtime_test.h b/tensorflow/cc/experimental/libtf/tests/runtime_test.h deleted file mode 100644 index 3ae665c663b784..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/runtime_test.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_TESTS_RUNTIME_TEST_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_TESTS_RUNTIME_TEST_H_ - -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "tensorflow/c/eager/unified_api_testutil.h" -#include "tensorflow/c/tf_datatype.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/status_matchers.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace runtime { - -typedef Runtime (*RuntimeFn)(); - -class RuntimeTest : public ::testing::TestWithParam {}; - -} // namespace runtime -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_TESTS_RUNTIME_TEST_H_ diff --git a/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc b/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc deleted file mode 100644 index 599520025229f1..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/runtime/core/core.h" -#include "tensorflow/cc/experimental/libtf/tests/runtime_test.h" - -namespace tf { -namespace libtf { -namespace runtime { - -INSTANTIATE_TEST_SUITE_P(TF2CAPI, RuntimeTest, - ::testing::Values(core::Runtime)); -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(RuntimeTest); -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/tensor_test.cc b/tensorflow/cc/experimental/libtf/tests/tensor_test.cc deleted file mode 100644 index 85243dd428775f..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/tensor_test.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/unified_api_testutil.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { - -using AbstractContextPtr = tensorflow::AbstractContextPtr; -using AbstractContext = tensorflow::AbstractContext; -using AbstractTensorHandle = tensorflow::AbstractTensorHandle; -using TF_StatusPtr = tensorflow::TF_StatusPtr; -using Status = tensorflow::Status; - -class UnifiedCAPI - : public ::testing::TestWithParam> { - protected: - void SetUp() override { - TF_StatusPtr status(TF_NewStatus()); - TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - } -}; - -namespace { -template -TaggedValue MakeContext(T runtime) { - AbstractContext* ctx_raw = nullptr; - Status s = BuildImmediateExecutionContext(runtime, &ctx_raw); - // ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); - return TaggedValue::Capsule(static_cast(ctx_raw), [](void* p) { - tensorflow::internal::AbstractContextDeleter()( - static_cast(p)); - }); -} -} // namespace - -TEST_P(UnifiedCAPI, HoldTensors) { - // Use the parametrized test parameters to make a context. - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); - ctx.reset(ctx_raw); - } - - // Construct a scalar. - impl::TaggedValueTensor x; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); - ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); - x.reset(x_raw, false); - } - // Manually copy pointer so we can later compare the reference count. - impl::TaggedValueTensor x2(x); - - { - // Take ownership of x2 pointer. Semantics of AbstractTensorHandlePtr - // are that it has a reference. Here we steal that reference and put it - // into TaggedValue. If we used release() we would double free. - impl::TaggedValue tensor(std::move(x2)); - auto list = TaggedValue::List(); - // Test adding values by copying and moving. - list.list().emplace_back(3.f); - list.list().push_back(tensor); - list.list().emplace_back(std::move(tensor)); - ASSERT_FALSE(x->RefCountIsOne()); - } - ASSERT_TRUE(x->RefCountIsOne()); -} - -TaggedValue MakeScalarTensor(TaggedValue self, TaggedValue val) { - if (val.type() != TaggedValue::FLOAT32) return TaggedValue::None(); - if (self.type() != TaggedValue::DICT) return TaggedValue::None(); - TaggedValue ctx_capsule = (self.dict())[TaggedValue("context")]; - AbstractContext* ctx = static_cast(ctx_capsule.capsule()); - AbstractTensorHandle* x_raw = nullptr; - Status s = - TestScalarTensorHandle(ctx, val.f32().get(), &x_raw); - if (!s.ok()) return TaggedValue::None(); - return TaggedValue(impl::TaggedValueTensor(x_raw, false)); -} -TEST_P(UnifiedCAPI, SimpleCreationFunctions) { - // Use the parametrized test parameters to make a context. - TaggedValue context = MakeContext(std::get<1>(GetParam())); - Object methods; - methods.Set(String("context"), Handle(MakeContext(std::get<1>(GetParam())))); - methods.Set(String("make_scalar_tensor"), - Callable(TaggedValue(MakeScalarTensor))); - - Handle foo = *methods.Get(String("make_scalar_tensor")) - ->Call(methods, Float(3.f)); -} - -INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, - ::testing::Combine(::testing::Values("graphdef", - "mlir"), - ::testing::Values(false))); - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/README b/tensorflow/cc/experimental/libtf/tests/testdata/README deleted file mode 100644 index 84ad79dac73564..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/testdata/README +++ /dev/null @@ -1,2 +0,0 @@ -The models in this directory are generated using -//third_party/tensorflow/cc/experimental/libtf/tests:generate_testdata \ No newline at end of file diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/saved_model.pb b/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/saved_model.pb deleted file mode 100644 index 60e1a6028942d3de9be5d4f72ec0724fbec26610..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10117 zcmeHN&2Jn@6`!7OKYyk?u_t9a37v!_BjNG%BvF_HS<5;nGFJ2@bcR(qyxHyL-2 zx_g}D1mc9m0kMZ=B_sqZ;7`~CNC?DAz=1n|09^Phc-5aZ{V^WLS&0L-?CGg`_3G8H z-h1`lt1b)VhuiRPAb$sF!`!-MX-E1C#nhC8p1R#9Q&2Wl%NUp)wOS!FP&ie3163?^ z?YDhZoCf!SXuY?^VwW=ku z0Hmv0wx%n#X6Rl|wInXVY*+26wkjLFuH4m3i{uXsqeruM1h$|ePy1;Q4|p&)mJHGvaWoQ1sdMzfqgnyqGF115KRN=H4$SyeNR)-bYS z1q#-&(pQQ5LG!I_Re%p*W>;0ZY1>85-Yx?sb`cfBL>LkEq6-+pH3w7fHp#kj;s98I zIY$Ip(YvzM(>kn`UV@TKXGujzCoHiF-V=eIX4!X`YTRXjFgSZq+*3~9aW-<^*&g9+ zxVashOv6OS(7PJj{3?(l2u!)yr#ec>9>eV3v0`@B?uS}WWl%7SI8ZF-8F89$Hpc8$ zbEp}h4^CtR$?RIy97V!IZ&0-ae-vPG&sJ>paM0UPOjL7}MD1=Uy&joF?J$olDYvib zxM`l@NcnxVAaNd8^DvgE4^C&K6|Litz`|wG~ZQhlYt)V3O^T+I3W$=NNHGBH9m9239BDfLYuw<24TLs;A5X z$dsevF2r|f7BUExSg=&{RIL@H)d#CpLxKGD5p<#K+biLvuPwlg?4rk1 zsE@q!kZomZlfIA5t_kdmD-r97ECfiv8NLV0{w#IoBAkDI3%d1)-=M+BvBZF50H2xM zxacyLhWBHNaJ)f*tKm#zErlC>{>4aN;bNqr#YpE|jCB6RG{3^dG=~<`Jm+GX<1J>& zJ$s3J0{e)`Tp(rE&c4D-xk!M;PKrH1Cj=S% zMfZXHg*E&r6QSX?Ob|iA79hd_wzLTz;xE>BpRBh(-jrXgfBbwiG-G;X9Xdb#waeEWMc#uiQe+pT36nkO_O7DFpAAK%vcyFn6iI_ z5ska)s8sN-1C>oYEM!CPy~QUt?3dkB+8Ozc zNc^xumY@VVw7)rlJ$!@@87ccT^HBJakg45d@yNB^6ZV@TXz z`Q8KgG7DEyA|pIp#@?IY2&262htCq%-mJ~JAWlF@DwHbt^S_GWZHA$^dmXMOVThZ= zf_M4BL-Max_}c<}kv)@e7ABcd1t}1@0ap|!??VfB@;;+4`;XO1E03NrLvF)LjHnTi z$&HbOQb;I`By5I+O`i}WV$i(IP`+O+VTAR$5X)3Q7f3XDl^aQoQc`1+{y7T^yM8KE z=X0Nwu$ zEN?vC-1yGU_LHrB=Um!uKi%Bemp|HF-+3%m(}1F%2|z;(oj{e#7B7n9iVtI@O#%=z zZSPlF9yK@pQs6V)FzKbdEzoPB^t=YN+_=28MsI`!(d7k1(YWMJOi>n~!t(d%Uu{0f zoE^X9_;zWM6PzA#MUtRgLgxdiT3#~VT8m0iVwDq?LR4eoV2Oi?B0689m;7sBkhvty)y!0}9{?hqaliWyZl#&{owD6KuUZb*lj;xdx z#!9>a-=iH)_WPXdp+VjYz@{Z1Ny?$Msyg zljNmHf*(z!loBYVM9O9YWwW(W&X7ltQci}~%pOdY=+R+X*zYI7{cx{GHtMEy7 zWZptzarAAt!E%~Y&doXL6#M7L(?-84kh^dzO4x8SuW?AF@kyI0q*&ee1aRscSi0q7 z?1mR(C)}^lp9(2u;KbtaP;@nH$h9>4tz8Z)$5_&QX*FB*`f5$Wped>Tj0W!J#fjl% z`Yc?S!CZ+~6q;^NIwhgPDi)k=spQ{g^`NxjRD>EAnP$nKhP3CLX+XBslYZTq3qw9{ z3i=atC*Hk<_eoch)*hHGopFYS-{(A5K}!_;GaVdy4Oawe50&k?68_Nc(pi^$)Qnps zTg|5JJM-J&^#J`Yf1Bcgo;H+^8FMw_qdd_rVeKx{5(LIOyqlpvfXjQsw>ou<=FM*6 z-i}pROqw+#EvC#ve<#H6m7*6-bUs8Q0F_=%6g=#o50@D<4)1;HdZbG+q39p7Fu&*A z*Zn^w9+47XPDF`*!tVTB^rG2fimHuWR5&9UzLGjI{bx3~g7vOqDUB#q%P+cO#=D%b z{efoD>nY0%DaJVx>3GLsoGkGujFTlEgLLQ!jM_+ID3zR`Gf(MkG7FTy9)BmHp725f zab6}E34H0L%QY$)dT!=Ng)l}VA zFwu=qLd~#3!G-cXcg!T>688{ChdvQ0U*_LjOjvNwsUy(0>M|{J@`mvJgCuZIXh#mk_+QjCE(MHO7iW&b2u9 zW=HnsM)u}M_7>86#X|X4#knc^6yJtHqYv+=T5BZj76ZL#T9@2-m=&+VJmb}sPYiq& z39oB3T{ijz{fes@wQs>?*|{$9Ni^(A>y(A&OHkE!K-^Cc0Vs-fX^t>F|wQOWg m#*w>z9p;!ozbIaY>EKcqH>U3iPg-;AQUq(J!L1WqpzXhxb=g?nTRmlEd5;=Y@P4SG)SHwdT2%f diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.index b/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.index deleted file mode 100644 index 52b3dfd35d92cdc0114e469099e7ae378f3bbdf1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 252 zcmey#_#udafsKPviiv}fL545h*~8V@JHX%5FT~Nw$2H#H$;;I_B;Gy9F~Ea^U5G(J zKw~1a>0;H>2N_hj%My#?6N`%U4fORKLqdW)okBxggZ0B4gFIPnLc$zOK;rP6`nhs8IF!-B>Cz!zLH~?u~U^m^k(~LNC$T=QyJFq~db33&Im}?EjoVMRWFC z+N5PDfq0}j!R>f(TgBUhr*tB?sq4I}^-a&E7pU-DC;=QiwJq}`7(WvXnFyZ@22v_~ zgagMl%d$PqGi=L+@i+UXdE43B>G2NygD8I|aB*Z;f$7w>>-0k+H%^7I^nQMW`+1Mh z7M*%7G_(8pI=AvUQT~Sbjcf%`SCzHljR*+L%efJ{S?Fb}?Aoay!NRG~0(e#yfgA_U z7G9qO*FNA7f*#bm^>94ndP!e+$^?f{sID|7ZBthz^tme+7Z{)LtE!2 zMv&pDc8^}8uNJ)msNA<7^jNt?q{}d!o_(kpmOr}=wP?^LO!)&GwQHE%(hhlhm0sLI zf9Ku4xurRt;d?&UH#F01ude`Jm{f`Wda9J8?00vYyFl;;;{z=UhogWz;-rm^y zNqbeEg9+8{d%eD=y8Bv>yR-u3c1a$Gs_vciI4!S=@-@IfHv60|uCuyGOY1W*HIl z;mmJCqhniLV^8flwr67&UHAnl-5^SN6SE}8Em!a4R&TSdF5t<}2!Aok9Z~7LIwy(e|c@{yFd$z;X*A@EjtMI4k_C~!# zZ$g1r!!@j(LCtircoizrJ#?{R03KU6mMdFP;K zV}^EV&NF!)nxQYmv4!5(F11V2HvkNsTNv0U>KdJXdvM(jV-chLR|IB!-#0qje2*Xd zMem;Gb@plQQ#wm76n>=#; zM*r+mp3xInCW5KS{ZzEn;K{J@vB>QPNwmNx6|4iH3d{r8@f^e2+eDY5sp|~7Y%&~v zmc;wC+$tC0u3)uHB;Ko0vOfyp5a=m&`4D6hp3Ev>(Xfe?&X zUsN~l-r4xk)`Pe2@7`Yj@g4QS`sSUDUG=T)+go>)*<3}~Gvua`f(T)ctD1!c$wVgc zm6e>0m29lTK*^GFu}~oU)Q{W1S{S#%KVw69wG%SM7IQ2+0F7NK*j6&4;i8Nq1sMVo zsbMT;MzHW1!z9!K{)EN}7)B80p{A((Yoi5Zr86idQ5o{N*p}8gICE>_qfq-{tuXEQ zNr1?M{BwrSU`actM(hvdyw#7boNtqfnTu=Ks+?3Op-7T~z~THN;P6!%xZ-8fz%3VdECS?&3}ga)@=`;#^Wu#WEsKR|IYXY{hhNw|8m2ZmL`JICKdenP;PX1R@5H`!VA)SCxvDIGcX?H+;+L2s zPCu%{jkuM8XoTGyDQu-mle89|?V{u+HEG4~u|};{SWa0Ppc-@$C9=!L2Mdv1X)KyU zc7f-yUpf!dIY$X3YjZ$#jCBJ)JNe`=>~*z7X@Mazq?O0EZNen>x4Fkv?AWTh;kd}( z;x;whgWxzqC!iK|gI&~3Cik=IHC72OCSXep4$Q?39C3fLgpMpP8o5(YSr=Sr3O8^3@9<#M5)ju^>!uZujQ9-Q65%ud{Gb0#f{usZ02RN(j)4=oy%KcTrD^nsv@!ib z(#+7qa47n=e|Eu+FT!1>C0cWXei+nn@7P0^SmtS_v+lcbYDSYE^G+Xo!>SLVkKC>N zwx(}^Akxqo5YjMWTOeH+2#=eja6jEQLquv}pT8ygQIO0PrECz)XA~QAId&m6)a<{4 C$c$ - -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(ValueTest, TestBasic) { - TaggedValue valuef(3.f); - TaggedValue valuei(int64_t(3)); - TaggedValue list = TaggedValue::List(); - TaggedValue tuple = TaggedValue::Tuple(); - tuple.tuple().push_back(TaggedValue(int64_t(310))); - list.list().push_back(valuei); - list.list().push_back(valuef); - list.list().push_back(tuple); - std::stringstream stream; - stream << list; - ASSERT_EQ(stream.str(), "[3, 3, (310, ), ]"); -} - -TEST(ValueTest, TestString) { - TaggedValue value1a("string1"); - std::string s = "string"; - s += "1"; - TaggedValue value1b(s.c_str()); - // Verify that interned the pointers are the same. - ASSERT_EQ(value1b.s(), value1a.s()); - TaggedValue value2("string2"); - ASSERT_NE(value1a.s(), value2.s()); - ASSERT_STREQ(value1a.s(), "string1"); - ASSERT_STREQ(value2.s(), "string2"); -} - -TEST(Test1, TestDict) { - TaggedValue s1("test1"); - TaggedValue s2("test2"); - TaggedValue d = TaggedValue::Dict(); - d.dict()[s2] = TaggedValue(6.f); - std::stringstream stream; - stream << d; - ASSERT_EQ(stream.str(), "{test2: 6, }"); -} - -namespace { -TaggedValue add(TaggedValue args, TaggedValue kwargs) { - if (args.type() == TaggedValue::TUPLE) { - return TaggedValue(args.tuple()[0].f32() + args.tuple()[1].f32()); - } - return TaggedValue::None(); -} -} // namespace -TEST(Test1, TestFunctionCall) { - TaggedValue f32 = TaggedValue(add); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(1.f)); - args.tuple().emplace_back(TaggedValue(2.f)); - TaggedValue c = f32.func()(args, TaggedValue::None()).value(); - ASSERT_EQ(c, TaggedValue(3.f)); -} - -namespace { -int alloc_count = 0; -class Cool { - public: - Cool() { alloc_count++; } - ~Cool() { alloc_count--; } -}; -} // namespace - -TEST(Test1, TestCapsule) { - TaggedValue test_moved, test_copy; - ASSERT_EQ(alloc_count, 0); - void* ptr_value = new Cool(); - { - TaggedValue capsule = - TaggedValue::Capsule(static_cast(ptr_value), - [](void* x) { delete static_cast(x); }); - ASSERT_EQ(alloc_count, 1); - ASSERT_EQ(capsule.capsule(), ptr_value); - test_moved = std::move(capsule); - ASSERT_EQ(capsule.type(), TaggedValue::NONE); // NOLINT - test_copy = test_moved; - ASSERT_EQ(test_moved.capsule(), ptr_value); - ASSERT_EQ(test_copy.capsule(), ptr_value); - } - ASSERT_EQ(alloc_count, 1); - test_moved = TaggedValue::None(); - ASSERT_EQ(alloc_count, 1); - test_copy = TaggedValue(3.f); - ASSERT_EQ(alloc_count, 0); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/variable_test.cc b/tensorflow/cc/experimental/libtf/tests/variable_test.cc deleted file mode 100644 index 1e37ed9cb2b96b..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/variable_test.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_function.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/graph_function.h" -#include "tensorflow/c/eager/unified_api_testutil.h" -#include "tensorflow/c/experimental/ops/resource_variable_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/cc/experimental/libtf/function.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -using tensorflow::AbstractContext; -using tensorflow::AbstractContextPtr; -using tensorflow::AbstractFunctionPtr; -using tensorflow::AbstractTensorHandle; -using tensorflow::DT_FLOAT; -using tensorflow::PartialTensorShape; -using tensorflow::Status; -using tensorflow::TF_StatusPtr; - -class VariableTest - : public ::testing::TestWithParam> { - public: - template - impl::TaggedValueTensor CreateScalarTensor(T val) { - AbstractTensorHandle* raw = nullptr; - Status s = TestScalarTensorHandle(ctx_.get(), val, &raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - return impl::TaggedValueTensor(raw, /*add_ref=*/false); - } - - bool UseTfrt() { return std::get<1>(GetParam()); } - - AbstractContextPtr ctx_; - - protected: - void SetUp() override { - // Set the tracing impl, GraphDef vs MLIR. - TF_StatusPtr status(TF_NewStatus()); - TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - - // Set the runtime impl, Core RT vs TFRT. - AbstractContext* ctx_raw = nullptr; - s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - ctx_.reset(ctx_raw); - } -}; - -template -void ExpectEquals(AbstractTensorHandle* t, T expected) { - TF_Tensor* result_t; - Status s = tensorflow::GetValue(t, &result_t); - ASSERT_TRUE(s.ok()) << s.message(); - auto value = static_cast(TF_TensorData(result_t)); - EXPECT_EQ(*value, expected); - TF_DeleteTensor(result_t); -} - -TEST_P(VariableTest, CreateAssignReadDestroy) { - // Create uninitialized variable. - tensorflow::AbstractTensorHandlePtr var; - { - AbstractTensorHandle* var_ptr = nullptr; - PartialTensorShape scalar_shape; - TF_EXPECT_OK( - PartialTensorShape::MakePartialShape({}, 0, &scalar_shape)); - TF_EXPECT_OK(tensorflow::ops::VarHandleOp(ctx_.get(), &var_ptr, DT_FLOAT, - scalar_shape)); - var.reset(var_ptr); - } - // Assign a value. - auto x = CreateScalarTensor(2.0f); - TF_EXPECT_OK( - tensorflow::ops::AssignVariableOp(ctx_.get(), var.get(), x.get())); - // Read variable. - tensorflow::AbstractTensorHandlePtr value; - { - AbstractTensorHandle* value_ptr = nullptr; - TF_EXPECT_OK(tensorflow::ops::ReadVariableOp(ctx_.get(), var.get(), - &value_ptr, DT_FLOAT)); - value.reset(value_ptr); - } - ExpectEquals(value.get(), 2.0f); - // Destroy variable. - TF_EXPECT_OK(tensorflow::ops::DestroyResourceOp(ctx_.get(), var.get())); -} - -INSTANTIATE_TEST_SUITE_P(TF2CAPI, VariableTest, - ::testing::Combine(::testing::Values("graphdef", - "mlir"), - ::testing::Values(false))); - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/visit_test.cc b/tensorflow/cc/experimental/libtf/tests/visit_test.cc deleted file mode 100644 index fe905d9972a629..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/visit_test.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -struct Visitor { - const char* operator()(Int64 i) { return "int64"; } - const char* operator()(Float32 f) { return "float32"; } - template - const char* operator()(const T& i) { - return "else"; - } -}; - -TEST(VisitTest, Test1) { - TaggedValue a(Int64(1)), b(Float32(1.1f)); - TaggedValue c = TaggedValue::None(); - - ASSERT_EQ(a.visit(Visitor()), "int64"); - ASSERT_EQ(b.visit(Visitor()), "float32"); - ASSERT_EQ(c.visit(Visitor()), "else"); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/value.h b/tensorflow/cc/experimental/libtf/value.h deleted file mode 100644 index 61a2888426ee3d..00000000000000 --- a/tensorflow/cc/experimental/libtf/value.h +++ /dev/null @@ -1,596 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -/// @file value.h -/// @brief The TaggedValue struct that supports Python-like behavior in C++. -/// -/// The TaggedValue struct implements a tagged union data structure -/// (https://en.wikipedia.org/wiki/Tagged_union) in the TensorFlow C++ API. It -/// contains a `Type` enum (sometimes referred to as a "tag") -/// and a `Data` union for holding values. - -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/core/platform/intrusive_ptr.h" -#include "tensorflow/core/platform/statusor.h" - -// TODO(b/195578409): Move all value objects into `impl`. Currently only values -// that do not reference TaggedValue are there. -#include "tensorflow/cc/experimental/libtf/impl/none.h" -#include "tensorflow/cc/experimental/libtf/impl/scalars.h" -#include "tensorflow/cc/experimental/libtf/impl/string.h" -#include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" - -namespace tf { -namespace libtf { -namespace impl { -// Necessary forward declares. -class TaggedValue; -class Tuple; -template -// TODO(ccrusius): Use absl::Hash specializations instead. -class TaggedValueHash; -using List = std::vector; -using ListPtr = std::shared_ptr; -using Dict = - absl::flat_hash_map>; -using DictPtr = std::shared_ptr; -using TuplePtr = std::shared_ptr; -using Func = - std::function(TaggedValue, TaggedValue)>; -// A capsule holds a pointer and a destructor for the pointer (i.e. a generic -// shared_ptr to void with a custom deleter). -using Capsule = std::shared_ptr; -using TaggedValueTensor = - tensorflow::core::IntrusivePtr; - -// Declare hash types so they can be instantiated below. - -/// @brief TaggedValue hashing infrastructure, which uses absl::hash. -/// -/// Hashable TaggedValues overload `AbslHashValue`. Non-hashable structures -/// return 0. -template <> -struct TaggedValueHash { - size_t operator()(const TaggedValue& v) const; -}; - -/// @brief Hash implementation for TaggedValue Tuples. -template <> -struct TaggedValueHash { - size_t operator()(const Tuple& t) const; -}; - -/// @brief The basic `TaggedValue` tagged union type. -/// -/// A `TaggedValue` contains a `Type` (or "tag") as an enum and a `Value` union. -/// Values include tensors, primitive values, lists, tuples, and dictionaries. -/// In the future we might also want to have representation of python objects in -/// the form of PyObject*. -class TaggedValue final { - public: - /// @brief Enum that describes the possible types a `TaggedValue` can be. - /// - /// A `TaggedValue` must be one of the following types: NONE, INT64, FLOAT32, - /// STRING, FUNC, DICT, LIST, TUPLE, TENSOR, TENSOR_SPEC, CAPSULE. - enum Type { - NONE = 0, - INT64 = 1, - FLOAT32 = 2, - STRING = 3, - FUNC = 4, - DICT = 5, - LIST = 6, - TUPLE = 7, - TENSOR = 8, - TENSOR_SPEC = 9, - CAPSULE = 10, - }; - TaggedValue() : type_(NONE), data_() {} - - /// Move assignment operator. - TaggedValue& operator=(TaggedValue&& v) { - destroy(); - MoveIntoUnion(std::move(v)); - return *this; - } - /// Move constructor. - TaggedValue(TaggedValue&& v) : type_(NONE) { MoveIntoUnion(std::move(v)); } - /// Copy constructor. - TaggedValue(const TaggedValue& v) : type_(NONE) { CopyIntoUnion(v); } - /// Copy assignment operator. - TaggedValue& operator=(const TaggedValue& v) { - destroy(); - CopyIntoUnion(v); - return *this; - } - /// TaggedValue constructor for type TENSOR. - explicit TaggedValue(TaggedValueTensor tensor) - : type_(TENSOR), data_(std::move(tensor)) {} - /// TaggedValue constructor for type TENSOR_SPEC. - explicit TaggedValue(tensorflow::PartialTensorShape shape, - tensorflow::DataType dtype) - : type_(TENSOR_SPEC), data_(shape, dtype) {} - /// TaggedValue constructor for type FUNC. - explicit TaggedValue(Func f32) : type_(FUNC), data_(f32) {} - /// TaggedValue constructor for type FLOAT32. - explicit TaggedValue(float f32) : type_(FLOAT32), data_(Float32(f32)) {} - /// TaggedValue constructor for type INT64. - explicit TaggedValue(int64_t i64) : type_(INT64), data_(Int64(i64)) {} - /// TaggedValue constructor for type FLOAT32. - explicit TaggedValue(Float32 f32) : type_(FLOAT32), data_(f32) {} - /// TaggedValue constructor for type INT64. - explicit TaggedValue(Int64 i64) : type_(INT64), data_(i64) {} - /// TaggedValue constructor for type STRING. - explicit TaggedValue(const char* s) : type_(STRING), data_(s) {} - /// Constructs a TaggedValue with type NONE. - static TaggedValue None() { - TaggedValue v; - v.type_ = NONE; - return v; - } - /// Constructs a TaggedValue with type LIST. - static TaggedValue List() { - TaggedValue v; - v.type_ = LIST; - using T = decltype(v.data_.list); - new (&v.data_.list) T(std::make_shared()); - return v; - } - /// Constructs a TaggedValue with type TUPLE. - static TaggedValue Tuple() { - TaggedValue v; - v.type_ = TUPLE; - using T = decltype(v.data_.tuple); - new (&v.data_.tuple) T(std::make_shared()); - return v; - } - /// Constructs a TaggedValue with type DICT. - static TaggedValue Dict() { - TaggedValue v; - v.type_ = DICT; - using T = decltype(v.data_.dict); - new (&v.data_.dict) T(std::make_shared()); - return v; - } - /// Constructs a TaggedValue with type TENSOR. - static TaggedValue Tensor(tensorflow::AbstractTensorHandle* raw_ptr) { - TaggedValue v; - v.type_ = TENSOR; - using T = decltype(v.data_.tensor); - new (&v.data_.tensor) T(raw_ptr, /*add_ref=*/false); - return v; - } - - /// Constructs a TaggedValue with type CAPSULE with a default destructor. - template - static TaggedValue Capsule(T* data) { - return Capsule(static_cast(data), - [](void* x) { delete static_cast(x); }); - } - /// Constructs a TaggedValue with type CAPSULE with a custom destructor. - static TaggedValue Capsule(void* data, void (*deleter)(void*)) { - TaggedValue v; - v.type_ = CAPSULE; - using T = decltype(v.data_.capsule); - new (&v.data_.capsule) T(data, deleter); - return v; - } - /// Destroys TaggedValue. Shared pointers in unions must be explicitly - /// deleted. - void destroy() { - if (type_ != NONE) { - // Explicitly run the destructor on the correct type. - visit([](auto& x) { - using T = typename std::decay::type; - x.~T(); - }); - // Make the type None, whenever we destroy so we always have an - // initialized value. - type_ = NONE; - } - } - ~TaggedValue() { destroy(); } - - /// @brief Get the underlying value based on type. - /// - /// @tparam T The desired return type. - /// @return The unwrapped value. If this `TaggedValue` type does not currently - /// contain a value of type `T`, the program terminates via a call to - /// `assert`. - template - T& get() { - assert(type_ == EnumValueOf::value); - return UnionAccess::unsafe_reference(*this); - } - - /// @brief Get the underlying value based on type. - /// - /// @tparam T The desired return type. - /// @return The unwrapped value. If this `TaggedValue` type does not currently - /// contain a value of type `T`, the program terminates via a call to - /// `assert`. - template - const T& get() const { - assert(type_ == EnumValueOf::value); - return UnionAccess::unsafe_reference(*this); - } - - /// Retrieves underlying value from a TaggedValue with type INT64. - const Int64& i64() const { return get(); } - - /// Retrieves underlying value from a TaggedValue with type FLOAT32. - const Float32& f32() const { return get(); } - - /// Retrieves underlying value from a TaggedValue with type STRING. - const char* s() const { return get().str().c_str(); } - - /// Retrieves underlying value from a TaggedValue with type LIST. - impl::List& list() { return *get(); } - /// Retrieves underlying value from a TaggedValue with type LIST. - const impl::List& list() const { return *get(); } - - /// Retrieves underlying value from a TaggedValue with type TUPLE. - impl::Tuple& tuple() { return *get(); } - /// Retrieves underlying value from TaggedValues with type TUPLE. - const impl::Tuple& tuple() const { return *get(); } - - /// Retrieves underlying value from a TaggedValue with type DICT. - impl::Dict& dict() { return *get(); } - /// Retrieves underlying value from TaggedValues with type DICT. - const impl::Dict& dict() const { return *get(); } - - /// Retrieves underlying value from a TaggedValue with type FUNC. - impl::Func func() const { return get(); } - - // TODO(danielellis): make const-only if possible, once the API allows for it - /// Retrieves underlying value from a TaggedValue with type TENSOR. - TaggedValueTensor& tensor() { return get(); } - /// Retrieves underlying value from a TaggedValue with type TENSOR. - const TaggedValueTensor& tensor() const { return get(); } - - /// Retrieves underlying value from a TaggedValue with type TENSOR_SPEC. - const TensorSpec& tensor_spec() const { return get(); } - - /// Retrieves underlying value from a TaggedValue with type CAPSULE. - void* capsule() const { return get().get(); } - - /// Retrieves type of TaggedValue. - Type type() const { return type_; } - - /// @brief Implements equality operator for TaggedValue. - bool operator==(const TaggedValue& o) const { - if (type_ != o.type_) return false; - switch (type_) { - case LIST: - return data_.list == o.data_.list; - break; - case TUPLE: - return data_.tuple == o.data_.tuple; - break; - case DICT: - return data_.dict == o.data_.dict; - break; - case FUNC: - // TODO(b/187536093): This is definitely wrong because the exact ptr of - // the function pointer is almost always different, because we hold - // it by value. Two tagged values that hold the same std::function - // will have different std::function ptrs. operator== is not defined - // for std::function's so we need a better solution here, or these - // are not comparable which seems bad. - return &data_.func == &o.data_.func; - break; - case FLOAT32: - return data_.f32 == o.data_.f32; - break; - case INT64: - return data_.i64 == o.data_.i64; - break; - case STRING: - return data_.s == o.data_.s; - break; - case TENSOR: - return data_.tensor == o.data_.tensor; - case TENSOR_SPEC: - return data_.tensor_spec == o.data_.tensor_spec; - case CAPSULE: - return data_.capsule.get() == o.data_.capsule.get(); - case NONE: - return true; - } - } - - /// @brief Implements visitor pattern for doing type-based dispatch. - /// - /// @tparam R The desired return type. - /// @tparam Visitor The visitor class which has a callable operator. - /// @return The `visitor` called on the correct value. - template - R visit(Visitor visitor) { - switch (type_) { - case LIST: - return visitor(data_.list); - case TUPLE: - return visitor(data_.tuple); - case DICT: - return visitor(data_.dict); - case FUNC: - return visitor(data_.func); - case FLOAT32: - return visitor(data_.f32); - case INT64: - return visitor(data_.i64); - case STRING: - return visitor(data_.s); - case TENSOR: - return visitor(data_.tensor); - case TENSOR_SPEC: - return visitor(data_.tensor_spec); - case CAPSULE: - return visitor(data_.capsule); - case NONE: - return visitor(impl::None::GetInstance()); - } - } - - /// @brief Implements visitor pattern for doing type-based dispatch. - /// - /// @tparam R The desired return type. - /// @tparam Visitor The visitor class which has a callable operator. - /// @return The `visitor` called on the correct value. - template - R visit(Visitor visitor) const { - switch (type_) { - case LIST: - return visitor(data_.list); - case TUPLE: - return visitor(data_.tuple); - case DICT: - return visitor(data_.dict); - case FUNC: - return visitor(data_.func); - case FLOAT32: - return visitor(data_.f32); - case INT64: - return visitor(data_.i64); - case STRING: - return visitor(data_.s); - case TENSOR: - return visitor(data_.tensor); - case TENSOR_SPEC: - return visitor(data_.tensor_spec); - case CAPSULE: - return visitor(data_.capsule); - case NONE: - return visitor(impl::None::GetInstance()); - } - } - - private: - /// @brief A utility class for mapping C++ types to Type values. - template - struct EnumValueOf; - - /// @brief A utility class for accessing the `Data` union members. - template - struct UnionAccess; - - // Unsafe Move, because it assumes the union has already been destroyed - // or is new! - void MoveIntoUnion(TaggedValue&& v) { - assert(type_ == NONE); - type_ = v.type_; - if (type_ != NONE) { - visit([&v](auto& left) -> void { - using T = typename std::decay::type; - new (&left) T(std::move(UnionAccess::unsafe_reference(v))); - }); - } - // Destroy the source r-value reference (making it None) - v.destroy(); - } - - // Unsafe Move, because it assumes the union has already been destroyed - // or is new! - void CopyIntoUnion(const TaggedValue& v) { - assert(type_ == NONE); - type_ = v.type_; - if (type_ != NONE) { - visit([&v](auto& left) -> void { - using T = typename std::decay::type; - new (&left) T(UnionAccess::unsafe_reference(v)); - }); - } - } - - /// @brief The type of the TaggedValue, i.e. the "tag" of a tagged union. - /// - /// In principle this could be incorporated into the union - /// for pointer types and non-64bit values, but then int64 and float64 values - /// would need to be indirected. This means that we are aiming for a total - /// data type size of <=16 bytes, comprised of one pointer (8 bytes) and - /// one type (<=8bytes). - Type type_; - - // we use an explicit union here because we want to avoid C++17's - // variant structures due to c++14 compatibility requirements. - // TODO(b/183980966): Compare against absl::variant. - union Data { - explicit Data() {} - explicit Data(Float32 f32) : f32(f32) {} - explicit Data(Int64 i64) : i64(i64) {} - explicit Data(const char* s) : s(String(s)) {} - explicit Data(Func fn) : func(fn) {} - explicit Data(TaggedValueTensor tensor_in) { - new (&tensor) TaggedValueTensor(std::move(tensor_in)); - } - explicit Data(tensorflow::PartialTensorShape shape, - tensorflow::DataType dtype) - : tensor_spec({shape, dtype}) {} - ~Data() {} - Float32 f32; - Int64 i64; - String s; - Func func; - // TODO(aselle): look at tensorflow thing - std::shared_ptr dict; - std::shared_ptr list; - std::shared_ptr tuple; - impl::Capsule capsule; - TaggedValueTensor tensor; - TensorSpec tensor_spec; - } data_; - friend std::ostream& operator<<(std::ostream& o, const TaggedValue& v); - friend TaggedValueHash; -}; - -#define TF_ENUM_VALUE_OF(TYPE, ENUM) \ - template <> \ - struct TaggedValue::EnumValueOf { \ - static constexpr Type value = ENUM; \ - }; - -TF_ENUM_VALUE_OF(impl::Capsule, CAPSULE); -TF_ENUM_VALUE_OF(impl::Float32, FLOAT32); -TF_ENUM_VALUE_OF(impl::Int64, INT64); -TF_ENUM_VALUE_OF(impl::List, LIST); -TF_ENUM_VALUE_OF(impl::ListPtr, LIST); -TF_ENUM_VALUE_OF(impl::Tuple, TUPLE); -TF_ENUM_VALUE_OF(impl::TuplePtr, TUPLE); -TF_ENUM_VALUE_OF(impl::Dict, DICT); -TF_ENUM_VALUE_OF(impl::DictPtr, DICT); -TF_ENUM_VALUE_OF(impl::None, NONE); -TF_ENUM_VALUE_OF(impl::Func, FUNC); -TF_ENUM_VALUE_OF(impl::String, STRING); -TF_ENUM_VALUE_OF(impl::TaggedValueTensor, TENSOR); -TF_ENUM_VALUE_OF(impl::TensorSpec, TENSOR_SPEC); -#undef TF_ENUM_VALUE_OF - -#define TF_UNION_ACCESS_INSTANCE(TYPE, MEMBER) \ - template <> \ - struct TaggedValue::UnionAccess { \ - static TYPE& unsafe_reference(TaggedValue& t) { return t.data_.MEMBER; } \ - static const TYPE& unsafe_reference(const TaggedValue& t) { \ - return t.data_.MEMBER; \ - } \ - }; - -TF_UNION_ACCESS_INSTANCE(impl::Capsule, capsule); -TF_UNION_ACCESS_INSTANCE(impl::Float32, f32); -TF_UNION_ACCESS_INSTANCE(impl::Int64, i64); -TF_UNION_ACCESS_INSTANCE(impl::ListPtr, list); -TF_UNION_ACCESS_INSTANCE(impl::TuplePtr, tuple); -TF_UNION_ACCESS_INSTANCE(impl::DictPtr, dict); -TF_UNION_ACCESS_INSTANCE(impl::Func, func); -TF_UNION_ACCESS_INSTANCE(impl::String, s); -TF_UNION_ACCESS_INSTANCE(impl::TaggedValueTensor, tensor); -TF_UNION_ACCESS_INSTANCE(impl::TensorSpec, tensor_spec); -#undef TF_UNION_ACCESS_INSTANCE - -/// The union accessor for `NoneType`. -template <> -struct TaggedValue::UnionAccess { - static impl::None& unsafe_reference(TaggedValue& t) { - return None::GetInstance(); - } - static const impl::None& unsafe_reference(const TaggedValue& t) { - return None::GetInstance(); - } -}; - -/// @brief The Tuple class for holding tuples of TaggedValues. -/// TODO: Need to wrap vector in Tuple otherwise variant has duplicate types. -class Tuple { - using TU = std::vector; - using value_type = TU::value_type; - using iterator = TU::iterator; - using const_iterator = TU::const_iterator; - TU values_; - - public: - TU::iterator begin() { return values_.begin(); } - TU::iterator end() { return values_.end(); } - TU::const_iterator begin() const { return values_.begin(); } - TU::const_iterator end() const { return values_.end(); } - const TU::value_type& operator[](size_t i) const { return values_[i]; } - TU::value_type& operator[](size_t i) { return values_[i]; } - size_t size() const { return values_.size(); } - void emplace_back(TaggedValue v) { values_.emplace_back(std::move(v)); } - void push_back(const TaggedValue& v) { values_.push_back(v); } -}; - -/// Hashing infrastructure for Tuple. -inline size_t TaggedValueHash::operator()(const Tuple& t) const { - std::size_t hash = 0; - for (auto& i : t) { - hash ^= TaggedValueHash()(i); - } - return hash; -} - -/// @brief The TaggedValueHashVisitor class for doing type-based hashing -/// of TaggedValues. -class TaggedValueHashVisitor { - public: - size_t operator()(const TaggedValueTensor& v) { - assert(false); - return 0; - } - size_t operator()(const ListPtr& v) { - assert(false); - return 0; - } - size_t operator()(const DictPtr& v) { - assert(false); - return 0; - } - size_t operator()(const Capsule& t) { return std::hash()(t); } - size_t operator()(const Func& t) { - assert(false); - return 0; - } - size_t operator()(const TuplePtr& t) { - std::size_t hash = 0; - for (auto it = t->begin(); it != t->end(); ++it) { - hash ^= TaggedValueHash()(*it); - } - return hash; - } - template - size_t operator()(const T& t) { - return absl::Hash()(t); - } -}; - -/// Hashing infrastructure for TaggedValues. Hashable TaggedValues overload -/// `AbslHashValue`. Non-hashable structures return 0, since we have no easy -/// way to abort. -inline size_t TaggedValueHash::operator()( - const TaggedValue& v) const { - return v.visit(TaggedValueHashVisitor()); -} - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ diff --git a/tensorflow/cc/experimental/libtf/value_iostream.h b/tensorflow/cc/experimental/libtf/value_iostream.h deleted file mode 100644 index c26ed493890407..00000000000000 --- a/tensorflow/cc/experimental/libtf/value_iostream.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_IOSTREAM_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_IOSTREAM_H_ - -#include - -#include "tensorflow/cc/experimental/libtf/value.h" - -namespace tf { -namespace libtf { -namespace impl { - -inline std::ostream& operator<<(std::ostream& o, const Dict& v) { - o << "{"; - for (auto& x : v) { - o << x.first; - o << ": "; - o << x.second; - o << ", "; - } - o << "}"; - return o; -} -template -inline std::ostream& OutList(std::ostream& o, IT v_start, IT const v_end, - char start, char end) { - o << start; - for (IT p = v_start; p != v_end; ++p) { - o << *p; - o << ", "; - } - o << end; - return o; -} - -class TaggedValueIOStreamVisitor { - std::ostream& o_; - - public: - explicit TaggedValueIOStreamVisitor(std::ostream& o) : o_(o) {} - - std::ostream& operator()(const ListPtr& x) { - OutList(o_, x->begin(), x->end(), '[', ']'); - return o_; - } - std::ostream& operator()(const TuplePtr& x) { - OutList(o_, x->begin(), x->end(), '(', ')'); - return o_; - } - std::ostream& operator()(const DictPtr& x) { - o_ << *x; - return o_; - } - std::ostream& operator()(const Capsule& x) { - o_ << "Capsule(" << x.get() << ")"; - return o_; - } - std::ostream& operator()(const Func& x) { - o_ << "Func"; - return o_; - } - std::ostream& operator()(const TaggedValueTensor& x) { - o_ << "Tensor"; - return o_; - } - - template - std::ostream& operator()(const T& x) { - o_ << x; - return o_; - } -}; - -inline std::ostream& operator<<(std::ostream& o, const TaggedValue& v) { - return v.visit(TaggedValueIOStreamVisitor(o)); -} -} // namespace impl -} // namespace libtf -} // namespace tf -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_IOSTREAM_H_ From 9aa7f647e073652e1e91fede9f470c73df748e18 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 25 Sep 2024 11:13:27 -0700 Subject: [PATCH 269/483] Integrate StableHLO at openxla/stablehlo@ca13d31b PiperOrigin-RevId: 678772867 --- third_party/stablehlo/workspace.bzl | 4 ++-- third_party/xla/third_party/stablehlo/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 61608344c772f9..2be9b211c51514 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "9bb28f84c281795783639364b727e4398dcec570" - STABLEHLO_SHA256 = "44f90a3b6e8c7fba454644a7457b71327f643196ee1b1f69a3b210604ea8b5a9" + STABLEHLO_COMMIT = "ca13d31b5ed0b2053dde0a624480ad765e219ebf" + STABLEHLO_SHA256 = "123462093f087f2576bb6a6cc471370eed2d43c291f881ff359fd4ca812003db" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 61608344c772f9..2be9b211c51514 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "9bb28f84c281795783639364b727e4398dcec570" - STABLEHLO_SHA256 = "44f90a3b6e8c7fba454644a7457b71327f643196ee1b1f69a3b210604ea8b5a9" + STABLEHLO_COMMIT = "ca13d31b5ed0b2053dde0a624480ad765e219ebf" + STABLEHLO_SHA256 = "123462093f087f2576bb6a6cc471370eed2d43c291f881ff359fd4ca812003db" # LINT.ThenChange(Google-internal path) tf_http_archive( From 2296231a50fe59ee6556bef29db5b3188de63a85 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Wed, 25 Sep 2024 11:15:30 -0700 Subject: [PATCH 270/483] Added A18/A18Pro/M2 and later info to gpu_info. PiperOrigin-RevId: 678773838 --- .../lite/delegates/gpu/common/gpu_info.cc | 120 ++++++++++++------ .../lite/delegates/gpu/common/gpu_info.h | 12 ++ 2 files changed, 90 insertions(+), 42 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc index 2627adda13c6bd..0b7bc03136c842 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc @@ -386,37 +386,46 @@ int AdrenoInfo::GetComputeUnitsCount() const { } AppleInfo::AppleInfo(const std::string& gpu_description) { - const std::map kMapping = { - {"apple a7 gpu", AppleGpu::kA7}, - {"apple a8 gpu", AppleGpu::kA8}, - {"apple a8x gpu", AppleGpu::kA8X}, - {"apple a9 gpu", AppleGpu::kA9}, - {"apple a9x gpu", AppleGpu::kA9X}, - {"apple a10 gpu", AppleGpu::kA10}, - {"apple a10x gpu", AppleGpu::kA10X}, - {"apple a11 gpu", AppleGpu::kA11}, - {"apple a12 gpu", AppleGpu::kA12}, - {"apple a12x gpu", AppleGpu::kA12X}, - {"apple a12z gpu", AppleGpu::kA12Z}, - {"apple a13 gpu", AppleGpu::kA13}, - {"apple a14 gpu", AppleGpu::kA14}, - {"apple a15 gpu", AppleGpu::kA15}, - {"apple a16 gpu", AppleGpu::kA16}, - {"apple a17 pro gpu", AppleGpu::kA17Pro}, - // on tablets we have metal device name "apple m1 gpu" - // and on notebooks "apple m1" - {"apple m1 gpu", AppleGpu::kM1}, + const std::vector> kMapping = { + {"apple a7", AppleGpu::kA7}, + {"apple a8", AppleGpu::kA8}, + {"apple a8x", AppleGpu::kA8X}, + {"apple a9", AppleGpu::kA9}, + {"apple a9x", AppleGpu::kA9X}, + {"apple a10", AppleGpu::kA10}, + {"apple a10x", AppleGpu::kA10X}, + {"apple a11", AppleGpu::kA11}, + {"apple a12", AppleGpu::kA12}, + {"apple a12x", AppleGpu::kA12X}, + {"apple a12z", AppleGpu::kA12Z}, + {"apple a13", AppleGpu::kA13}, + {"apple a14", AppleGpu::kA14}, + {"apple a15", AppleGpu::kA15}, + {"apple a16", AppleGpu::kA16}, + {"apple a17 pro", AppleGpu::kA17Pro}, + {"apple a18", AppleGpu::kA18}, + {"apple a18 pro", AppleGpu::kA18Pro}, {"apple m1", AppleGpu::kM1}, {"apple m1 pro", AppleGpu::kM1Pro}, {"apple m1 max", AppleGpu::kM1Max}, {"apple m1 ultra", AppleGpu::kM1Ultra}, {"apple m2", AppleGpu::kM2}, + {"apple m2 pro", AppleGpu::kM2Pro}, + {"apple m2 max", AppleGpu::kM2Max}, + {"apple m2 ultra", AppleGpu::kM2Ultra}, + {"apple m3", AppleGpu::kM3}, + {"apple m3 pro", AppleGpu::kM3Pro}, + {"apple m3 max", AppleGpu::kM3Max}, + {"apple m4", AppleGpu::kM4}, }; - auto it = kMapping.find(gpu_description); - if (it != kMapping.end()) { - gpu_type = it->second; - } else { - gpu_type = AppleGpu::kUnknown; + gpu_type = AppleGpu::kUnknown; + std::string gpu_name = ""; + for (const auto& v : kMapping) { + if (gpu_description.find(v.first) != std::string::npos && + v.first.size() > gpu_name.size()) { + gpu_name = v.first; + gpu_type = v.second; + } } gpu_family = GetGpuFamily(); } @@ -439,9 +448,10 @@ AppleInfo::Family AppleInfo::GetGpuFamily() const { } else if (gpu_type == AppleGpu::kA14 || IsM1Series()) { return AppleInfo::Family::kApple7; } else if (gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || - gpu_type == AppleGpu::kM2) { + IsM2Series()) { return AppleInfo::Family::kApple8; - } else if (gpu_type == AppleGpu::kA17Pro) { + } else if (gpu_type == AppleGpu::kA17Pro || gpu_type == AppleGpu::kA18 || + gpu_type == AppleGpu::kA18Pro || IsM3Series() || IsM4Series()) { return AppleInfo::Family::kApple9; } return AppleInfo::Family::kApple1; @@ -496,27 +506,28 @@ bool AppleInfo::IsM1Series() const { gpu_type == AppleGpu::kM1Max || gpu_type == AppleGpu::kM1Ultra; } +bool AppleInfo::IsM2Series() const { + return gpu_type == AppleGpu::kM2 || gpu_type == AppleGpu::kM2Pro || + gpu_type == AppleGpu::kM2Max || gpu_type == AppleGpu::kM2Ultra; +} + +bool AppleInfo::IsM3Series() const { + return gpu_type == AppleGpu::kM3 || gpu_type == AppleGpu::kM3Pro || + gpu_type == AppleGpu::kM3Max; +} + +bool AppleInfo::IsM4Series() const { return gpu_type == AppleGpu::kM4; } + bool AppleInfo::IsBionic() const { - return gpu_type == AppleGpu::kA11 || gpu_type == AppleGpu::kA12 || - gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z || - gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14 || - gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || - gpu_type == AppleGpu::kA17Pro || gpu_type == AppleGpu::kM1 || - gpu_type == AppleGpu::kM1Pro || gpu_type == AppleGpu::kM1Max || - gpu_type == AppleGpu::kM1Ultra || gpu_type == AppleGpu::kM2; + return gpu_family >= AppleInfo::Family::kApple4; } bool AppleInfo::IsSIMDMatMulSupported() const { - return gpu_type == AppleGpu::kA14 || gpu_type == AppleGpu::kA15 || - gpu_type == AppleGpu::kA16 || gpu_type == AppleGpu::kA17Pro || - gpu_type == AppleGpu::kM1 || gpu_type == AppleGpu::kM1Pro || - gpu_type == AppleGpu::kM1Max || gpu_type == AppleGpu::kM1Ultra || - gpu_type == AppleGpu::kM2; + return gpu_family >= AppleInfo::Family::kApple7; } bool AppleInfo::IsSIMDMatMulFp32Perf2x() const { - return gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || - gpu_type == AppleGpu::kA17Pro || gpu_type == AppleGpu::kM2; + return gpu_family >= AppleInfo::Family::kApple8 || IsM1Series(); } bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); } @@ -560,6 +571,10 @@ int AppleInfo::GetComputeUnitsCount() const { return 5; case AppleGpu::kA17Pro: return 6; + case AppleGpu::kA18: + return 5; + case AppleGpu::kA18Pro: + return 6; case AppleGpu::kM1: // approximate, can be 7 or 8 return 8; @@ -573,7 +588,28 @@ int AppleInfo::GetComputeUnitsCount() const { // approximate, 64 is max possible return 64; case AppleGpu::kM2: - // approximate, 10 is max possible + // approximate + return 10; + case AppleGpu::kM2Pro: + // approximate + return 19; + case AppleGpu::kM2Max: + // approximate + return 38; + case AppleGpu::kM2Ultra: + // approximate + return 76; + case AppleGpu::kM3: + // approximate + return 10; + case AppleGpu::kM3Pro: + // approximate + return 18; + case AppleGpu::kM3Max: + // approximate + return 40; + case AppleGpu::kM4: + // approximate return 10; case AppleGpu::kUnknown: return 4; diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h index f5d73a2f341e28..c1b4eb6454e4a1 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.h +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h @@ -175,11 +175,20 @@ enum class AppleGpu { kA15, kA16, kA17Pro, + kA18, + kA18Pro, kM1, kM1Pro, kM1Max, kM1Ultra, kM2, + kM2Pro, + kM2Max, + kM2Ultra, + kM3, + kM3Pro, + kM3Max, + kM4, }; struct AppleInfo { @@ -216,6 +225,9 @@ struct AppleInfo { bool IsBionic() const; bool IsM1Series() const; + bool IsM2Series() const; + bool IsM3Series() const; + bool IsM4Series() const; bool IsSIMDMatMulSupported() const; // Often, fp32 alu performance is 1/2 of fp16 alu performance From e64cf7c50d5e74023f18297b46e997c47f877fda Mon Sep 17 00:00:00 2001 From: Eric Salo Date: Wed, 25 Sep 2024 11:30:14 -0700 Subject: [PATCH 271/483] cleanup: remove api_version from BUILD files PiperOrigin-RevId: 678781652 --- tensorflow/python/keras/protobuf/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow/python/keras/protobuf/BUILD b/tensorflow/python/keras/protobuf/BUILD index 73db8516220ffb..17825a4aeefe1d 100644 --- a/tensorflow/python/keras/protobuf/BUILD +++ b/tensorflow/python/keras/protobuf/BUILD @@ -31,19 +31,16 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "saved_metadata_proto_py_pb2", -# api_version = 2, # deps = [":saved_metadata_proto"], # ) # # py_proto_library( # name = "projector_config_proto_py_pb2", -# api_version = 2, # deps = [":projector_config_proto"], # ) # # py_proto_library( # name = "versions_proto_py_pb2", -# api_version = 2, # deps = [":versions_proto"], # ) # copybara:uncomment_end From ba594ccda597637190c0596dd4a737ac05cc4930 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 25 Sep 2024 11:30:44 -0700 Subject: [PATCH 272/483] Remove unnecessary namespace qualifiers. PiperOrigin-RevId: 678781900 --- .../hlo_to_mhlo/module_attributes_importer.cc | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc index 02eb02fd869d8a..fd363f441bb617 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc @@ -64,24 +64,23 @@ constexpr char kSpmdParametersShardings[] = "mhlo.spmd_parameters_shardings"; constexpr char kUseAutoSpmdPartitioning[] = "mhlo.use_auto_spmd_partitioning"; mlir::ArrayAttr ConvertCrossProgramPrefetches( - const absl::Span prefetches, - const xla::HloComputation& entryComputation, mlir::Builder* builder, + const absl::Span prefetches, + const HloComputation& entryComputation, mlir::Builder* builder, bool flatten_computation_args_result) { llvm::SmallVector shapes; shapes.reserve(prefetches.size()); if (flatten_computation_args_result) { - llvm::SmallVector> + llvm::SmallVector> original_param_index_to_flattened_arg_index; int64_t arg_index = 0; - for (xla::HloInstruction* param_instruction : + for (HloInstruction* param_instruction : entryComputation.parameter_instructions()) { auto& param_map = original_param_index_to_flattened_arg_index.emplace_back(); - xla::ShapeUtil::ForEachLeafShape( - param_instruction->shape(), - [&](const xla::Shape&, const xla::ShapeIndex& index) { - param_map[index] = arg_index++; - }); + ShapeUtil::ForEachLeafShape(param_instruction->shape(), + [&](const Shape&, const ShapeIndex& index) { + param_map[index] = arg_index++; + }); } for (const auto& [parameter, index, alt_memory_offset] : prefetches) shapes.push_back(mlir::mhlo::CrossProgramPrefetchAttr::get( @@ -100,7 +99,7 @@ mlir::ArrayAttr ConvertCrossProgramPrefetches( } void ImportEntryComputationParameterLayoutAndTiles( - const xla::HloModule& hlo_module, mlir::ModuleOp module, + const HloModule& hlo_module, mlir::ModuleOp module, const ComputationLayout& computation_layout, mlir::Builder builder) { llvm::SmallVector parameter_layouts; llvm::SmallVector parameter_tiles; @@ -132,7 +131,7 @@ void ImportEntryComputationParameterLayoutAndTiles( } void ImportEntryComputationResultLayoutAndTiles( - const xla::HloModule& hlo_module, mlir::ModuleOp module, + const HloModule& hlo_module, mlir::ModuleOp module, const ComputationLayout& computation_layout, mlir::Builder builder) { if (computation_layout.result_layout().shape().IsTuple()) { llvm::SmallVector result_layouts; @@ -162,7 +161,7 @@ void ImportEntryComputationResultLayoutAndTiles( } // namespace -void ImportCrossProgramPrefetches(const xla::HloModule& hlo_module, +void ImportCrossProgramPrefetches(const HloModule& hlo_module, mlir::ModuleOp module, bool flatten_computation_args_result, mlir::Builder builder) { @@ -173,7 +172,7 @@ void ImportCrossProgramPrefetches(const xla::HloModule& hlo_module, flatten_computation_args_result)); } -void ImportEntryComputationLayoutAndTiles(const xla::HloModule& hlo_module, +void ImportEntryComputationLayoutAndTiles(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { const auto& computation_layout = hlo_module.entry_computation_layout(); @@ -195,7 +194,7 @@ void ImportEntryComputationLayoutAndTiles(const xla::HloModule& hlo_module, } } -void ImportFrontendAttributes(const xla::HloModule& hlo_module, +void ImportFrontendAttributes(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { if (!hlo_module.frontend_attributes().map().empty()) { llvm::SmallVector frontend_attributes; @@ -230,7 +229,7 @@ void ImportNumPartitions(const xla::HloModule& hlo_module, } } -void ImportNumReplicas(const xla::HloModule& hlo_module, mlir::ModuleOp module, +void ImportNumReplicas(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { const auto& config = hlo_module.config(); if (config.replica_count() != 1) { @@ -247,7 +246,7 @@ void ImportSpmdOutputSharding(const xla::HloModule& hlo_module, ConvertSharding(hlo_module.spmd_output_sharding(), &builder)); } -void ImportSpmdParametersShardings(const xla::HloModule& hlo_module, +void ImportSpmdParametersShardings(const HloModule& hlo_module, mlir::ModuleOp module, bool flatten_computation_args_result, mlir::Builder builder) { @@ -266,7 +265,7 @@ void ImportSpmdParametersShardings(const xla::HloModule& hlo_module, } } -void ImportUseAutoSpmdPartitioning(const xla::HloModule& hlo_module, +void ImportUseAutoSpmdPartitioning(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { module->setAttr(kUseAutoSpmdPartitioning, From fe5c1b67cfb1063227e9dc089ab89f34353998ea Mon Sep 17 00:00:00 2001 From: Grant Jensen Date: Wed, 25 Sep 2024 11:53:32 -0700 Subject: [PATCH 273/483] [tflite-gpu] Add cbrt to gpu_compatibility PiperOrigin-RevId: 678790686 --- .../lite/tools/versioning/gpu_compatibility.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensorflow/lite/tools/versioning/gpu_compatibility.cc b/tensorflow/lite/tools/versioning/gpu_compatibility.cc index a71042ad7f32f0..4b1ec4299dd821 100644 --- a/tensorflow/lite/tools/versioning/gpu_compatibility.cc +++ b/tensorflow/lite/tools/versioning/gpu_compatibility.cc @@ -1152,6 +1152,18 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, "Require size(indices) = rank(operand)"); } return absl::OkStatus(); + case kTfLiteBuiltinStablehloCbrt: + if (op_sig.inputs[0].type != kTfLiteFloat16 && + op_sig.inputs[0].type != kTfLiteFloat32 && + op_sig.inputs[0].type != kTfLiteBFloat16) { + return absl::InvalidArgumentError("Only support float inputs"); + } + if (op_sig.inputs[0].type != op_sig.outputs[0].type) { + return absl::InvalidArgumentError("Input and output types must match"); + } + return CheckInputsConstsOutputs(op_sig, /*required_runtime_inputs=*/1, + /*required_const_inputs=*/0, + /*required_outputs=*/1); case kTfLiteBuiltinStablehloClamp: if ((op_sig.inputs.at(0).type != op_sig.inputs.at(1).type) || (op_sig.inputs.at(1).type != op_sig.inputs.at(2).type)) { From 36b37ed6fa71ec440a685711242e362e0f33ad96 Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Wed, 25 Sep 2024 12:03:35 -0700 Subject: [PATCH 274/483] Reverts 45550c14cfb765a1a3d737abde366723025422c0 PiperOrigin-RevId: 678794783 --- .../lite/python/_pywrap_converter_api.pyi | 2 +- .../mlir/lite/python/converter_python_api.cc | 3 +- .../mlir/lite/python/converter_python_api.h | 7 +- .../python/converter_python_api_wrapper.cc | 20 +- .../mlir/lite/python/wrap_converter.py | 8 +- tensorflow/lite/python/BUILD | 1 - tensorflow/lite/python/convert.py | 218 +++--------------- tensorflow/lite/python/convert_test.py | 31 --- tensorflow/lite/python/lite.py | 1 - tensorflow/lite/python/lite_constants.py | 7 - .../metrics/metrics_nonportable_test.py | 2 - 11 files changed, 45 insertions(+), 255 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi b/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi index cdb1e881b7dc9f..989d4f1dbe56fb 100644 --- a/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi +++ b/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -def Convert(model_flags_proto_txt_raw: object, toco_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., enable_mlir_converter: bool = ..., quantization_py_function_library = ...) -> object: ... +def Convert(model_flags_proto_txt_raw: object, converter_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., quantization_py_function_library = ...) -> object: ... def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ..., disable_per_channel_for_dense_layers: bool = ..., debug_options_proto_txt_raw: object = ...) -> object: ... def ExperimentalMlirSparsifyModel(input_contents_txt_raw: object) -> object: ... def FlatBufferToMlir(arg0: str, arg1: bool) -> str: ... diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc index 1b0250fd535e9a..c7059d721a062f 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -61,7 +60,7 @@ namespace tflite { PyObject* Convert(PyObject* model_flags_proto_txt_raw, PyObject* converter_flags_proto_txt_raw, PyObject* input_contents_txt_raw, bool extended_return, - PyObject* debug_info_txt_raw, bool enable_mlir_converter, + PyObject* debug_info_txt_raw, const tensorflow::quantization::PyFunctionLibrary* quantization_py_function_library) { // Use Python C API to validate and convert arguments. In py3 (bytes), diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.h b/tensorflow/compiler/mlir/lite/python/converter_python_api.h index 6dbcf0603d7e8c..cfcba696d01b7a 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api.h +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.h @@ -30,15 +30,12 @@ namespace tflite { // representing the contents of the converted model. When extended_return // flag is set to true returns a dictionary that contains string representation // of the converted model and some statistics like arithmetic ops count. -// `debug_info_str` contains the `GraphDebugInfo` proto. When -// `enable_mlir_converter` is True, use MLIR-based conversion instead of -// TOCO conversion. +// `debug_info_str` contains the `GraphDebugInfo` proto. PyObject* Convert(PyObject* model_flags_proto_txt_raw, - PyObject* toco_flags_proto_txt_raw, + PyObject* converter_flags_proto_txt_raw, PyObject* input_contents_txt_raw, bool extended_return = false, PyObject* debug_info_txt_raw = nullptr, - bool enable_mlir_converter = false, const tensorflow::quantization::PyFunctionLibrary* quantization_py_function_library = nullptr); diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api_wrapper.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api_wrapper.cc index de46a6f9115339..83e3da9e540bcf 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api_wrapper.cc +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api_wrapper.cc @@ -27,21 +27,21 @@ PYBIND11_MODULE(_pywrap_converter_api, m) { m.def( "Convert", [](py::object model_flags_proto_txt_raw, - py::object toco_flags_proto_txt_raw, py::object input_contents_txt_raw, - bool extended_return, py::object debug_info_txt_raw, - bool enable_mlir_converter, + py::object converter_flags_proto_txt_raw, + py::object input_contents_txt_raw, bool extended_return, + py::object debug_info_txt_raw, const tensorflow::quantization::PyFunctionLibrary* quantization_py_function_library) { return tensorflow::PyoOrThrow(tflite::Convert( - model_flags_proto_txt_raw.ptr(), toco_flags_proto_txt_raw.ptr(), - input_contents_txt_raw.ptr(), extended_return, - debug_info_txt_raw.ptr(), enable_mlir_converter, + model_flags_proto_txt_raw.ptr(), + converter_flags_proto_txt_raw.ptr(), input_contents_txt_raw.ptr(), + extended_return, debug_info_txt_raw.ptr(), quantization_py_function_library)); }, - py::arg("model_flags_proto_txt_raw"), py::arg("toco_flags_proto_txt_raw"), + py::arg("model_flags_proto_txt_raw"), + py::arg("converter_flags_proto_txt_raw"), py::arg("input_contents_txt_raw"), py::arg("extended_return") = false, py::arg("debug_info_txt_raw") = py::none(), - py::arg("enable_mlir_converter") = false, py::arg("quantization_py_function_library") = py::none(), R"pbdoc( Convert a model represented in `input_contents`. `model_flags_proto` @@ -50,9 +50,7 @@ PYBIND11_MODULE(_pywrap_converter_api, m) { representing the contents of the converted model. When extended_return flag is set to true returns a dictionary that contains string representation of the converted model and some statistics like arithmetic ops count. - `debug_info_str` contains the `GraphDebugInfo` proto. When - `enable_mlir_converter` is True, tuse MLIR-based conversion instead of - TOCO conversion. + `debug_info_str` contains the `GraphDebugInfo` proto. )pbdoc"); m.def( "ExperimentalMlirQuantizeModel", diff --git a/tensorflow/compiler/mlir/lite/python/wrap_converter.py b/tensorflow/compiler/mlir/lite/python/wrap_converter.py index 1c198f062388fc..ee3c5f2435fd9d 100644 --- a/tensorflow/compiler/mlir/lite/python/wrap_converter.py +++ b/tensorflow/compiler/mlir/lite/python/wrap_converter.py @@ -22,19 +22,17 @@ def wrapped_convert( model_flags_str, - toco_flags_str, + converter_flags_str, input_data_str, debug_info_str, - enable_mlir_converter, ): - """Wraps TocoConvert with lazy loader.""" + """Wraps Convert with lazy loader.""" return _pywrap_converter_api.Convert( model_flags_str, - toco_flags_str, + converter_flags_str, input_data_str, False, # extended_return debug_info_str, - enable_mlir_converter, py_function_lib.PyFunctionLibrary(), ) diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index c5a4ba27639458..d6f626a28594e2 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -490,7 +490,6 @@ pytype_strict_library( "//tensorflow/lite/tools:flatbuffer_utils", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/platform:resource_loader", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", ], diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index f4206cb68c932d..96b17f996e45b5 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -14,13 +14,8 @@ # ============================================================================== """Converts a frozen graph into a TFLite FlatBuffer.""" -import distutils.spawn import enum import hashlib -import os as _os -import platform as _platform -import subprocess as _subprocess -import tempfile as _tempfile from typing import Optional import warnings @@ -41,7 +36,6 @@ from tensorflow.lite.tools import flatbuffer_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape -from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export as _tf_export @@ -143,18 +137,6 @@ def convert_inference_tf_type_to_tflite_type( return tflite_type -# Find the deprecated conversion binary using the resource loader if using from -# bazel, otherwise we are in a pip where console_scripts already has the tool. -if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY: - _deprecated_conversion_binary = "" -else: - _deprecated_conversion_binary = _resource_loader.get_path_to_datafile( - "../toco/python/toco_from_protos" - ) - if not _os.path.exists(_deprecated_conversion_binary): - _deprecated_conversion_binary = "toco_from_protos" - - def _try_convert_to_unicode(output): if output is None: return "" @@ -315,7 +297,6 @@ def convert( conversion_flags: _conversion_flags_pb2.ConverterFlags, input_data_str: Optional[str] = None, debug_info_str: Optional[str] = None, - enable_mlir_converter: bool = True, ): """Converts `input_data_str` to a TFLite model. @@ -327,178 +308,44 @@ def convert( it can be hlo text or proto) debug_info_str: Serialized `GraphDebugInfo` proto describing logging information. - enable_mlir_converter: Enables MLIR-based conversion. Returns: Converted model in serialized form (e.g. a TFLITE model is common). Raises: ConverterError: When conversion fails in TFLiteConverter, usually due to ops not being supported. - RuntimeError: When conversion fails, an exception is raised with the error - message embedded. """ - # Historically, deprecated conversion failures would trigger a crash, so we - # attempt to run the converter out-of-process. The current MLIR conversion - # pipeline surfaces errors instead, and can be safely run in-process. - if enable_mlir_converter or not _deprecated_conversion_binary: - try: - return wrap_converter.wrapped_convert( - model_flags.SerializeToString(), - conversion_flags.SerializeToString(), - input_data_str, - debug_info_str, - enable_mlir_converter, - ) - except Exception as e: - converter_error = ConverterError(str(e)) - - for error_data in _metrics_wrapper.retrieve_collected_errors(): - converter_error.append_error(error_data) - # Seldom we encounter the case where an unsupported - # `StatefulPartitionedCallOp` is not inlined and remains in the final - # IR. If this occurs we can set `guarantee_all_funcs_one_use` and retry. - # This makes the converter copy functions definitions called by - # multiple StatefulPartitionedCall, thus allowing them to be properly - # inlined. - if ( - error_data.error_code - == converter_error_data_pb2.ConverterErrorData.ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR - and not conversion_flags.guarantee_all_funcs_one_use - ): - conversion_flags.guarantee_all_funcs_one_use = True - return convert( - model_flags, - conversion_flags, - input_data_str, - debug_info_str, - enable_mlir_converter, - ) - raise converter_error - - return _run_deprecated_conversion_binary( - model_flags.SerializeToString(), - conversion_flags.SerializeToString(), - input_data_str, - debug_info_str, - ) - - -@convert_phase( - Component.CONVERT_TF_TO_TFLITE_MODEL, - SubComponent.CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER, -) -def _run_deprecated_conversion_binary( - model_flags_str, conversion_flags_str, input_data_str, debug_info_str=None -): - """Convert `input_data_str` using deprecated conversion binary. - - Args: - model_flags_str: Serialized proto describing model properties, see - `model_flags.proto`. - conversion_flags_str: Serialized proto describing TFLite converter - properties, see `compiler/mlir/lite/converter_flags.proto`. - input_data_str: Input data in serialized form (e.g. a graphdef is common) - debug_info_str: Serialized `GraphDebugInfo` proto describing logging - information. (default None) - - Returns: - Converted model in serialized form (e.g. a TFLITE model is common). - Raises: - ConverterError: When cannot find the deprecated conversion binary. - RuntimeError: When conversion fails, an exception is raised with the error - message embedded. - """ - if distutils.spawn.find_executable(_deprecated_conversion_binary) is None: - raise ConverterError("""Could not find `toco_from_protos` binary, make sure -your virtualenv bin directory or pip local bin directory is in your path. -In particular, if you have installed TensorFlow with --user, make sure you -add the install directory to your path. - -For example: -Linux: export PATH=$PATH:~/.local/bin/ -Mac: export PATH=$PATH:~/Library/Python//bin - -Alternative, use virtualenv.""") - # Windows and TemporaryFile are not that useful together, - # since you cannot have two readers/writers. So we have to - # make the temporaries and close and delete them explicitly. - conversion_filename: str = None - model_filename: str = None - input_filename: str = None - output_filename: str = None try: - # Build all input files - with ( - _tempfile.NamedTemporaryFile(delete=False) as fp_conversion, - _tempfile.NamedTemporaryFile(delete=False) as fp_model, - _tempfile.NamedTemporaryFile(delete=False) as fp_input, - _tempfile.NamedTemporaryFile(delete=False) as fp_debug, - ): - conversion_filename = fp_conversion.name - input_filename = fp_input.name - model_filename = fp_model.name - debug_filename = fp_debug.name - - fp_model.write(model_flags_str) - fp_conversion.write(conversion_flags_str) - fp_input.write(input_data_str) - debug_info_str = debug_info_str if debug_info_str else "" - # if debug_info_str contains a "string value", then the call to - # fp_debug.write(debug_info_str) will fail with the following error - # - # TypeError: a bytes-like object is required, not 'str' - # - # Some of the subtests within the "convert_test" unit-test fail - # with the error shown above. So watch out for that scenario and - # convert debug_info_str to bytes where needed - if not isinstance(debug_info_str, bytes): - fp_debug.write(debug_info_str.encode("utf-8")) - else: - fp_debug.write(debug_info_str) - - # Reserve an output file - with _tempfile.NamedTemporaryFile(delete=False) as fp: - output_filename = fp.name - - # Run - cmd = [ - _deprecated_conversion_binary, - model_filename, - conversion_filename, - input_filename, - output_filename, - "--debug_proto_file={}".format(debug_filename), - ] - cmdline = " ".join(cmd) - is_windows = _platform.system() == "Windows" - proc = _subprocess.Popen( - cmdline, - shell=True, - stdout=_subprocess.PIPE, - stderr=_subprocess.STDOUT, - close_fds=not is_windows, + return wrap_converter.wrapped_convert( + model_flags.SerializeToString(), + conversion_flags.SerializeToString(), + input_data_str, + debug_info_str, ) - stdout, stderr = proc.communicate() - exitcode = proc.returncode - if exitcode == 0: - with open(output_filename, "rb") as fp: - return fp.read() - else: - stdout = _try_convert_to_unicode(stdout) - stderr = _try_convert_to_unicode(stderr) - raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr)) - finally: - # Must manually cleanup files. - for filename in [ - conversion_filename, - input_filename, - model_filename, - output_filename, - ]: - try: - _os.unlink(filename) - except (OSError, TypeError): - pass + except Exception as e: + converter_error = ConverterError(str(e)) + + for error_data in _metrics_wrapper.retrieve_collected_errors(): + converter_error.append_error(error_data) + # Seldom we encounter the case where an unsupported + # `StatefulPartitionedCallOp` is not inlined and remains in the final + # IR. If this occurs we can set `guarantee_all_funcs_one_use` and retry. + # This makes the converter copy functions definitions called by + # multiple StatefulPartitionedCall, thus allowing them to be properly + # inlined. + if ( + error_data.error_code + == converter_error_data_pb2.ConverterErrorData.ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR + and not conversion_flags.guarantee_all_funcs_one_use + ): + conversion_flags.guarantee_all_funcs_one_use = True + return convert( + model_flags, + conversion_flags, + input_data_str, + debug_info_str, + ) + raise converter_error def build_model_flags( @@ -909,7 +756,6 @@ def convert_graphdef_with_arrays( """ model_flags = build_model_flags(**kwargs) conversion_flags = build_conversion_flags(**kwargs) - enable_mlir_converter = kwargs.get("enable_mlir_converter", True) quantized_input_stats = kwargs.get("quantized_input_stats", None) for idx, (name, shape) in enumerate(input_arrays_with_shape): @@ -940,7 +786,6 @@ def convert_graphdef_with_arrays( conversion_flags, input_data.SerializeToString(), debug_info_str=None, - enable_mlir_converter=enable_mlir_converter, ) return data @@ -972,7 +817,6 @@ def convert_graphdef(input_data, input_tensors, output_tensors, **kwargs): conversion_flags = build_conversion_flags(**kwargs) saved_model_dir = kwargs.get("saved_model_dir", None) input_shapes = kwargs.get("input_shapes", None) - enable_mlir_converter = kwargs.get("enable_mlir_converter", True) quantized_input_stats = kwargs.get("quantized_input_stats", None) debug_info = kwargs.get("debug_info", None) @@ -1030,7 +874,6 @@ def convert_graphdef(input_data, input_tensors, output_tensors, **kwargs): conversion_flags, input_data.SerializeToString(), debug_info_str=debug_info.SerializeToString() if debug_info else None, - enable_mlir_converter=enable_mlir_converter, ) return data @@ -1047,7 +890,6 @@ def convert_saved_model(**kwargs): conversion_flags, input_data_str=None, debug_info_str=None, - enable_mlir_converter=True, ) return data @@ -1075,7 +917,6 @@ def convert_jax_hlo(input_content, input_names, is_proto_format, **kwargs): conversion_flags, input_data_str=input_content, debug_info_str=None, - enable_mlir_converter=True, ) return data @@ -1103,7 +944,6 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): Raises: Defined in `convert`. """ - kwargs["enable_mlir_converter"] = kwargs.get("enable_mlir_converter", False) return convert_graphdef( input_data, input_tensors, output_tensors, *args, **kwargs ) diff --git a/tensorflow/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py index 91b3e810328b16..aad267ed4384ea 100644 --- a/tensorflow/lite/python/convert_test.py +++ b/tensorflow/lite/python/convert_test.py @@ -41,7 +41,6 @@ def _mock_wrapped_convert( conversion_flags_str="", unused_input_data_str="", unused_debug_info_str="", - unused_enable_mlir_converter=True, ): # Simulate the converter throwing and error when # `guarantee_all_funcs_one_use` is not set. @@ -76,32 +75,6 @@ def testBasic(self): ) self.assertTrue(tflite_model) - @mock.patch.object( - convert, - "_deprecated_conversion_binary", - new="tocos_from_proto", - ) - @mock.patch.object( - convert, - "_run_deprecated_conversion_binary", - autospec=True, - ) - def testBasicDeprecatedConversionBinary(self, mock_func): - with ops.Graph().as_default(): - in_tensor = array_ops.placeholder( - shape=[1, 16, 16, 3], dtype=dtypes.float32 - ) - out_tensor = in_tensor + in_tensor - sess = session.Session() - - convert.convert_graphdef( - sess.graph_def, - input_tensors=[in_tensor], - output_tensors=[out_tensor], - enable_mlir_converter=False, - ) - mock_func.assert_called_once() - @mock.patch.object( convert.wrap_converter, "wrapped_convert", new=_mock_wrapped_convert ) @@ -125,7 +98,6 @@ def testConversionStatefulPartitionRetry(self, mock_convert): sess.graph_def, input_tensors=[in_tensor], output_tensors=[out_tensor], - enable_mlir_converter=True, guarantee_all_funcs_one_use=False, ) self.assertTrue(str(model, encoding="utf-8"), "A model") @@ -164,7 +136,6 @@ def testGraphDefBasic(self): output_arrays=["add"], control_output_arrays=None, inference_type=dtypes.float32, - enable_mlir_converter=False, ) self.assertTrue(tflite_model) @@ -209,7 +180,6 @@ def testGraphDefQuantization(self): control_output_arrays=None, inference_type=dtypes.uint8, quantized_input_stats=[(0.0, 1.0), (0.0, 1.0)], - enable_mlir_converter=False, ) self.assertTrue(tflite_model) @@ -263,7 +233,6 @@ def testGraphDefQuantizationInvalid(self): output_arrays=["output"], control_output_arrays=None, inference_type=dtypes.uint8, - enable_mlir_converter=False, ) self.assertEqual( "The `quantized_input_stats` flag must be defined when either " diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 315e0c64848c62..5f0d1e14632f89 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -792,7 +792,6 @@ def _get_base_converter_args(self): "allow_custom_ops": self.allow_custom_ops, "debug_info": self._debug_info, "target_ops": self.target_spec.supported_ops, - "enable_mlir_converter": self.experimental_new_converter, "select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops, "supported_backends": self.target_spec.experimental_supported_backends, "unfold_batchmatmul": self.unfold_batchmatmul, diff --git a/tensorflow/lite/python/lite_constants.py b/tensorflow/lite/python/lite_constants.py index 4fc63f79f8c0c5..0d33eaeef14dd1 100644 --- a/tensorflow/lite/python/lite_constants.py +++ b/tensorflow/lite/python/lite_constants.py @@ -60,12 +60,6 @@ _tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant( __name__, "GRAPHVIZ_DOT") -# Currently the default mode of operation is to shell to another python process -# to protect against crashes. However, it breaks some dependent targets because -# it forces us to depend on an external py_binary. The experimental API doesn't -# have that drawback. -EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False - _allowed_symbols = [ "FLOAT", @@ -85,6 +79,5 @@ "KERAS", "JAX", "PYTORCH", - "EXPERIMENTAL_USE_TOCO_API_DIRECTLY", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/lite/python/metrics/metrics_nonportable_test.py b/tensorflow/lite/python/metrics/metrics_nonportable_test.py index ff12a822602c8c..288780db89b7d1 100644 --- a/tensorflow/lite/python/metrics/metrics_nonportable_test.py +++ b/tensorflow/lite/python/metrics/metrics_nonportable_test.py @@ -154,7 +154,6 @@ def test_conversion_from_constructor_success(self): mock.call.increase_counter_converter_success(), mock.call.export_metrics(), mock.call.set_converter_param('input_format', '1'), - mock.call.set_converter_param('enable_mlir_converter', 'True'), mock.call.set_converter_param('allow_custom_ops', 'False'), mock.call.set_converter_param('api_version', '1'), ], any_order=True) # pyformat: disable @@ -275,7 +274,6 @@ def test_conversion_from_saved_model(self): mock.call.increase_counter_converter_success(), mock.call.set_converter_latency(2000), mock.call.export_metrics(), - mock.call.set_converter_param('enable_mlir_converter', 'True'), ], any_order=True) # pyformat: disable def disable_converter_counter_metrics(self, tflite_metrics): From 585eef2ec396690e5fb8632e15e1faca82a1f098 Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Wed, 25 Sep 2024 12:48:12 -0700 Subject: [PATCH 275/483] Reverts 5d60f7c58fde119b717b1f4d7b23420f21182d15 PiperOrigin-RevId: 678810842 --- tensorflow/compiler/mlir/lite/BUILD | 2 +- tensorflow/compiler/mlir/lite/stablehlo/BUILD | 6 +- ...casting_op.mlir => fold_broadcast_to.mlir} | 46 +------------ ...g_op_pass.cc => fold_broadcast_to_pass.cc} | 66 ++++++------------- .../mlir/lite/stablehlo/transforms/passes.h | 6 +- .../compiler/mlir/lite/tf_tfl_passes.cc | 4 +- 6 files changed, 29 insertions(+), 101 deletions(-) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{fold_broadcasting_op.mlir => fold_broadcast_to.mlir} (50%) rename tensorflow/compiler/mlir/lite/stablehlo/transforms/{fold_broadcasting_op_pass.cc => fold_broadcast_to_pass.cc} (78%) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 3da89f496218a3..707225f8a67556 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1599,7 +1599,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:build_stablehlo_composite", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", "//tensorflow/compiler/mlir/lite/stablehlo:composite_lowering", - "//tensorflow/compiler/mlir/lite/stablehlo:fold_broadcasting_op_pass", # buildcleaner: keep + "//tensorflow/compiler/mlir/lite/stablehlo:fold_broadcast_to_pass", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:lift_callsite_loc_caller", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 3d34c83ab65274..a0c3febeead92f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -354,9 +354,9 @@ cc_library( ) cc_library( - name = "fold_broadcasting_op_pass", + name = "fold_broadcast_to_pass", srcs = [ - "transforms/fold_broadcasting_op_pass.cc", + "transforms/fold_broadcast_to_pass.cc", ], hdrs = [ "transforms/passes.h", @@ -1000,7 +1000,7 @@ tf_cc_binary( deps = [ ":compose_uniform_quantized_type_pass", ":fold_broadcast_pass", - ":fold_broadcasting_op_pass", + ":fold_broadcast_to_pass", ":fuse_convolution_pass", ":legalize_stablehlo_composite_to_tfl_custom", ":legalize_stablehlo_custom_call_to_composite", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcasting_op.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast_to.mlir similarity index 50% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcasting_op.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast_to.mlir index f34d239a25ab20..9bbfb3dee9313f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcasting_op.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast_to.mlir @@ -1,4 +1,4 @@ -// RUN: odml-to-stablehlo-opt %s -fold-broadcasting-op-pass -cse -verify-diagnostics | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -fold-broadcast-to-pass -cse -verify-diagnostics | FileCheck %s // CHECK-LABEL: @broadcast_mul0 func.func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { @@ -45,47 +45,3 @@ func.func @broadcast_batchmatmul(%arg0: tensor<5x30x1024xf32>) -> tensor<5x30x81 return %1 : tensor<5x30x8192xf32> // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) <{adj_x = false, adj_y = false}> : (tensor<5x30x1024xf32>, tensor<1024x8192xf32>) -> tensor<5x30x8192xf32> } - -// CHECK-LABEL: @dym_broadcast_mul0 -func.func @dym_broadcast_mul0(%arg0: tensor, %arg1: tensor<7xf32>) -> tensor { - %0 = "tfl.shape"(%arg0): (tensor) -> tensor<2xi32> - %1 = "tfl.broadcast_to"(%arg1, %0) : (tensor<7xf32>, tensor<2xi32>) -> tensor - %2 = "tfl.mul"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor - func.return %2 : tensor - // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor, tensor<7xf32>) -> tensor -} - -// CHECK-LABEL: @expanding_reshape_mul -func.func @expanding_reshape_mul(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { - %cst = mhlo.constant dense<[1, 7]> : tensor<2xi32> - %0 = "tfl.reshape"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<1x7xf32> - %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<1x7xf32>) -> tensor<5x7xf32> - func.return %1 : tensor<5x7xf32> - // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> -} - -// CHECK-LABEL: @squeezing_reshape_mul -func.func @squeezing_reshape_mul(%arg0: tensor<5x7xf32>, %arg1: tensor<1x7xf32>) -> tensor<5x7xf32> { - %cst = mhlo.constant dense<[7]> : tensor<1xi32> - %0 = "tfl.reshape"(%arg1, %cst) : (tensor<1x7xf32>, tensor<1xi32>) -> tensor<7xf32> - %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> - func.return %1 : tensor<5x7xf32> - // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<5x7xf32>, tensor<1x7xf32>) -> tensor<5x7xf32> -} - -// CHECK-LABEL: @expanddims_mul -func.func @expanddims_mul(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { - %cst = mhlo.constant dense<1> : tensor - %0 = "tfl.expand_dims"(%arg1, %cst) : (tensor<7xf32>, tensor) -> tensor<1x7xf32> - %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<1x7xf32>) -> tensor<5x7xf32> - func.return %1 : tensor<5x7xf32> - // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> -} - -// CHECK-LABEL: @squeeze_mul -func.func @squeeze_mul(%arg0: tensor<5x7xf32>, %arg1: tensor<1x7xf32>) -> tensor<5x7xf32> { - %0 = "tfl.squeeze"(%arg1) {squeeze_dims = [0]} : (tensor<1x7xf32>) -> tensor<7xf32> - %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> - func.return %1 : tensor<5x7xf32> - // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<5x7xf32>, tensor<1x7xf32>) -> tensor<5x7xf32> -} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcasting_op_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_to_pass.cc similarity index 78% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcasting_op_pass.cc rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_to_pass.cc index f7153e8ecd9564..204102c9e080a5 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcasting_op_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_to_pass.cc @@ -59,7 +59,7 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern { // Determine op with shapes is valid. TODO: @lukeboyer - Move the // `TFL_OperandsHaveSameShapesOrBroadcastableShape` runtime verification trait // into a standard (not runtime verification) trait and change this function to -// use only that interface. Currently there is no way to query derived runtime +// use only that interface. Curently there is no way to query derived runtime // verification traits. bool IsRankSupported(Operation* op) { // These ops have no rank constraints. @@ -75,14 +75,6 @@ bool IsRankSupported(Operation* op) { return llvm::cast(op->getResultTypes()[0]).getRank() <= 4; } -// Returns true when the op may be a broadcasting op. Broadcasting op is not -// limited to TFL::BroadcastToOp, but also other ops that may change the shape -// of a tensor to match the shape of another operand. -bool MayBeBroadcastingOp(Operation* op) { - return op && llvm::isa(op); -} - LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { if (op->hasTrait()) @@ -101,7 +93,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the result shape is fully defined. auto result_type = mlir::dyn_cast_or_null(op->getResultTypes().front()); - if (!result_type) return failure(); + if (!result_type || !result_type.hasStaticShape()) return failure(); if (!IsRankSupported(op)) { return failure(); @@ -110,39 +102,19 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( bool changed = false; for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) { // Check that the i'th operand is a broadcast. - auto broadcast = op->getOpOperand(i).get().getDefiningOp(); - if (!broadcast || !MayBeBroadcastingOp(broadcast)) { - continue; - } - - auto broadcast_input = broadcast->getOperand(0); - if (!broadcast_input) { - continue; - } + auto broadcast = llvm::dyn_cast_or_null( + op->getOpOperand(i).get().getDefiningOp()); + if (!broadcast) continue; // Check that the operand of the broadcast has fully defined shape. - // Fusing dynamic broadcasting op (non static broadcast_arg_type shape) - // is experimental and theoretically unsafe, because checking equality on - // unknown dimensions in broadcasted shape is not reliable. - // TODO: Full dynamism support with symbolic shape comparisons. - auto broadcast_arg_type = - mlir::cast(broadcast_input.getType()); - if (!broadcast_arg_type) { - continue; - } + auto broadcast_arg_type = mlir::dyn_cast_or_null( + broadcast.getInput().getType()); + if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; // Check that the other argument has fully defined shape. - auto argument = op->getOperand(1 - i); - auto argument_type = mlir::cast(argument.getType()); - // When two operands are both dynamic broadcasting op, it has high chance - // that the model is doing explicitly broadcasting. In this case, removing - // either broadcasting op may result in incorrect output shape in the - // runtime. - // TODO: Full dynamism support with symbolic shape comparisons. - if (!argument_type || (!broadcast_arg_type.hasStaticShape() && - MayBeBroadcastingOp(argument.getDefiningOp()))) { - continue; - } + auto argument_type = mlir::dyn_cast_or_null( + op->getOpOperand(1 - i).get().getType()); + if (!argument_type || !argument_type.hasStaticShape()) continue; // Get the unbroadcasted shapes in the operand order. std::array, 2> operand_shapes; @@ -162,7 +134,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Update the operand of the op to be the operand of the broadcast. rewriter.modifyOpInPlace( - op, [&]() { op->getOpOperand(i).set(broadcast->getOperand(0)); }); + op, [&]() { op->getOpOperand(i).set(broadcast.getInput()); }); changed = true; } return success(changed); @@ -238,12 +210,12 @@ LogicalResult ConvertResultsBroadcastableBatchMatMulShapeOp::RewriteOp( } // namespace -class FoldBroadcastingOpPass - : public PassWrapper> { +class FoldBroadcastToPass + : public PassWrapper> { public: - StringRef getArgument() const final { return "fold-broadcasting-op-pass"; } + StringRef getArgument() const final { return "fold-broadcast-to-pass"; } StringRef getDescription() const final { - return "Folds TFL broadcasting/shape changing nodes with subsequent ops"; + return "Folds tfl.BroadcastTo nodes with subsequent ops"; } void runOnOperation() override { @@ -261,11 +233,11 @@ class FoldBroadcastingOpPass }; // TODO(weiyiw): Consider having this as canonicalization? -std::unique_ptr> CreateFoldBroadcastingOpPass() { - return std::make_unique(); +std::unique_ptr> CreateFoldBroadcastToPass() { + return std::make_unique(); } -static PassRegistration pass; +static PassRegistration pass; } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h index defaf17af16a97..331505e2445e87 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h @@ -70,9 +70,9 @@ std::unique_ptr> CreateLegalizeChloToTflPass(); // Rewrites MHLO in preparation for tflite legalization. std::unique_ptr> CreatePrepareHloPass(); -// Folds TFL broadcasting/shape changing nodes with subsequent ops that -// supports implicit broadcasting. -std::unique_ptr> CreateFoldBroadcastingOpPass(); +// Folds tfl.BroadcastTo nodes with subsequent ops that supports implicit +// broadcasting. +std::unique_ptr> CreateFoldBroadcastToPass(); // Adds the HLO to TF rewrite patterns to the specified pattern list. void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index c5a4a766bbd23a..22cd44fca3e212 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -295,12 +295,12 @@ void AddPostQuantizationStableHloToTfPasses( // TODO: b/354280588 - Rewrite this pass into a pattern in PrepareHloPass. pass_manager.addPass(mlir::odml::CreateUnfoldSplatConstantPass()); pass_manager.addPass(mlir::odml::CreateLegalizeHloToTfLitePass()); - // Folds TFL broadcasting ops with subsequent ops if they have built in + // Folds tfl.BroadcastTo ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TFL // legalization, otherwise the newly generated TFL broadcast ops can fold // and materialize the weights. pass_manager.addNestedPass( - mlir::odml::CreateFoldBroadcastingOpPass()); + mlir::odml::CreateFoldBroadcastToPass()); } // folds tf.BroadcastTo ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TF From d23bdf591e436ee3c15dfe883ed4b53f80285da5 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Wed, 25 Sep 2024 13:01:20 -0700 Subject: [PATCH 276/483] Bifurcate exhaustive test utilities PiperOrigin-RevId: 678815297 --- third_party/xla/xla/tests/exhaustive/BUILD | 122 +-- .../xla/xla/tests/exhaustive/build_defs.bzl | 5 +- .../exhaustive_binary_test_definitions.h | 211 +--- .../exhaustive_binary_test_definitions.inc | 191 ++++ ...nary_test_f16_and_smaller_instantiation.cc | 24 +- ...ary_test_f16_and_smaller_instantiation.inc | 32 + ...xhaustive_binary_test_f32_instantiation.cc | 47 +- ...haustive_binary_test_f32_instantiation.inc | 55 ++ ...xhaustive_binary_test_f64_instantiation.cc | 50 +- ...haustive_binary_test_f64_instantiation.inc | 58 ++ .../exhaustive_binary_test_functions.cc | 907 +++++------------- .../exhaustive/exhaustive_binary_test_ops.inc | 80 ++ .../xla/tests/exhaustive/exhaustive_op_test.h | 74 ++ .../exhaustive/exhaustive_op_test_base.cc | 9 +- .../exhaustive/exhaustive_op_test_base.h | 93 +- .../exhaustive/exhaustive_op_test_utils.cc | 29 + .../exhaustive/exhaustive_op_test_utils.h | 26 +- .../exhaustive_unary_complex_test.cc | 7 +- .../exhaustive_unary_test_definitions.h | 156 +-- .../exhaustive_unary_test_definitions.inc | 140 +++ ...nary_test_f32_and_smaller_instantiation.cc | 27 +- ...ary_test_f32_and_smaller_instantiation.inc | 33 + ...exhaustive_unary_test_f64_instantiation.cc | 30 +- ...xhaustive_unary_test_f64_instantiation.inc | 38 + .../exhaustive_unary_test_functions.cc | 720 ++++---------- .../exhaustive/exhaustive_unary_test_ops.inc | 217 +++++ .../xla/xla/tests/exhaustive/platform.cc | 103 ++ .../xla/xla/tests/exhaustive/platform.h | 77 ++ .../xla/xla/tests/exhaustive/test_op.h | 247 +++++ 29 files changed, 1969 insertions(+), 1839 deletions(-) create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_ops.inc create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_op_test.h create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_ops.inc create mode 100644 third_party/xla/xla/tests/exhaustive/platform.cc create mode 100644 third_party/xla/xla/tests/exhaustive/platform.h create mode 100644 third_party/xla/xla/tests/exhaustive/test_op.h diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD index 902c2c696db98f..412e0af5c89a49 100644 --- a/third_party/xla/xla/tests/exhaustive/BUILD +++ b/third_party/xla/xla/tests/exhaustive/BUILD @@ -4,6 +4,7 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tests:build_defs.bzl", "xla_test") load("//xla/tests/exhaustive:build_defs.bzl", "exhaustive_xla_test") +load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -21,18 +22,41 @@ package_group( ], ) -cc_library( - name = "exhaustive_op_test_utils", +filegroup( + name = "exhaustive_op_test_utils_shared_hdrs", + testonly = True, + srcs = [ + "error_spec.h", + "exhaustive_op_test_base.h", + "exhaustive_op_test_utils.h", + ], + compatible_with = get_compatible_with_portable(), +) + +filegroup( + name = "exhaustive_op_test_utils_shared_srcs", testonly = True, srcs = [ "exhaustive_op_test_base.cc", "exhaustive_op_test_utils.cc", ], + compatible_with = get_compatible_with_portable(), +) + +cc_library( + name = "exhaustive_op_test_utils", + testonly = True, + srcs = [ + "platform.cc", + ":exhaustive_op_test_utils_shared_srcs", + ], hdrs = [ - "error_spec.h", - "exhaustive_op_test_base.h", - "exhaustive_op_test_utils.h", + "exhaustive_op_test.h", + "platform.h", + "test_op.h", + ":exhaustive_op_test_utils_shared_hdrs", ], + visibility = ["//visibility:private"], deps = [ "//xla:bit_cast", "//xla:executable_run_options", @@ -46,6 +70,8 @@ cc_library( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/service:shaped_buffer", + "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", "//xla/tests:client_library_test_base", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/util:command_line_flags", @@ -64,33 +90,23 @@ cc_library( ], ) -filegroup( - name = "exhaustive_unary_test_srcs", - testonly = True, - srcs = [ - "exhaustive_unary_test_definitions.h", - "exhaustive_unary_test_functions.cc", +cc_library( + name = "exhaustive_unary_test_textual_hdrs", + textual_hdrs = [ + "exhaustive_unary_test_definitions.inc", + "exhaustive_unary_test_f32_and_smaller_instantiation.inc", + "exhaustive_unary_test_f64_instantiation.inc", + "exhaustive_unary_test_ops.inc", ], ) -filegroup( - name = "exhaustive_unary_test_f32_and_smaller_srcs", - testonly = True, - srcs = ["exhaustive_unary_test_f32_and_smaller_instantiation.cc"], -) - -filegroup( - name = "exhaustive_unary_test_f64_srcs", - testonly = True, - srcs = ["exhaustive_unary_test_f64_instantiation.cc"], -) - exhaustive_xla_test( name = "exhaustive_unary_test", timeout = "long", srcs = [ "exhaustive_test_main.cc", - ":exhaustive_unary_test_srcs", + "exhaustive_unary_test_definitions.h", + "exhaustive_unary_test_functions.cc", ], # Nvidia close-sourced libraries are not TSAN friendly, but are doing their own synchronization. # This can lead to TSAN false positives that are hard to track down. @@ -107,8 +123,12 @@ exhaustive_xla_test( # exhaustive_xla_test needs to have all partition names added to allow other build tools to # function. partitions = { - "f32_and_smaller": [":exhaustive_unary_test_f32_and_smaller_srcs"], - "f64": [":exhaustive_unary_test_f64_srcs"], + "f32_and_smaller": [ + "exhaustive_unary_test_f32_and_smaller_instantiation.cc", + ], + "f64": [ + ":exhaustive_unary_test_f64_instantiation.cc", + ], }, real_hardware_only = True, # Very slow on the interpreter. shard_count = 50, @@ -120,8 +140,8 @@ exhaustive_xla_test( ], deps = [ ":exhaustive_op_test_utils", + ":exhaustive_unary_test_textual_hdrs", "//xla:literal", - "//xla:types", "//xla/client:xla_builder", "//xla/client/lib:constants", "//xla/client/lib:math", @@ -172,39 +192,24 @@ xla_test( ], ) -filegroup( - name = "exhaustive_binary_test_srcs", - testonly = True, - srcs = [ - "exhaustive_binary_test_definitions.h", - "exhaustive_binary_test_functions.cc", +cc_library( + name = "exhaustive_binary_test_textual_hdrs", + textual_hdrs = [ + "exhaustive_binary_test_definitions.inc", + "exhaustive_binary_test_f16_and_smaller_instantiation.inc", + "exhaustive_binary_test_f32_instantiation.inc", + "exhaustive_binary_test_f64_instantiation.inc", + "exhaustive_binary_test_ops.inc", ], ) -filegroup( - name = "exhaustive_binary_test_f16_and_smaller_srcs", - testonly = True, - srcs = ["exhaustive_binary_test_f16_and_smaller_instantiation.cc"], -) - -filegroup( - name = "exhaustive_binary_test_f32_srcs", - testonly = True, - srcs = ["exhaustive_binary_test_f32_instantiation.cc"], -) - -filegroup( - name = "exhaustive_binary_test_f64_srcs", - testonly = True, - srcs = ["exhaustive_binary_test_f64_instantiation.cc"], -) - exhaustive_xla_test( name = "exhaustive_binary_test", timeout = "long", srcs = [ + "exhaustive_binary_test_definitions.h", + "exhaustive_binary_test_functions.cc", "exhaustive_test_main.cc", - ":exhaustive_binary_test_srcs", ], # Nvidia close-sourced libraries are not TSAN friendly, but are doing their own synchronization. # This can lead to TSAN false positives that are hard to track down. @@ -221,9 +226,15 @@ exhaustive_xla_test( # exhasutive_xla_test needs to have all partition names added to allow other build tools to # function. partitions = { - "f16_and_smaller": [":exhaustive_binary_test_f16_and_smaller_srcs"], - "f32": [":exhaustive_binary_test_f32_srcs"], - "f64": [":exhaustive_binary_test_f64_srcs"], + "f16_and_smaller": [ + "exhaustive_binary_test_f16_and_smaller_instantiation.cc", + ], + "f32": [ + "exhaustive_binary_test_f32_instantiation.cc", + ], + "f64": [ + "exhaustive_binary_test_f64_instantiation.cc", + ], }, shard_count = 50, tags = [ @@ -233,6 +244,7 @@ exhaustive_xla_test( "no_oss", ], deps = [ + ":exhaustive_binary_test_textual_hdrs", ":exhaustive_op_test_utils", "//xla:literal", "//xla:types", diff --git a/third_party/xla/xla/tests/exhaustive/build_defs.bzl b/third_party/xla/xla/tests/exhaustive/build_defs.bzl index 446fe7de188460..f92de799dcae77 100644 --- a/third_party/xla/xla/tests/exhaustive/build_defs.bzl +++ b/third_party/xla/xla/tests/exhaustive/build_defs.bzl @@ -53,5 +53,8 @@ def exhaustive_xla_test(name, srcs, partitions, tags, **kwargs): register_extension_info( extension = exhaustive_xla_test, # Needs to be kept up-to-date on all partition names defined in the invocations. - label_regex_for_dep = "{extension_name}_(f16_and_smaller|f32_and_smaller|f32|f64)_.*", + # + # For some reason, manually specifying the expansion targets like (cpu|cpu_.*|...) is required + # for build tools. + label_regex_for_dep = "{extension_name}_(f16_and_smaller|f32_and_smaller|f32|f64)_(cpu|cpu_.*|gpu|gpu_.*)", ) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.h b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.h index 6ee448ed1c9cf8..e4aee45b302cac 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.h @@ -16,205 +16,26 @@ limitations under the License. #ifndef XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_BINARY_TEST_DEFINITIONS_H_ #define XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_BINARY_TEST_DEFINITIONS_H_ -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "xla/literal.h" -#include "xla/tests/exhaustive/exhaustive_op_test_base.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "xla/tests/test_macros.h" -#include "xla/types.h" -#include "tsl/platform/test.h" +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc + +#include "absl/log/check.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "absl/log/log.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "absl/types/span.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "xla/literal.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "xla/tests/exhaustive/exhaustive_op_test.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "xla/tests/test_macros.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "xla/types.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc namespace xla { namespace exhaustive_op_test { -// Exhaustive test for binary operations for 16 bit floating point types, -// including float16 and bfloat. -// -// Test parameter is a pair of (begin, end) for range under test. -template -class Exhaustive16BitBinaryTest - : public ExhaustiveBinaryTest, - public ::testing::WithParamInterface> { - public: - int64_t GetInputSize() override { - int64_t begin, end; - std::tie(begin, end) = GetParam(); - return end - begin; - } - - // Given a range of uint64_t representation, uses bits 0..15 and bits 16..31 - // for the values of src0 and src1 (see below for ordering) for the 16 bit - // binary operation being tested, and generates the cartesian product of the - // two sets as the two inputs for the test. - // - // If `kLeftToRightPacking == true`, bit 31..16 become src0 and 15..0 becomes - // src1. If `kLeftToRightPacking == false`, then bits 31..16 become src1 - // and 15..0 becomes src0. - void FillInput(std::array* input_literals) override { - int64_t input_size = GetInputSize(); - CHECK_EQ(input_size, (*input_literals)[0].element_count()); - CHECK_EQ(input_size, (*input_literals)[1].element_count()); - - int64_t begin, end; - std::tie(begin, end) = GetParam(); - if (VLOG_IS_ON(2)) { - uint16_t left_begin, left_end, right_begin, right_end; - if constexpr (kLeftToRightPacking) { - left_begin = std::bit_cast(static_cast(begin >> 16)); - left_end = std::bit_cast(static_cast(end >> 16)); - right_begin = std::bit_cast(static_cast(begin)); - right_end = std::bit_cast(static_cast(end)); - } else { - left_begin = std::bit_cast(static_cast(begin)); - left_end = std::bit_cast(static_cast(end)); - right_begin = - std::bit_cast(static_cast(begin >> 16)); - right_end = std::bit_cast(static_cast(end >> 16)); - } - - // N.B.: Use INFO directly instead of doing another thread-safe VLOG - // check. - LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; - LOG(INFO) << "\tfrom=(" << left_begin << ", " << right_begin << "); hex=(" - << std::hex << left_begin << ", " << right_begin << "); float=(" - << *reinterpret_cast(&left_begin) << ", " - << *reinterpret_cast(&right_begin) - << ") (inclusive)"; - LOG(INFO) << "\tto=(" << left_end << ", " << right_end << "); hex=(" - << std::hex << left_end << ", " << right_end << "); float=(" - << *reinterpret_cast(&left_end) << ", " - << *reinterpret_cast(&right_end) - << ") (exclusive)"; - LOG(INFO) << "\ttotal values to test=" << (end - begin); - } - - absl::Span input_arr_0 = (*input_literals)[0].data(); - absl::Span input_arr_1 = (*input_literals)[1].data(); - for (int64_t i = 0; i < input_size; i++) { - uint32_t input_val = i + begin; - // Convert the packed bits to a pair of NativeT and replace known - // incorrect input values with 0. - // - // In either case, we only use 32 bits out of the 64 bits possible. - if constexpr (kLeftToRightPacking) { - // Left is stored at higher 16 bits. - input_arr_0[i] = this->ConvertValue(input_val >> 16); - input_arr_1[i] = this->ConvertValue(input_val); - } else { - // Left is stored at lower 16 bits. - input_arr_0[i] = this->ConvertValue(input_val); - input_arr_1[i] = this->ConvertValue(input_val >> 16); - } - } - } - - protected: - using typename ExhaustiveBinaryTest::NativeT; -}; - -// Exhaustive test for binary operations for float and double. -// -// Test parameter is a tuple of (FpValues, FpValues) describing the possible -// values for each operand. The inputs for the test are the Cartesian product -// of the possible values for the two operands. -template -class Exhaustive32BitOrMoreBinaryTest - : public ExhaustiveBinaryTest, - public ::testing::WithParamInterface> { - protected: - using typename ExhaustiveBinaryTest::NativeT; - - private: - int64_t GetInputSize() override { - FpValues values_0; - FpValues values_1; - std::tie(values_0, values_1) = GetParam(); - return values_0.GetTotalNumValues() * values_1.GetTotalNumValues(); - } - - void FillInput(std::array* input_literals) override { - int64_t input_size = GetInputSize(); - FpValues values_0; - FpValues values_1; - std::tie(values_0, values_1) = GetParam(); - if (VLOG_IS_ON(2)) { - // N.B.: Use INFO directly instead of doing another thread-safe VLOG - // check. - LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; - LOG(INFO) << "\tleft values=" << values_0.ToString(); - LOG(INFO) << "\tright values=" << values_1.ToString(); - LOG(INFO) << "\ttotal values to test=" << input_size; - } - CHECK(input_size == (*input_literals)[0].element_count() && - input_size == (*input_literals)[1].element_count()); - - absl::Span input_arr_0 = (*input_literals)[0].data(); - absl::Span input_arr_1 = (*input_literals)[1].data(); - - uint64_t i = 0; - for (auto src0 : values_0) { - for (auto src1 : values_1) { - input_arr_0[i] = this->ConvertValue(src0); - input_arr_1[i] = this->ConvertValue(src1); - ++i; - } - } - CHECK_EQ(i, input_size); - } -}; - -using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest; -using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; -using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest; -using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -#define BINARY_TEST_F16(test_name, ...) \ - XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ - __VA_ARGS__ -#else -#define BINARY_TEST_F16(test_name, ...) -#endif - -#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) -#define BINARY_TEST_BF16(test_name, ...) \ - XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \ - __VA_ARGS__ -#else -#define BINARY_TEST_BF16(test_name, ...) -#endif - -#define BINARY_TEST_F32(test_name, ...) \ - XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \ - __VA_ARGS__ - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; -#define BINARY_TEST_F64(test_name, ...) \ - XLA_TEST_P(ExhaustiveF64BinaryTest, test_name) \ - __VA_ARGS__ -#else -#define BINARY_TEST_F64(test_name, ...) -#endif - -#define BINARY_TEST(test_name, ...) \ - BINARY_TEST_F16(test_name, __VA_ARGS__) \ - BINARY_TEST_BF16(test_name, __VA_ARGS__) \ - BINARY_TEST_F32(test_name, __VA_ARGS__) \ - BINARY_TEST_F64(test_name, __VA_ARGS__) - -#define BINARY_TEST_COMPLEX(test_name, ...) \ - BINARY_TEST_F32(test_name, __VA_ARGS__) \ - BINARY_TEST_F64(test_name, __VA_ARGS__) +#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.inc" } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc new file mode 100644 index 00000000000000..8fe0a71d45277d --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc @@ -0,0 +1,191 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Exhaustive test for binary operations for 16 bit floating point types, +// including float16 and bfloat. +// +// Test parameter is a pair of (begin, end) for range under test. +template +class Exhaustive16BitBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + protected: + int64_t GetInputSize() override { + int64_t begin, end; + std::tie(begin, end) = GetParam(); + return end - begin; + } + + // Given a range of uint64_t representation, uses bits 0..15 and bits 16..31 + // for the values of src0 and src1 (see below for ordering) for the 16 bit + // binary operation being tested, and generates the cartesian product of the + // two sets as the two inputs for the test. + // + // If `kLeftToRightPacking == true`, bit 31..16 become src0 and 15..0 becomes + // src1. If `kLeftToRightPacking == false`, then bits 31..16 become src1 + // and 15..0 becomes src0. + void FillInput(std::array* input_literals) override { + using NativeT = typename ExhaustiveBinaryTest::NativeT; + + int64_t input_size = GetInputSize(); + CHECK_EQ(input_size, (*input_literals)[0].element_count()); + CHECK_EQ(input_size, (*input_literals)[1].element_count()); + + int64_t begin, end; + std::tie(begin, end) = GetParam(); + if (VLOG_IS_ON(2)) { + uint16_t left_begin, left_end, right_begin, right_end; + if constexpr (kLeftToRightPacking) { + left_begin = std::bit_cast(static_cast(begin >> 16)); + left_end = std::bit_cast(static_cast(end >> 16)); + right_begin = std::bit_cast(static_cast(begin)); + right_end = std::bit_cast(static_cast(end)); + } else { + left_begin = std::bit_cast(static_cast(begin)); + left_end = std::bit_cast(static_cast(end)); + right_begin = + std::bit_cast(static_cast(begin >> 16)); + right_end = std::bit_cast(static_cast(end >> 16)); + } + + // N.B.: Use INFO directly instead of doing another thread-safe VLOG + // check. + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + LOG(INFO) << "\tfrom=(" << left_begin << ", " << right_begin << "); hex=(" + << std::hex << left_begin << ", " << right_begin << "); float=(" + << *reinterpret_cast(&left_begin) << ", " + << *reinterpret_cast(&right_begin) + << ") (inclusive)"; + LOG(INFO) << "\tto=(" << left_end << ", " << right_end << "); hex=(" + << std::hex << left_end << ", " << right_end << "); float=(" + << *reinterpret_cast(&left_end) << ", " + << *reinterpret_cast(&right_end) + << ") (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + for (int64_t i = 0; i < input_size; i++) { + uint32_t input_val = i + begin; + // Convert the packed bits to a pair of NativeT and replace known + // incorrect input values with 0. + // + // In either case, we only use 32 bits out of the 64 bits possible. + if constexpr (kLeftToRightPacking) { + // Left is stored at higher 16 bits. + input_arr_0[i] = this->ConvertValue(input_val >> 16); + input_arr_1[i] = this->ConvertValue(input_val); + } else { + // Left is stored at lower 16 bits. + input_arr_0[i] = this->ConvertValue(input_val); + input_arr_1[i] = this->ConvertValue(input_val >> 16); + } + } + } +}; + +// Exhaustive test for binary operations for float and double. +// +// Test parameter is a tuple of (FpValues, FpValues) describing the possible +// values for each operand. The inputs for the test are the Cartesian product +// of the possible values for the two operands. +template +class Exhaustive32BitOrMoreBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + protected: + int64_t GetInputSize() override { + FpValues values_0; + FpValues values_1; + std::tie(values_0, values_1) = GetParam(); + return values_0.GetTotalNumValues() * values_1.GetTotalNumValues(); + } + + void FillInput(std::array* input_literals) override { + using NativeT = typename ExhaustiveBinaryTest::NativeT; + + int64_t input_size = GetInputSize(); + FpValues values_0; + FpValues values_1; + std::tie(values_0, values_1) = GetParam(); + if (VLOG_IS_ON(2)) { + // N.B.: Use INFO directly instead of doing another thread-safe VLOG + // check. + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\tleft values=" << values_0.ToString(); + LOG(INFO) << "\tright values=" << values_1.ToString(); + LOG(INFO) << "\ttotal values to test=" << input_size; + } + CHECK(input_size == (*input_literals)[0].element_count() && + input_size == (*input_literals)[1].element_count()); + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + + uint64_t i = 0; + for (auto src0 : values_0) { + for (auto src1 : values_1) { + input_arr_0[i] = this->ConvertValue(src0); + input_arr_1[i] = this->ConvertValue(src1); + ++i; + } + } + CHECK_EQ(i, input_size); + } +}; + +using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest; +using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; +using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest; +using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +#define BINARY_TEST_F16(test_name, ...) \ + XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_F16(test_name, ...) +#endif + +#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) +#define BINARY_TEST_BF16(test_name, ...) \ + XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_BF16(test_name, ...) +#endif + +#define BINARY_TEST_F32(test_name, ...) \ + XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \ + __VA_ARGS__ + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#define BINARY_TEST_F64(test_name, ...) \ + XLA_TEST_P(ExhaustiveF64BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_F64(test_name, ...) +#endif + +#define BINARY_TEST(test_name, ...) \ + BINARY_TEST_F16(test_name, __VA_ARGS__) \ + BINARY_TEST_BF16(test_name, __VA_ARGS__) \ + BINARY_TEST_F32(test_name, __VA_ARGS__) \ + BINARY_TEST_F64(test_name, __VA_ARGS__) + +#define BINARY_TEST_COMPLEX(test_name, ...) \ + BINARY_TEST_F32(test_name, __VA_ARGS__) \ + BINARY_TEST_F64(test_name, __VA_ARGS__) \ No newline at end of file diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.cc index d8451068624898..db863ed53af0fe 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.cc @@ -13,31 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" // IWYU pragma: keep, exhaustive_binary_test_f16_and_smaller_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_binary_test_f16_and_smaller_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_binary_test_f16_and_smaller_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) -INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); -#endif - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); -#endif - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32BinaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); +#include "xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc new file mode 100644 index 00000000000000..1e88061028a65d --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) +INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest, + ::testing::ValuesIn(CreateExhaustiveF32Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, + ::testing::ValuesIn(CreateExhaustiveF32Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); +#endif + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.cc index ed28e923a035bf..c625282e987d5a 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.cc @@ -13,54 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" // IWYU pragma: keep, exhaustive_binary_test_f32_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_binary_test_f32_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_binary_test_f32_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); - -INSTANTIATE_TEST_SUITE_P( - SpecialValues, ExhaustiveF32BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - -INSTANTIATE_TEST_SUITE_P( - SpecialAndNormalValues, ExhaustiveF32BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), - ::testing::Values(GetNormals(2000)))); - -INSTANTIATE_TEST_SUITE_P( - NormalAndSpecialValues, ExhaustiveF32BinaryTest, - ::testing::Combine( - ::testing::Values(GetNormals(2000)), - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - -INSTANTIATE_TEST_SUITE_P( - NormalAndNormalValues, ExhaustiveF32BinaryTest, - ::testing::Combine(::testing::Values(GetNormals(2000)), - ::testing::Values(GetNormals(2000)))); - -// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. -// Comparing with the unary tests, the binary tests use a smaller set of inputs -// for each sub-test to avoid timeout because the implementation of ExpectNear -// more than 2x slower for binary test. -INSTANTIATE_TEST_SUITE_P( - LargeAndSmallMagnitudeNormalValues, ExhaustiveF32BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals(40000, - 2000)), - ::testing::ValuesIn( - GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); +#include "xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc new file mode 100644 index 00000000000000..1c8e97d1d5d41e --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc @@ -0,0 +1,55 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); + +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + SpecialAndNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(2000)))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndSpecialValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::Values(GetNormals(2000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine(::testing::Values(GetNormals(2000)), + ::testing::Values(GetNormals(2000)))); + +// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. +// Comparing with the unary tests, the binary tests use a smaller set of inputs +// for each sub-test to avoid timeout because the implementation of ExpectNear +// more than 2x slower for binary test. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnitudeNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals(40000, + 2000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.cc index c948c83703171e..fd2b73a706cca1 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.cc @@ -13,57 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" // IWYU pragma: keep, exhaustive_binary_test_f64_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_binary_test_f64_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_binary_test_f64_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32BinaryTest); - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -INSTANTIATE_TEST_SUITE_P( - SpecialValues, ExhaustiveF64BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - -INSTANTIATE_TEST_SUITE_P( - SpecialAndNormalValues, ExhaustiveF64BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), - ::testing::Values(GetNormals(1000)))); - -INSTANTIATE_TEST_SUITE_P( - NormalAndSpecialValues, ExhaustiveF64BinaryTest, - ::testing::Combine( - ::testing::Values(GetNormals(1000)), - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - -INSTANTIATE_TEST_SUITE_P( - NormalAndNormalValues, ExhaustiveF64BinaryTest, - ::testing::Combine(::testing::Values(GetNormals(1000)), - ::testing::Values(GetNormals(1000)))); - -// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. -// Similar to ExhaustiveF64BinaryTest, we use a smaller set of inputs for each -// for each sub-test comparing with the unary test to avoid timeout. -INSTANTIATE_TEST_SUITE_P( - LargeAndSmallMagnitudeNormalValues, ExhaustiveF64BinaryTest, - ::testing::Combine( - ::testing::ValuesIn( - GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), - ::testing::ValuesIn( - GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); -#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#include "xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc new file mode 100644 index 00000000000000..0de1e1242d6b7c --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32BinaryTest); + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + SpecialAndNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(1000)))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndSpecialValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::Values(GetNormals(1000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine(::testing::Values(GetNormals(1000)), + ::testing::Values(GetNormals(1000)))); + +// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. +// Similar to ExhaustiveF64BinaryTest, we use a smaller set of inputs for each +// for each sub-test comparing with the unary test to avoid timeout. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnitudeNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); +#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc index a74b86ac89d019..07485354cf2427 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include -#include +#include // IWYU pragma: keep, exhaustive_binary_test_ops.inc +#include // IWYU pragma: keep, exhaustive_binary_test_ops.inc #include #include -#include "xla/client/xla_builder.h" +#include "xla/client/xla_builder.h" // IWYU pragma: keep, exhaustive_binary_test_ops.inc +#include "xla/tests/exhaustive/error_spec.h" #include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include "xla/tests/exhaustive/test_op.h" // IWYU pragma: keep, exhaustive_binary_test_ops.inc #include "xla/types.h" #ifdef __FAST_MATH__ @@ -33,346 +33,163 @@ namespace xla { namespace exhaustive_op_test { namespace { +#include "xla/tests/exhaustive/exhaustive_binary_test_ops.inc" + // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. template -double AddCpuTpuAbsErr(NativeT left, NativeT right) { +double AddCpuAbsErr(NativeT left, NativeT right) { NativeRefT output = static_cast(left) + static_cast(right); - // Hardware flushes subnormal outputs to 0. if (IsSubnormal(output)) { return std::numeric_limits::min(); } - return 0.0; } BINARY_TEST(Add, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if ((IsCpu(platform_) || IsTpu(platform_))) { - if (std::is_same_v || - std::is_same_v || std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(AddCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - } - - Run( - AddEmptyBroadcastDimension(Add), - [](NativeRefT x, NativeRefT y) { return x + y; }, error_spec_gen); + AddOp(this) + .Error(+[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(AddCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. template -double SubCpuTpuAbsErr(NativeT left, NativeT right) { +double SubCpuAbsErr(NativeT left, NativeT right) { NativeRefT output = static_cast(left) - static_cast(right); - // Hardware flushes subnormal outputs to 0. if (IsSubnormal(output)) { return std::numeric_limits::min(); } - return 0.0; } BINARY_TEST(Sub, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) || IsTpu(platform_)) { - if (std::is_same_v || - std::is_same_v || std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(SubCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - } - - Run( - AddEmptyBroadcastDimension(Sub), - [](NativeRefT x, NativeRefT y) { return x - y; }, error_spec_gen); + SubOp(this) + .Error(+[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(SubCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. template -double MulCpuTpuAbsErr(NativeT left, NativeT right) { +double MulCpuAbsErr(NativeT left, NativeT right) { NativeRefT output = static_cast(left) * static_cast(right); - - // CPU BF16 and TPU (all types) flush subnormals to 0. + // CPU BF16 flush subnormals to 0. auto output_is_subnormal = IsSubnormal(output); if (output_is_subnormal) { return std::numeric_limits::min(); } - return 0.0; } -bool MulCpuTpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { - // For CPU and TPU BF16, multiplying a subnormal by infinity will lead to +bool MulCpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { + // For CPU BF16, multiplying a subnormal by infinity will lead to // calculating 0 multiplied by infinity due to subnormal flushing, which is // defined to be NaN. However, the calculation in higher precision does not // flush the subnormal value to 0, leading to a result of infinity. - if ((IsSubnormal(left) && std::isinf(right)) || - (std::isinf(left) && IsSubnormal(right))) { - return true; - } - return false; + return (IsSubnormal(left) && std::isinf(right)) || + (std::isinf(left) && IsSubnormal(right)); } BINARY_TEST(Mul, { - ErrorSpecGen error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) || IsTpu(platform_)) { - if (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(MulCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .skip_comparison( - MulCpuTpuBf16Skip(static_cast(left), - static_cast(right))) - .build(); - }; - } - if (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(MulCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - } - - Run( - AddEmptyBroadcastDimension(Mul), - [](NativeRefT x, NativeRefT y) { return x * y; }, error_spec_gen); + MulOp(this) + .Error(+[](NativeT left, NativeT right) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(MulCpuAbsErr(left, right)) + .strict_signed_zeros() + .skip_comparison( + MulCpuBf16Skip(static_cast(left), + static_cast(right))) + .build(); + } + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(MulCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. template -double DivCpuTpuAbsErr(NativeT left, NativeT right) { +double DivCpuAbsErr(NativeT left, NativeT right) { NativeRefT output = static_cast(left) / static_cast(right); - // Subnormals are flushed to 0 so we add a absolute error margin that is // larger than any subnormal. if (IsSubnormal(output)) { return std::numeric_limits::min(); } - - return 0.0; -} - -template -double DivTpuAbsErr(NativeT left, NativeT right) { - NativeRefT reciprocal = 1.0f / static_cast(right); - NativeT output = left / right; - NativeRefT output_as_native_ref_t = - static_cast(left) / static_cast(right); - - // If we calculate NaN, we don't need to adjust tolerances. - if (std::isnan(output_as_native_ref_t)) { - return 0.0; - } - - // TPUs perform `left * (1 / right)`, where `left` and `1 / right` are - // flushed to `0` if they are subnormal. Also applies to if reciprocal is min - // normal. - if (IsSubnormal(left) || IsSubnormal(reciprocal)) { - // Subnormals can have a larger value in BF16 than float due to rounding to - // the nearest BF16 value during conversion while having less representation - // bits. For normals, the float value is usually always bigger due to - // greater precision. - return std::max(std::abs(output), std::abs(output_as_native_ref_t)); - } - - // For subnormals, we need to set absolute error to the smallest positive - // representable value due to hardware implementations that truncate - // subnormals to zero. - if (IsSubnormal(output)) { - return std::numeric_limits::min(); - } - - return 0.0; -} - -template -double DivTpuBf16F32AbsErr(NativeT left, NativeT right) { - NativeRefT reciprocal = 1.0f / static_cast(right); - NativeT output = left / right; - NativeRefT output_as_native_ref_t = - static_cast(left) / static_cast(right); - - // If we calculate NaN, we don't need to adjust tolerances. - if (std::isnan(output_as_native_ref_t)) { - return 0.0; - } - - // TPUs perform `left * (1 / right)`, where `left` and `1 / right` are - // flushed to `0` if they are subnormal. Also applies to if reciprocal is min - // normal. - if (IsSubnormal(left) || IsSubnormalOrMinNormal(reciprocal)) { - // Subnormals can have a larger value in BF16 than float due to rounding to - // the nearest BF16 value during conversion while having less representation - // bits. For normals, the float value is usually always bigger due to - // greater precision. - return std::max(std::abs(output), std::abs(output_as_native_ref_t)); - } - - // For subnormals, we need to set absolute error to the smallest positive - // representable value due to hardware implementations that truncate - // subnormals to zero. - if (IsSubnormalOrMinNormal(output)) { - return std::numeric_limits::min(); - } - return 0.0; } -template -bool DivTpuBf16F32Skip(NativeT left, NativeT right) { - NativeRefT reciprocal = 1.0f / right; - - // TPU calculates `left * (1 / right)` and flushed `(1 / right)` to `0` when - // it is subnormal or min normal. It also follows the IEEE multiplication spec - // that inf * 0 is NaN. However, IEEE division of infinity by a subnormal is - // infinity, so we must skip comparison. - if (std::isinf(left) && IsSubnormalOrMinNormal(reciprocal)) { - return true; - } - - return false; -} - BINARY_TEST(Div, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) && - (std::is_same_v || - std::is_same_v || std::is_same_v)) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(DivCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - - if (IsGpu(platform_)) { - if (std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(1) - .strict_signed_zeros() - .build(); - }; - } else if (std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(2) - .strict_signed_zeros() - .build(); - }; - } - } - - if (IsTpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(DivTpuBf16F32AbsErr(left, right)) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - // This is basically distance_err(1), but is tighter because it - // guarantees this only happens when the abs_err is less than min - // normal. - .abs_err(std::numeric_limits::min()) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(DivTpuAbsErr(left, right)) - .distance_err(2) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } - } - if (IsPreV6Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(DivTpuAbsErr(left, right)) - .rel_err(34 * eps) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } - } - if (IsPreV5Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(DivTpuBf16F32AbsErr(left, right)) - .rel_err(eps) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(DivTpuAbsErr(left, right)) - .rel_err(136 * eps) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } - } - - Run( - AddEmptyBroadcastDimension(Div), - [](NativeRefT x, NativeRefT y) { return x / y; }, error_spec_gen); + DivOp(this) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(DivCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT, NativeT) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(2) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of @@ -388,61 +205,42 @@ double MaxMinCpuAbsErr(NativeT left, NativeT right) { } BINARY_TEST(Max, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) && - (std::is_same_v || - std::is_same_v || std::is_same_v)) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(MaxMinCpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - - if (IsGpu(platform_) || IsTpu(platform_)) { - error_spec_gen = +[](NativeT, NativeT) { - // A100 and H100 return -0 for max(-0,0). - // - // TPUs return -0 for max(0,-0) and 0 for max(-0,0). - return ErrorSpec::Builder().strict_signed_zeros(false).build(); - }; - } - - Run(AddEmptyBroadcastDimension(Max), ReferenceMax, - error_spec_gen); + MaxOp(this) + .CpuError(+[](NativeT left, NativeT right) { + if ((std::is_same_v || + std::is_same_v || + std::is_same_v)) { + return ErrorSpec::Builder() + .abs_err(MaxMinCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT, NativeT) { + // A100 and H100 return -0 for max(-0,0). + return ErrorSpec::Builder().strict_signed_zeros(false).build(); + }) + .Run(); }) BINARY_TEST(Min, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) && - (std::is_same_v || - std::is_same_v || std::is_same_v)) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(MaxMinCpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - - if (IsGpu(platform_) || IsTpu(platform_)) { - error_spec_gen = +[](NativeT, NativeT) { - // A100 and H100 return 0 for min(0,-0). - // - // TPUs return 0 for min(-0,0) and -0 for min(0,-0). - return ErrorSpec::Builder().strict_signed_zeros(false).build(); - }; - } - - Run(AddEmptyBroadcastDimension(Min), ReferenceMin, - error_spec_gen); + MinOp(this) + .CpuError(+[](NativeT left, NativeT right) { + if (std::is_same_v || + std::is_same_v || std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(MaxMinCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT, NativeT) { + // A100 and H100 return 0 for min(0,-0). + return ErrorSpec::Builder().strict_signed_zeros(false).build(); + }) + .Run(); }) template @@ -476,17 +274,6 @@ double PowCpuBf16F32AbsErr(NativeT left, NativeT right) { return 0.0; } -double PowTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { - float output = std::pow(static_cast(left), static_cast(right)); - - // Output is flushed to 0 if subnormal. - if (IsSubnormal(output)) { - return std::numeric_limits::min(); - } - - return 0.0; -} - bool PowCpuF64Skip(double left, double right) { // Hardware returns 0 if right is positive and inf otherwise. if ((IsSubnormal(left) || std::isinf(left) || left == 0) && @@ -509,119 +296,37 @@ bool PowCpuGpuF16Skip(NativeT left, NativeT right) { return false; } -template -bool PowTpuSkip(NativeT left, NativeT right) { - // Hardware always returns 1 if right is 0 (or subnormal due to - // flushing subnormals to zero before the operation), no matter if left is - // NaN. - if (std::isnan(left) && (right == 0.0f || IsSubnormal(right))) { - return true; - } - // Hardware always returns 1 if left is 1, no matter if right is NaN. - if (left == 1.0f && std::isnan(right)) { - return true; - } - - return false; -} - BINARY_TEST(Pow, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .strict_signed_zeros() - .skip_comparison(PowCpuGpuF16Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v || - std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(PowCpuBf16F32AbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .strict_signed_zeros() - .skip_comparison(PowCpuF64Skip(static_cast(left), - static_cast(right))) - .build(); - }; - } - } - - if (IsGpu(platform_)) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .distance_err(1) - .strict_signed_zeros() - .skip_comparison(PowCpuGpuF16Skip(left, right)) - .build(); - }; - } - - if (IsTpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(PowTpuBf16AbsErr(static_cast(left), - static_cast(right))) - .distance_err(1) - .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { + PowOp(this) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .strict_signed_zeros() + .skip_comparison(PowCpuGpuF16Skip(left, right)) + .build(); + } else if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(PowCpuBf16F32AbsErr(left, right)) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .strict_signed_zeros() + .skip_comparison(PowCpuF64Skip(static_cast(left), + static_cast(right))) + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT left, NativeT right) { return ErrorSpec::Builder() .distance_err(1) .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .distance_err(8) - .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) - .build(); - }; - } - } - if (IsPreV6Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .rel_err(41 * eps) - .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) - .build(); - }; - } - } - if (IsPreV5Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .rel_err(44 * eps) - .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) + .skip_comparison(PowCpuGpuF16Skip(left, right)) .build(); - }; - } - } - - Run(AddEmptyBroadcastDimension(Pow), std::pow, error_spec_gen); + }) + .Run(); }) // Can be thought of as an absolute error of @@ -630,13 +335,11 @@ template double Atan2CpuBf16F32F64AbsErr(NativeT left, NativeT right) { NativeRefT output = std::atan2(static_cast(left), static_cast(right)); - // If the output would be a subnormal float, we allow some error to account // for BF16 implementation flushing subnormals to zero. if (IsSubnormal(output)) { return std::numeric_limits::min(); } - return 0.0; } @@ -648,151 +351,58 @@ bool Atan2CpuBf16F32Skip(NativeT left, NativeT right) { if (IsSubnormal(left) && IsSubnormal(right)) { return true; } - return false; } -template -double Atan2TpuBf16F32AbsErr(NativeT left, NativeT right) { - NativeT output = static_cast(std::atan2(left, right)); - NativeRefT output_as_float = - std::atan2(static_cast(left), static_cast(right)); - - // If the output would be a subnormal float, we allow some error to account - // for BF16 implementation flushing subnormals to zero. TPUs also seem to - // flush the minimum value to 0 along with subnormals. - if (IsSubnormalOrMinNormal(output_as_float)) { - return std::numeric_limits::min(); - } - - // Implementation of Atan2 on TPUs is that they take the reciprocal of the - // larger of left or right. If this is subnormal or the minimum value, the TPU - // flushes it to 0 before using it in multiplication. When this happens, the - // error is the output calculation, either in BF16 or float, or PI/2, - // depending on which of the three is bigger. - NativeRefT reciprocal_as_float = - 1.0f / std::max(std::abs(static_cast(left)), - std::abs(static_cast(right))); - if (!std::isnan(output_as_float) && IsSubnormal(reciprocal_as_float)) { - return std::max({std::abs(output_as_float), std::abs(output), - static_cast(M_PI_2)}); - } - - return 0.0; -} - BINARY_TEST(Atan2, { - auto error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2CpuBf16F32F64AbsErr(left, right)) - .strict_signed_zeros() - .skip_comparison(Atan2CpuBf16F32Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2CpuBf16F32F64AbsErr(left, right)) - // Only used when right is subnormal. - .distance_err(2) - .strict_signed_zeros() - .skip_comparison(Atan2CpuBf16F32Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2CpuBf16F32F64AbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - } - - if (IsGpu(platform_)) { - if constexpr (std::is_same_v || - std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(1) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(3) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(2) - .strict_signed_zeros() - .build(); - }; - } - } - - if (IsTpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2TpuBf16F32AbsErr(left, right)) - .distance_err(1) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .distance_err(1) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2TpuBf16F32AbsErr(left, right)) - .distance_err(3) - .strict_signed_zeros() - .build(); - }; - } - } - if (IsPreV6Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(Atan2TpuBf16F32AbsErr(left, right)) - .rel_err(28 * eps) - .strict_signed_zeros() - .build(); - }; - } - } - if (IsPreV5Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(Atan2TpuBf16F32AbsErr(left, right)) - .rel_err(133 * eps) - .strict_signed_zeros() - .build(); - }; - } - } - - Run(AddEmptyBroadcastDimension(Atan2), std::atan2, error_spec_gen); + Atan2Op(this) + .CpuError([](NativeT left, NativeT right) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err( + Atan2CpuBf16F32F64AbsErr(left, right)) + .strict_signed_zeros() + .skip_comparison(Atan2CpuBf16F32Skip(left, right)) + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err( + Atan2CpuBf16F32F64AbsErr(left, right)) + // Only used when right is subnormal. + .distance_err(2) + .strict_signed_zeros() + .skip_comparison(Atan2CpuBf16F32Skip(left, right)) + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err( + Atan2CpuBf16F32F64AbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT, NativeT) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(3) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(2) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of @@ -810,26 +420,7 @@ double AbsComplexCpuAbsErr(NativeRefT real, NativeRefT imag) { template bool AbsComplexSkip(NativeRefT real, NativeRefT imag) { // TODO(timshen): see b/162664705. - if (std::isnan(real) || std::isnan(imag)) { - return true; - } - return false; -} - -template -double AbsComplexTpuRelErr(NativeRefT real, NativeRefT imag) { - NativeRefT abs_max = std::max(std::abs(real), std::abs(imag)); - NativeRefT kOne(1); - NativeRefT reciprocal = kOne / abs_max; - if (IsSubnormal(reciprocal)) { - // In this case, the reciprocal erroneously returns zero, and - // we get max(|real|, |imag|) instead of sqrt(real^2 + imag^2), - // so the relative error can be as large as (sqrt(2)-1)/sqrt(2) ~= 0.293, - // when using the typical hypot implementation hypot(max, min) = max * - // sqrt(1 + min / max). - return 0.293; - } - return 0.0; + return std::isnan(real) || std::isnan(imag); } // It is more convenient to implement Abs(complex) as a binary op than a unary @@ -838,65 +429,33 @@ double AbsComplexTpuRelErr(NativeRefT real, NativeRefT imag) { // TODO(bixia): May want to move this test to unary test if we will be able to // implement Abs(complex) as unary conveniently. BINARY_TEST_COMPLEX(AbsComplex, { - ErrorSpecGen error_spec_gen = +[](NativeRefT, NativeRefT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_)) { - if constexpr (std::is_same_v || - std::is_same_v) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .abs_err(AbsComplexCpuAbsErr(real, imag)) - .distance_err(2) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } - } - - if (IsGpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .distance_err(3) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .distance_err(2) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } - } - - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .rel_err(AbsComplexTpuRelErr(real, imag)) - .distance_err(3) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .rel_err(AbsComplexTpuRelErr(real, imag)) - .distance_err(125) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } - - Run([](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }, - [](NativeRefT x, NativeRefT y) { - return std::abs(std::complex(x, y)); - }, - error_spec_gen); + AbsComplexOp(this) + .CpuError(+[](NativeRefT real, NativeRefT imag) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(AbsComplexCpuAbsErr(real, imag)) + .distance_err(2) + .skip_comparison(AbsComplexSkip(real, imag)) + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeRefT real, NativeRefT imag) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(3) + .skip_comparison(AbsComplexSkip(real, imag)) + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(2) + .skip_comparison(AbsComplexSkip(real, imag)) + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) } // namespace diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_ops.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_ops.inc new file mode 100644 index 00000000000000..611eb2c231aa46 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_ops.inc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define DEFINE_BINARY_TEST_OP(NAME, ENQUEUE, EVALUATE) \ + template \ + class NAME final : public BinaryTestOp { \ + public: \ + using Traits = BinaryTestOp::Traits; \ + using Test = BinaryTestOp::Test; \ + \ + explicit NAME(Test* test) : BinaryTestOp(test) {} \ + ~NAME() override {} \ + \ + Traits::EnqueueOp EnqueueOp() const override ENQUEUE; \ + \ + Traits::EvaluateOp EvaluateOp() const override EVALUATE; \ + }; \ + static_assert(true, "") + +DEFINE_BINARY_TEST_OP( + AddOp, { return AddEmptyBroadcastDimension(Add); }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return x + y; + }; + }); +DEFINE_BINARY_TEST_OP( + SubOp, { return AddEmptyBroadcastDimension(Sub); }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return x - y; + }; + }); +DEFINE_BINARY_TEST_OP( + MulOp, { return AddEmptyBroadcastDimension(Mul); }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return x * y; + }; + }); +DEFINE_BINARY_TEST_OP( + DivOp, { return AddEmptyBroadcastDimension(Div); }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return x / y; + }; + }); +DEFINE_BINARY_TEST_OP( + MaxOp, { return AddEmptyBroadcastDimension(Max); }, + { return ReferenceMax; }); +DEFINE_BINARY_TEST_OP( + MinOp, { return AddEmptyBroadcastDimension(Min); }, + { return ReferenceMin; }); +DEFINE_BINARY_TEST_OP( + PowOp, { return AddEmptyBroadcastDimension(Pow); }, { return std::pow; }); +DEFINE_BINARY_TEST_OP( + Atan2Op, { return AddEmptyBroadcastDimension(Atan2); }, + { return std::atan2; }); +DEFINE_BINARY_TEST_OP( + AbsComplexOp, + { return +[](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }; }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return std::abs(std::complex(x, y)); + }; + }); + +#undef DEFINE_BINARY_TEST_OP \ No newline at end of file diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test.h new file mode 100644 index 00000000000000..524ab7f53fb289 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test.h @@ -0,0 +1,74 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_OP_TEST_H_ +#define XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_OP_TEST_H_ + +#include + +#include "xla/tests/exhaustive/exhaustive_op_test_base.h" +#include "xla/tests/exhaustive/platform.h" + +namespace xla { +namespace exhaustive_op_test { + +// openXLA-specific ExhaustiveOpTestBase subclass. +// +// Holds utility functions related to determining the execution platform. +// +// Type Parameters: +// - T: The primitive type being tested. +// - N: The number of operands that the function being tested takes. +// +// Pure Virtual Functions: +// - GetInputSize +// - FillInput +template +class ExhaustiveOpTest : public ExhaustiveOpTestBase { + public: + using Traits = ExhaustiveOpTestBase::Traits; + + ExhaustiveOpTest() : platform_(*this->client_->platform()) {} + + bool RelaxedDenormalSigns() const override { + return !platform_.IsNvidiaGpu(); + } + + const Platform& Platform() { return platform_; } + + // DEPRECATED: Only kept until exhaustive_unary_complex_test is merged into + // exhaustive_unary_test. Use the new TestOp framework for + // exhaustive_unary_test. + bool IsGpu() const { return platform_.IsGpu(); } + bool IsCpu() const { return platform_.IsCpu(); } + + static typename Traits::ErrorSpecGen GetDefaultSpecGenerator() { + return exhaustive_op_test::GetDefaultSpecGenerator(); + } + + protected: + const class Platform platform_; +}; + +template +using ExhaustiveUnaryTest = ExhaustiveOpTest; + +template +using ExhaustiveBinaryTest = ExhaustiveOpTest; + +} // namespace exhaustive_op_test +} // namespace xla + +#endif // XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_OP_TEST_H_ diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc index b18bbec9c5ad04..1e393f20078e32 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc @@ -60,10 +60,6 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { -int eup_version = 0; - -int GetEupVersion() { return eup_version; } - bool dump_values = false; bool ShouldDumpValues() { return dump_values; } @@ -198,7 +194,6 @@ int GetCacheLocation(const std::array& input) { } // The inverse function of GetCacheLocation. - template ::value>::type* = nullptr> RetT FromCacheLocationComponent(int cache_loc) { @@ -568,11 +563,11 @@ ExhaustiveOpTestBase::GetTestValuesWithSubnormalSubstitutions( ComponentNativeRefT value) { std::vector test_values; if (std::fpclassify(value) == FP_SUBNORMAL) { - test_values.reserve(relaxed_denormal_signs_ ? 3 : 2); + test_values.reserve(RelaxedDenormalSigns() ? 3 : 2); test_values.push_back(std::copysign(0, value)); test_values.push_back( std::copysign(std::numeric_limits::min(), value)); - if (relaxed_denormal_signs_) { + if (RelaxedDenormalSigns()) { test_values.push_back(std::copysign(0, -value)); } } else { diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.h index 9f90e64f9bc392..2e8b4ab16e9aa6 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -40,16 +39,11 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { -// Access this through GetEupVersion. -extern int eup_version; - -// Get the TPU EUP version (if it was provided). -int GetEupVersion(); - // Return if the user specified dumping all tested values with their expected // and actual results. bool ShouldDumpValues(); +// Add all extra CLI flags that are used by ExhaustiveOpTestBase. void AddExhaustiveFlags(std::vector& flag_list); // Base class from which all exhaustive tests should inherit. @@ -60,10 +54,16 @@ void AddExhaustiveFlags(std::vector& flag_list); // Type Parameters: // - T: The primitive type being tested. // - N: The number of operands that the function being tested takes. +// +// Pure Virtual Functions: +// - GetInputSize +// - FillInput +// - RelaxedDenormalSigns template class ExhaustiveOpTestBase : public ClientLibraryTestBase { public: using Traits = ExhaustiveOpTestTraits; + static constexpr PrimitiveType kT = Traits::kT; using NativeT = typename Traits::NativeT; using NativeRefT = typename Traits::NativeRefT; @@ -85,10 +85,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { using ErrorSpecGen = typename Traits::ErrorSpecGen; ExhaustiveOpTestBase() - : ty_(T), - platform_(client_->platform()->Name()), - eup_version_(xla::exhaustive_op_test::GetEupVersion()), - should_dump_values_(xla::exhaustive_op_test::ShouldDumpValues()) { + : should_dump_values_(xla::exhaustive_op_test::ShouldDumpValues()) { SetFastMathDisabled(true); // Run all HLO passes. In particular, constant folding is disabled by @@ -105,8 +102,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // uint64_t. This function is used to convert such a bit pattern stored as // uint64_t to the input value for T. static ComponentNativeT ConvertValue(uint64_t bits) { - using I = ComponentIntegralNativeT; - I used_bits = static_cast(bits); + auto used_bits = static_cast(bits); return BitCast(used_bits); } @@ -116,13 +112,22 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // Fills the literals with values to test for. virtual void FillInput(LiteralInputs* literals) = 0; + // If true, allows denormals to be flushed to non-sign-preserving 0. + // + // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of + // a negative number) or -inf (flush the denormal to sign-preserving zero, + // then sqrt(-0)). When true, we'll also accept 0 (sqrt(0)). + // + // XLA:GPU preserves denormal signs, but other backends don't. + virtual bool RelaxedDenormalSigns() const = 0; + // Enable debug logging for the invocation of the lambda. // - // This is intended to be used to wrap a call to `Run`, which will then log - // extra debug information for a failure such as the calculated absolute, - // relative, and distance errors. In addition, in an effort to reduce output - // log size, this will trigger an ASSERT failure to early return from a test - // at the first failure. + // This is intended to be used to wrap a call to `Run`, which will then + // log extra debug information for a failure such as the calculated + // absolute, relative, and distance errors. In addition, in an effort to + // reduce output log size, this will trigger an ASSERT failure to early + // return from a test at the first failure. template , int> = 0> void EnableDebugLoggingForScope(Callable&& work) { @@ -218,41 +223,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { ErrorSpecGen error_spec_gen, OutputRangeCheck check_valid_range = nullptr); - const std::string& Platform() { return platform_; } - - bool IsGpu(const std::string& platform) const { return platform == "CUDA"; } - bool IsCpu(const std::string& platform) const { return platform == "Host"; } - bool IsTpu(const std::string& platform) const { - return !IsGpu(platform) && !IsCpu(platform); - } - - int EupVersion() const { return eup_version_; } - bool IsPreV5Tpu(const std::string& platform) const { - return IsTpu(platform) && eup_version_ < 2; - } - bool IsPreV6Tpu(const std::string& platform) const { - return IsTpu(platform) && eup_version_ < 3; - } - protected: - // The primitive type being tested. - const PrimitiveType ty_; - - // The platform under test. - const std::string platform_; - - // Version of the EUP for a TPU target. Only relevant for TPU platforms. - const int eup_version_; - - // If true, allows denormals to be flushed to non-sign-preserving 0. - // - // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of - // a negative number) or -inf (flush the denormal to sign-preserving zero, - // then sqrt(-0)). But with this as true, we'll also accept 0 (sqrt(0)). - // - // XLA:GPU preserves denormal signs, but other backends don't. - bool relaxed_denormal_signs_ = platform_ != "CUDA"; - // Indicates if files of the expected and actual values should be dumped. bool should_dump_values_ = false; @@ -261,24 +232,6 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { bool should_emit_debug_logging_ = false; }; -template -class ExhaustiveUnaryTest : public ExhaustiveOpTestBase { - public: - static typename ExhaustiveOpTestTraits::ErrorSpecGen - GetDefaultSpecGenerator() { - return exhaustive_op_test::GetDefaultSpecGenerator(); - } -}; - -template -class ExhaustiveBinaryTest : public ExhaustiveOpTestBase { - public: - static typename ExhaustiveOpTestTraits::ErrorSpecGen - GetDefaultSpecGenerator() { - return exhaustive_op_test::GetDefaultSpecGenerator(); - } -}; - } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc index aa1a501a73b42c..c8e34d4ba8a880 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -15,11 +15,40 @@ limitations under the License. #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include + +#include "xla/tests/exhaustive/error_spec.h" #include "xla/types.h" namespace xla { namespace exhaustive_op_test { +template +/* static */ typename ExhaustiveOpTestTraits::ErrorSpecGen +ExhaustiveOpTestTraits::FallbackErrorSpecGen() { + if constexpr (N == 1) { + return +[](NativeT) { return ErrorSpec{}; }; + } else if constexpr (N == 2) { + return +[](NativeT, NativeT) { return ErrorSpec{}; }; + } else { + static_assert(false, + "ExhaustieOpTestTraits::DefaultErrorSpecGen() is only " + "implemented for N == 1 and N == 2."); + } +} + +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; + +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; + bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); } bool IsSubnormalReal(xla::complex128 value) { diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h index 79200f6107d8df..62bf6786330ec4 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -117,6 +118,12 @@ class ExhaustiveOpTestTraits { N == 1, ErrorSpec (*)(NativeT), std::conditional_t>>; + + // Returns an ErrorSpecGen that sets no error tolerances. + // + // The intention of this default is to force test writers to tighten bounds at + // least somewhat and not rely on overly large default tolerances. + static ErrorSpecGen FallbackErrorSpecGen(); }; template @@ -231,6 +238,21 @@ typename ExhaustiveOpTestTraits::ErrorSpecGen GetDefaultSpecGenerator() { DefaultSpecGenerator); } +template +typename Traits::ErrorSpecGen PickFirstErrorSpecGenPresent( + std::initializer_list error_specs) { + typename Traits::ErrorSpecGen ret = Traits::FallbackErrorSpecGen(); + for (auto it = error_specs.begin(); it != error_specs.end(); it++) { + // Check if the ErrorSpecGen is nullptr to indicate it is not set. Replace + // ret with the first non-nullptr ErrorSpecGen. + if (*it != nullptr) { + ret = *it; + break; + } + } + return ret; +} + // Determines if the real component of the complex number is subnormal (either // sign). // @@ -288,7 +310,7 @@ bool IsMinNormal(NativeT value) { std::is_same_v) { return IsMinNormalReal(value) || IsMinNormalImaginary(value); } else { - return std::abs(value) == std::numeric_limits::min(); + return std::abs(value) == std::numeric_limits::min(); // NOLINT } } @@ -800,7 +822,7 @@ T ReferenceMin(T x, T y) { inline std::function AddEmptyBroadcastDimension( std::function)> build_method) { - return [&](XlaOp src0, XlaOp src1) -> XlaOp { + return [build_method](XlaOp src0, XlaOp src1) -> XlaOp { return build_method(src0, src1, {}); }; } diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc index bdc0ca990268ef..f93f5e31f544b0 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/literal.h" #include "xla/tests/exhaustive/error_spec.h" +#include "xla/tests/exhaustive/exhaustive_op_test.h" #include "xla/tests/exhaustive/exhaustive_op_test_base.h" #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" #include "xla/tests/test_macros.h" @@ -160,7 +161,7 @@ UNARY_TEST_COMPLEX_64(Rsqrt, { return ErrorSpec::Builder().strict_signed_zeros().build(); }; - if (IsCpu(platform_)) { + if (IsCpu()) { error_spec_gen = +[](complex64 x) { return ErrorSpec::Builder() .abs_err(RsqrtCpuGpuAbsErr(x)) @@ -171,7 +172,7 @@ UNARY_TEST_COMPLEX_64(Rsqrt, { }; } - if (IsGpu(platform_)) { + if (IsGpu()) { error_spec_gen = +[](complex64 x) { return ErrorSpec::Builder() .abs_err(RsqrtCpuGpuAbsErr(x)) @@ -251,7 +252,7 @@ UNARY_TEST_COMPLEX_128(Log, { return ErrorSpec::Builder().strict_signed_zeros().build(); }; - if (IsCpu(platform_) || IsGpu(platform_)) { + if (IsCpu() || IsGpu()) { error_spec_gen = +[](complex128 x) { // TODO(rmlarsen): see b/162664705 and b/138578594 bool should_skip = std::isnan(x.real()) || std::isnan(x.imag()); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.h b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.h index 0c5e839c71aa69..2b3fa8f3c34a3a 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.h @@ -16,152 +16,24 @@ limitations under the License. #ifndef XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_UNARY_TEST_DEFINITIONS_H_ #define XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_UNARY_TEST_DEFINITIONS_H_ -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "xla/literal.h" -#include "xla/tests/exhaustive/exhaustive_op_test_base.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" +#include // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_unary_test_definitions.inc + +#include "absl/log/check.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "absl/log/log.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "absl/types/span.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "xla/literal.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "xla/tests/exhaustive/exhaustive_op_test.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "xla/tests/test_macros.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc namespace xla { namespace exhaustive_op_test { -// Exhaustive test for unary operations for <= 32bit floating point types. -// -// Test parameter is a tuple containing -// - primitive type under test, -// - (begin, end) range under test, as zero-extended int64_ts bitcast to the -// primitive type under test. -template -class Exhaustive32BitOrLessUnaryTest - : public ExhaustiveUnaryTest, - public ::testing::WithParamInterface> { - public: - // Sets error parameters appropriately for testing tan. - void SetParamsForTan(); - - protected: - using typename ExhaustiveUnaryTest::NativeT; - - private: - int64_t GetInputSize() override { - auto [begin, end] = GetParam(); - return end - begin; - } - - // Generates all the input values for the test. The range of the bit - // representation of the input values is described by the test parameter as - // a pair of int64_t representing the starting bit pattern and the ending - // pattern. Each bit representation is first truncated to the integral type of - // the same bit as the type being tested, if needed, and then bitcasted to the - // type being tested. - void FillInput(std::array* input_literal) override { - using IntegralT = - typename ExhaustiveOpTestBase::ComponentIntegralNativeT; - - auto [begin, end] = GetParam(); - if (VLOG_IS_ON(2)) { - // N.B.: Use INFO directly instead of doing another thread-safe VLOG - // check. - LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; - LOG(INFO) << "\tfrom=" << begin << "; hex=" << std::hex << begin - << "; float=" << *reinterpret_cast(&begin) - << " (inclusive)"; - LOG(INFO) << "\tto=" << end << "; hex=" << std::hex << end - << "; float=" << *reinterpret_cast(&end) - << " (exclusive)"; - LOG(INFO) << "\ttotal values to test=" << (end - begin); - } - - int64_t input_size = (*input_literal)[0].element_count(); - CHECK_EQ(input_size, end - begin); - - absl::Span input_arr = (*input_literal)[0].data(); - for (int64_t i = 0; i < input_size; i++) { - IntegralT input_val = i + begin; - input_arr[i] = this->ConvertValue(input_val); - } - } -}; - -using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; -using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest; -using ExhaustiveBF16UnaryTest = Exhaustive32BitOrLessUnaryTest; - -// Exhaustive test for unary operations for double. -// -// Test parameter is a tuple containing -// - primitive type under test, -// - FpValues representing a set of double values. -class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, - public ::testing::WithParamInterface { - private: - int64_t GetInputSize() override { - FpValues values = GetParam(); - return values.GetTotalNumValues(); - } - - void FillInput(std::array* input_literal) override { - FpValues fp_values = GetParam(); - int64_t input_size = (*input_literal)[0].element_count(); - if (VLOG_IS_ON(2)) { - // N.B.: Use INFO directly instead of doing another thread-safe VLOG - // check. - LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; - LOG(INFO) << "\t" << fp_values.ToString(); - LOG(INFO) << "\ttotal values to test=" << input_size; - } - - uint64_t i = 0; - absl::Span input_arr = (*input_literal)[0].data(); - for (auto bits : fp_values) { - input_arr[i] = this->ConvertValue(bits); - ++i; - } - CHECK_EQ(i, input_size); - } -}; - -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -#define UNARY_TEST_BF16(test_name, ...) \ - XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ - __VA_ARGS__ -#else -#define UNARY_TEST_BF16(test_name, ...) -#endif - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -#define UNARY_TEST_F16(test_name, ...) \ - XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \ - __VA_ARGS__ -#else -#define UNARY_TEST_F16(test_name, ...) -#endif - -#define UNARY_TEST_F32(test_name, ...) \ - XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \ - __VA_ARGS__ - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -#define UNARY_TEST_F64(test_name, ...) \ - XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \ - __VA_ARGS__ -#else -#define UNARY_TEST_F64(test_name, ...) -#endif - -#define UNARY_TEST(test_name, ...) \ - UNARY_TEST_BF16(test_name, __VA_ARGS__) \ - UNARY_TEST_F16(test_name, __VA_ARGS__) \ - UNARY_TEST_F32(test_name, __VA_ARGS__) \ - UNARY_TEST_F64(test_name, __VA_ARGS__) +#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.inc" } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc new file mode 100644 index 00000000000000..64f491fbb7bcb7 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc @@ -0,0 +1,140 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Exhaustive test for unary operations for <= 32bit floating point types. +// +// Test parameter is a tuple containing +// - primitive type under test, +// - (begin, end) range under test, as zero-extended int64_ts bitcast to the +// primitive type under test. +template +class Exhaustive32BitOrLessUnaryTest + : public ExhaustiveUnaryTest, + public ::testing::WithParamInterface> { + protected: + int64_t GetInputSize() override { + auto [begin, end] = GetParam(); + return end - begin; + } + + // Generates all the input values for the test. The range of the bit + // representation of the input values is described by the test parameter as + // a pair of int64_t representing the starting bit pattern and the ending + // pattern. Each bit representation is first truncated to the integral type of + // the same bit as the type being tested, if needed, and then bitcasted to the + // type being tested. + void FillInput(std::array* input_literal) override { + using NativeT = typename ExhaustiveUnaryTest::NativeT; + using ComponentIntegralNativeT = + typename ExhaustiveUnaryTest::ComponentIntegralNativeT; + + auto [begin, end] = GetParam(); + if (VLOG_IS_ON(2)) { + // N.B.: Use INFO directly instead of doing another thread-safe VLOG + // check. + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + LOG(INFO) << "\tfrom=" << begin << "; hex=" << std::hex << begin + << "; float=" << *reinterpret_cast(&begin) + << " (inclusive)"; + LOG(INFO) << "\tto=" << end << "; hex=" << std::hex << end + << "; float=" << *reinterpret_cast(&end) + << " (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + + int64_t input_size = (*input_literal)[0].element_count(); + CHECK_EQ(input_size, end - begin); + + absl::Span input_arr = (*input_literal)[0].data(); + for (int64_t i = 0; i < input_size; i++) { + ComponentIntegralNativeT input_val = + // We guarantee i + begin will be within range. + static_cast(i + begin); + input_arr[i] = this->ConvertValue(input_val); + } + } +}; + +using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveBF16UnaryTest = Exhaustive32BitOrLessUnaryTest; + +// Exhaustive test for unary operations for double. +// +// Test parameter is a tuple containing +// - primitive type under test, +// - FpValues representing a set of double values. +class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, + public ::testing::WithParamInterface { + protected: + int64_t GetInputSize() override { + FpValues values = GetParam(); + return values.GetTotalNumValues(); + } + + void FillInput(std::array* input_literal) override { + FpValues fp_values = GetParam(); + int64_t input_size = (*input_literal)[0].element_count(); + if (VLOG_IS_ON(2)) { + // N.B.: Use INFO directly instead of doing another thread-safe VLOG + // check. + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\t" << fp_values.ToString(); + LOG(INFO) << "\ttotal values to test=" << input_size; + } + + uint64_t i = 0; + absl::Span input_arr = (*input_literal)[0].data(); + for (auto bits : fp_values) { + input_arr[i] = ExhaustiveOpTestBase::ConvertValue(bits); + ++i; + } + CHECK_EQ(i, input_size); + } +}; + +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +#define UNARY_TEST_BF16(test_name, ...) \ + XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_BF16(test_name, ...) +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +#define UNARY_TEST_F16(test_name, ...) \ + XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_F16(test_name, ...) +#endif + +#define UNARY_TEST_F32(test_name, ...) \ + XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \ + __VA_ARGS__ + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#define UNARY_TEST_F64(test_name, ...) \ + XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_F64(test_name, ...) +#endif + +#define UNARY_TEST(test_name, ...) \ + UNARY_TEST_BF16(test_name, __VA_ARGS__) \ + UNARY_TEST_F16(test_name, __VA_ARGS__) \ + UNARY_TEST_F32(test_name, __VA_ARGS__) \ + UNARY_TEST_F64(test_name, __VA_ARGS__) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.cc index 69ab47f64412ed..f252590a0edfd0 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.cc @@ -13,34 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include // IWYU pragma: keep, exhaustive_unary_test_f32_and_smaller_instantiation.inc -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_unary_test_f32_and_smaller_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" // IWYU pragma: keep, exhaustive_unary_test_f32_and_smaller_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_unary_test_f32_and_smaller_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, - ::testing::Values(std::make_pair(0, 1 << 16))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); -#endif - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16UnaryTest, - ::testing::Values(std::make_pair(0, 1 << 16))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); -#endif - -INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); +#include "xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc new file mode 100644 index 00000000000000..a958e2bbc88c74 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 16))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 16))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); +#endif + +INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, + ::testing::ValuesIn(CreateExhaustiveF32Ranges())); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.cc index 9b94a7ced85959..6271809d97df5f 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.cc @@ -13,37 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_unary_test_f64_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" // IWYU pragma: keep, exhaustive_unary_test_f64_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_unary_test_f64_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32UnaryTest); - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -INSTANTIATE_TEST_SUITE_P( - SpecialValues, ExhaustiveF64UnaryTest, - ::testing::ValuesIn(CreateFpValuesForBoundaryTest())); - -INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest, - ::testing::Values(GetNormals(1000))); - -// Tests a total of 4,000,000,000 inputs, with 16,000,000 inputs in each -// sub-test, to keep the peak memory usage low. -INSTANTIATE_TEST_SUITE_P( - LargeAndSmallMagnitudeNormalValues, ExhaustiveF64UnaryTest, - ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( - 4000000000ull, 16000000))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); -#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#include "xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc new file mode 100644 index 00000000000000..b558fb85f3f8e8 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc @@ -0,0 +1,38 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32UnaryTest); + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF64UnaryTest, + ::testing::ValuesIn(CreateFpValuesForBoundaryTest())); + +INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest, + ::testing::Values(GetNormals(1000))); + +// Tests a total of 4,000,000,000 inputs, with 16,000,000 inputs in each +// sub-test, to keep the peak memory usage low. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnitudeNormalValues, ExhaustiveF64UnaryTest, + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( + 4000000000ull, 16000000))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); +#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc index e07742874f6e98..9320e0c9ccedbc 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc @@ -13,23 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include // IWYU pragma: keep, exhaustive_unary_test_ops.inc +#include // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include // NOLINT #include -#include -#include +#include // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include #include -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/client/lib/constants.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc +#include "xla/client/lib/math.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc +#include "xla/client/xla_builder.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include "xla/tests/exhaustive/error_spec.h" -#include "xla/tests/exhaustive/exhaustive_op_test_base.h" +#include "xla/tests/exhaustive/exhaustive_op_test.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" #include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" -#include "xla/types.h" +#include "xla/tests/exhaustive/test_op.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc #ifdef __FAST_MATH__ #error "Can't be compiled with fast math on" @@ -39,624 +38,279 @@ namespace xla { namespace exhaustive_op_test { namespace { -using Eigen::half; +#include "xla/tests/exhaustive/exhaustive_unary_test_ops.inc" -template -T EvaluatePolynomial(T x, const std::array& coeffs) { - // Evaluate the polynomial as accurately as we can using double precision and - // FMA. - double result = 0; - double x_d = static_cast(x); - for (T c : coeffs) { - result = std::fma(result, x_d, static_cast(c)); - } - return static_cast(result); -} +UNARY_TEST(Log, { LogOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Log1p, { Log1pOp(this).Error(GetDefaultSpecGenerator()).Run(); }) -// There's no std::erfinv, so we have to implement it ourselves. This follows -// Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a -// different implementation from that in math.cc. -template -NativeRefT HostErfInv(NativeRefT x) { - std::array kPolyA = { - 8.8709406962545514830200e2, 1.1819493347062294404278e4, - 2.3782041382114385731252e4, 1.6235862515167575384252e4, - 4.8548868893843886794648e3, 6.9706266534389598238465e2, - 4.7072688112383978012285e1, 1.1975323115670912564578e0, - }; - std::array kPolyB = { - 5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4, - 2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2, - 4.2313330701600911252e1, 1.0000000000000000000e0, - }; - std::array kPolyC = { - 7.74545014278341407640e-4, 2.27238449892691845833e-2, - 2.41780725177450611770e-1, 1.27045825245236838258e0, - 3.64784832476320460504e0, 5.76949722146069140550e0, - 4.63033784615654529590e0, 1.42343711074968357734e0, - }; - std::array kPolyD = { - 1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4, - 2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1, - 9.7547832001787427186894837e-1, 2.3707661626024532365971225e0, - 2.9036514445419946173133295e0, 1.4142135623730950488016887e0, - }; - std::array kPolyE = { - 2.01033439929228813265e-7, 2.71155556874348757815e-5, - 1.24266094738807843860e-3, 2.65321895265761230930e-2, - 2.96560571828504891230e-1, 1.78482653991729133580e0, - 5.46378491116411436990e0, 6.65790464350110377720e0, - }; - std::array kPolyF = { - 2.891024605872965461538222e-15, 2.010321207683943062279931e-7, - 2.611088405080593625138020e-5, 1.112800997078859844711555e-3, - 2.103693768272068968719679e-2, 1.936480946950659106176712e-1, - 8.482908416595164588112026e-1, 1.414213562373095048801689e0, - }; - - if (std::abs(x) > 1 || std::isnan(x)) { - return std::numeric_limits::quiet_NaN(); - } - if (std::abs(x) == 1) { - return std::copysign(std::numeric_limits::infinity(), x); - } - - double unsigned_result = [&] { - double y = std::abs(x); - if (y <= 0.85) { - double r = 0.180625 - 0.25 * y * y; - return (y * EvaluatePolynomial(r, kPolyA)) / - EvaluatePolynomial(r, kPolyB); - } else { - double r = std::sqrt(std::log(2.0) - std::log1p(-y)); - if (r <= 5.0) { - r -= 1.6; - return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD); - } else { - r -= 5; - return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF); - } - } - }(); - return static_cast(std::copysign(unsigned_result, x)); -} - -// Digamma implementation using a polynomial from Cephes. Notably this is a -// different implementation from the one in math.cc. -template -NativeRefT HostDigamma(NativeRefT x) { - // Euler-Mascheroni constant - double kGamma = 0.57721566490153286061; - double kPi = M_PI; - - std::array kPoly = { - -4.16666666666666666667E-3, - 3.96825396825396825397E-3, - -8.33333333333333333333E-3, - 8.33333333333333333333E-2, - }; - - double reflection = 0; - if (x <= 0) { - double floor = std::floor(x); - if (x == floor) { - return std::numeric_limits::quiet_NaN(); - } - // Compute reflection term, pi * cot(pi * x). - reflection = x - floor; - if (reflection == 0.5) { - reflection = 0; - } else { - if (reflection > 0.5) { - reflection = x - (floor + 1.0f); - } - reflection = kPi / std::tan(kPi * reflection); - } - x = 1 - x; - } - - double result = 0; - if (x <= 10 && x == std::floor(x)) { - // Special case for integers <= 10. - for (int i = 1; i < x; ++i) { - result += 1.0 / i; - } - result -= kGamma; - } else { - double w = 0; - for (; x < 10; ++x) { - w += 1.0 / x; - } - if (x < 1e8) { - double z = 1.0 / (x * x); - result = z * EvaluatePolynomial(z, kPoly); - } - result = std::log(x) - 0.5 / x - result - w; - } - - // Compute the final, reflected value. - return static_cast(result - reflection); -} - -UNARY_TEST(Log, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(eps).build(); - }; - } - Run(Log, std::log, error_spec_gen); -}) -UNARY_TEST(Log1p, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(eps).build(); - }; - } - Run(Log1p, std::log1p, error_spec_gen); -}) -UNARY_TEST(Exp, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(min).rel_err(75 * eps).build(); - }; - } else if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(min).rel_err(33 * eps).build(); - }; - } - Run(Exp, std::exp, error_spec_gen); -}) - -UNARY_TEST(Expm1, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-5).rel_err(100 * eps).build(); - }; - } else if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2 * min).rel_err(33 * eps).build(); - }; - } - - Run(Expm1, std::expm1, error_spec_gen); -}) +UNARY_TEST(Exp, { ExpOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Expm1, { Expm1Op(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Logistic, { - // FIXME(rmlarsen): Break into region around zero and everything else. - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - NativeT atol = std::min(static_cast(0.004), - static_cast(200 * eps)); - return ErrorSpec::Builder().abs_err(atol).rel_err(0).build(); - }; - } - EvaluateOp fn = +[](NativeRefT x) { return 1.0f / (1.0f + std::exp(-x)); }; - auto range_checker = +[](NativeInputs in, NativeT out) { - if (Eigen::numext::isnan(in[0])) { - return Eigen::numext::isnan(out); - } - return Eigen::numext::abs(out) <= 1.0f; - }; - Run(Logistic, fn, error_spec_gen, range_checker); + LogisticOp(this) + .OutputRangeCheck(+[](NativeInputs in, NativeT out) { + if (std::isnan(in[0])) { + return std::isnan(out); + } + return std::abs(out) <= 1.0f; + }) + // FIXME(rmlarsen): Break into region around zero and everything else. + .Error(GetDefaultSpecGenerator()) + .Run(); }) // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but // this *did* find a bug, namely that some backends were assuming sqrt(x) == // pow(x, 0.5), but this is not true for x == -inf. -UNARY_TEST(PowOneHalf, { - EvaluateOp fn = +[](NativeRefT x) { return std::pow(x, 0.5f); }; - Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn); -}) - +UNARY_TEST(PowOneHalf, + { PowOneHalfOp(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Rsqrt, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(0) - .rel_err(2 * eps) - .strict_signed_zeros() - .build(); - }; - Run(Rsqrt, +[](NativeRefT x) { return 1 / std::sqrt(x); }, error_spec_gen); + RsqrtOp(this) + .Error(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(0) + .rel_err(2 * eps) + .strict_signed_zeros() + .build(); + }) + .Run(); }) - UNARY_TEST(Sqrt, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(0) - .rel_err(2 * eps) - .strict_signed_zeros() - .build(); - }; - Run(Sqrt, std::sqrt, error_spec_gen); -}) - -UNARY_TEST(Cbrt, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(0) - .rel_err(16 * eps) - .strict_signed_zeros() - .build(); - }; - if (IsCpu(platform_)) { - error_spec_gen = +[](NativeT x) { - // While GPUs and TPUs flush subnormal inputs to zero, the CPU returns a - // relatively inaccurate approximation for such inputs. Therefore we - // allow a small absolute error (e.g. ~9e-16 for F32). This corresponds - // to a 0.5% relative error for the smallest normalized floating point - // values, increasing gradually to 100% for the smallest subnormal - // value. - NativeT denorm_min = std::numeric_limits::denorm_min(); - double abs_err = std::cbrt(denorm_min); - - if constexpr (std::is_same_v) { + SqrtOp(this) + .Error(+[](NativeT x) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder() - .abs_err(abs_err) - .rel_err(70 * eps) + .abs_err(0) + .rel_err(2 * eps) .strict_signed_zeros() .build(); - } else { + }) + .Run(); +}) +UNARY_TEST(Cbrt, { + CbrtOp(this) + .Error(+[](NativeT x) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder() - .abs_err(abs_err) - .rel_err(10 * eps) + .abs_err(0) + .rel_err(16 * eps) .strict_signed_zeros() .build(); - } - }; - } - Run(Cbrt, std::cbrt, error_spec_gen); + }) + .CpuError(+[](NativeT x) { + // While GPUs flush subnormal inputs to zero, CPU returns a relatively + // inaccurate approximation for such inputs. Therefore we allow a small + // absolute error (e.g. ~9e-16 for F32). This corresponds to a 0.5% + // relative error for the smallest normalized floating point values, + // increasing gradually to 100% for the smallest subnormal value. + NativeT denorm_min = std::numeric_limits::denorm_min(); + double abs_err = std::cbrt(denorm_min); + + if constexpr (std::is_same_v) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(abs_err) + .rel_err(70 * eps) + .strict_signed_zeros() + .build(); + } else { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(abs_err) + .rel_err(10 * eps) + .strict_signed_zeros() + .build(); + } + }) + .Run(); }) // Tests for inverse hyperbolic functions. UNARY_TEST(Acosh, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-7).rel_err(50 * eps).build(); - }; - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(eps).build(); - }; - } - Run(Acosh, std::acosh, error_spec_gen); -}) -UNARY_TEST(Asinh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(eps).build(); - }; - } - Run(Asinh, std::asinh, error_spec_gen); -}) -UNARY_TEST(Atanh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-4).rel_err(eps).build(); - }; - } - Run(Atanh, std::atanh, error_spec_gen); + AcoshOp(this) + .Error(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(1e-7).rel_err(50 * eps).build(); + }) + .Run(); }) +UNARY_TEST(Asinh, { AsinhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Atanh, { AtanhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) // Tests for inverse trigonometric functions. UNARY_TEST(Acos, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (platform_ != "Host") { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-6).rel_err(10 * eps).build(); - }; - } - Run(Acos, std::acos, error_spec_gen); + AcosOp(this) + .Error(GetDefaultSpecGenerator()) + .GpuError(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(1e-6).rel_err(10 * eps).build(); + }) + .Run(); }) UNARY_TEST(Asin, { - auto error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2.0f * min).rel_err(10 * eps).build(); - }; - Run(Asin, std::asin, error_spec_gen); + AsinOp(this) + .Error(+[](NativeT x) { + NativeT min = std::numeric_limits::min(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(2.0f * min) + .rel_err(10 * eps) + .build(); + }) + .Run(); }) UNARY_TEST(Atan, { - auto error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2.0f * min).rel_err(20 * eps).build(); - }; - Run(Atan, std::atan, error_spec_gen); + AtanOp(this) + .Error(+[](NativeT x) { + NativeT min = std::numeric_limits::min(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(2.0f * min) + .rel_err(20 * eps) + .build(); + }) + .Run(); }) UNARY_TEST(Cosh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - // Cosh is always greater than or equal to 1, so an absolute - // tolerance does not make sense. - return ErrorSpec::Builder().abs_err(0).rel_err(100 * eps).build(); - }; - } - auto range_checker = - +[](NativeInputs in, NativeT actual) { return !(actual < 1); }; - Run(Cosh, std::cosh, error_spec_gen, range_checker); -}) - -UNARY_TEST(Sinh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-5).rel_err(100 * eps).build(); - }; - } - Run(Sinh, std::sinh, error_spec_gen); + CoshOp(this) + .Error(GetDefaultSpecGenerator()) + .OutputRangeCheck( + +[](NativeInputs in, NativeT actual) { return !(actual < 1); }) + .Run(); }) - +UNARY_TEST(Sinh, { SinhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Tanh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeT x) { - // The range of tanh is [-1:1], so no point in giving a relative - // tolerance when we have an absolute one. - return ErrorSpec::Builder().abs_err(5e-4).rel_err(0).build(); - }; - } - Run(Tanh, std::tanh, error_spec_gen, - [](NativeInputs in, NativeT out) -> bool { - if (Eigen::numext::isnan(in[0])) { - return Eigen::numext::isnan(out); + TanhOp(this) + .Error(GetDefaultSpecGenerator()) + .OutputRangeCheck([](NativeInputs in, NativeT out) -> bool { + if (std::isnan(in[0])) { + return std::isnan(out); } - return Eigen::numext::abs(out) <= 1.0f; - }); + return std::abs(out) <= 1.0f; + }) + .Run(); }) UNARY_TEST(Cos, { - auto range_checker = - +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }; - Run( - Cos, std::cos, - +[](NativeT) { + CosOp(this) + .Error(+[](NativeT) { // This error spec corresponds to a maximum relative error of 2 ULP. NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); - }, - range_checker); + }) + .OutputRangeCheck( + +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }) + .Run(); }) - UNARY_TEST(Sin, { - auto range_checker = - +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }; - Run( - Sin, std::sin, - +[](NativeT) { + SinOp(this) + .Error(+[](NativeT) { // This error spec corresponds to a maximum relative error of 2 ULP. NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); - }, - range_checker); + }) + .OutputRangeCheck( + +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }) + .Run(); }) - UNARY_TEST(Tan, { - Run( - Tan, std::tan, +[](NativeT) { + TanOp(this) + .Error(+[](NativeT) { // This error spec corresponds to a maximum relative error of 4 ULP. NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build(); - }); + }) + .Run(); }) -UNARY_TEST(Erf, { Run(Erf, std::erf); }) +UNARY_TEST(Erf, { ErfOp(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Erfc, { - auto error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2 * min).rel_err(50 * eps).build(); - }; - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2 * min).rel_err(100 * eps).build(); - }; - } - Run(Erfc, std::erfc, error_spec_gen); + ErfcOp(this) + .Error(+[](NativeT x) { + NativeT min = std::numeric_limits::min(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(2 * min).rel_err(50 * eps).build(); + }) + .Run(); }) UNARY_TEST(ErfInv, { - auto error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2 * min).rel_err(50 * eps).build(); - }; - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(5e-5).rel_err(2 * eps).build(); - }; - } - Run(ErfInv, HostErfInv, error_spec_gen); + ErfInvOp(this) + .Error(+[](NativeT x) { + NativeT min = std::numeric_limits::min(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(2 * min).rel_err(50 * eps).build(); + }) + .Run(); }) UNARY_TEST(Digamma, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); - }; - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(10 * eps).build(); - }; - } - Run(Digamma, HostDigamma, error_spec_gen); + DigammaOp(this) + .Error(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); + }) + .Run(); }) UNARY_TEST(Lgamma, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-5).rel_err(150 * eps).build(); - }; - if (IsGpu(platform_)) { - error_spec_gen = +[](NativeT x) { - if constexpr (std::is_same_v) { - // Very large error on the smallest subnormal input. - if (static_cast(std::abs(x)) == 4.9406564584124654e-324) { - return ErrorSpec::Builder().abs_err(0.05).build(); + LgammaOp(this) + .Error(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(1e-5).rel_err(150 * eps).build(); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + // Very large error on the smallest subnormal input. + if (static_cast(std::abs(x)) == 4.9406564584124654e-324) { + return ErrorSpec::Builder().abs_err(0.05).build(); + } else { + return ErrorSpec::Builder().distance_err(2).build(); + } } else { - return ErrorSpec::Builder().distance_err(2).build(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(1e-5).rel_err(5000 * eps).build(); } - } else { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-5).rel_err(5000 * eps).build(); - } - }; - } else if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(5e-4).rel_err(5000 * eps).build(); - }; - } - Run(Lgamma, std::lgamma, error_spec_gen); + }) + .Run(); }) -UNARY_TEST(Round, { Run(Round, std::round); }) - +UNARY_TEST(Round, { RoundOp(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(RoundNearestEven, { - auto error_spec_gen = +[](NativeT) { - return ErrorSpec::Builder().abs_err(0.0).rel_err(0.0).build(); - }; int curr_direction = fegetround(); fesetround(FE_TONEAREST); - Run(RoundNearestEven, std::nearbyint, error_spec_gen); + RoundNearestEvenOp(this).Run(); fesetround(curr_direction); }) UNARY_TEST(Reciprocal, { // Can be thought of as an absolute error of `<= // |std::numeric_limits::min()|`. - auto abs_err = +[](NativeT val) -> double { + auto* abs_err = +[](NativeT val) -> double { NativeT output = static_cast(1.0) / val; if (IsSubnormal(output)) { return std::numeric_limits::min(); } return 0.0; }; - auto abs_err_bf16 = +[](NativeT val) -> double { - NativeT output = static_cast(1.0) / val; - if (IsSubnormalOrMinNormal(output)) { - return std::numeric_limits::min(); - } - return 0.0; - }; - ErrorSpecGen error_spec_gen = [](NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - if (IsCpu(platform_)) { - error_spec_gen = [&](NativeT val) { - return ErrorSpec::Builder() - .abs_err(abs_err(val)) - .strict_signed_zeros() - .build(); - }; - } - if (IsGpu(platform_)) { - error_spec_gen = [&](NativeT val) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(abs_err(val)) - .rel_err(eps) - .strict_signed_zeros() - .build(); - }; - } - if (IsTpu(platform_)) { - error_spec_gen = [&](NativeT val) { - if constexpr (std::is_same_v) { - return ErrorSpec::Builder() - .abs_err(abs_err_bf16(val)) - .strict_signed_zeros() - .build(); - } else if constexpr (std::is_same_v) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - } else if constexpr (std::is_same_v) { - NativeT eps = std::numeric_limits::epsilon(); + ReciprocalOp(this) + .CpuError([&](NativeT val) { return ErrorSpec::Builder() .abs_err(abs_err(val)) - .rel_err(eps) .strict_signed_zeros() .build(); - } else { - return ErrorSpec{}; - } - }; - } - if (IsPreV6Tpu(platform_)) { - error_spec_gen = [&](NativeT val) { - if constexpr (std::is_same_v) { - return ErrorSpec::Builder() - .abs_err(abs_err_bf16(val)) - .strict_signed_zeros() - .build(); - } else if constexpr (std::is_same_v) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - } else if constexpr (std::is_same_v) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(abs_err(val)) - .rel_err(34 * eps) - .strict_signed_zeros() - .build(); - } else { - return ErrorSpec{}; - } - }; - } - if (IsPreV5Tpu(platform_)) { - error_spec_gen = [&](NativeT val) { - if constexpr (std::is_same_v) { - return ErrorSpec::Builder() - .abs_err(abs_err_bf16(val)) - .strict_signed_zeros() - .build(); - } else if constexpr (std::is_same_v) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - } else if constexpr (std::is_same_v) { + }) + .GpuError([&](NativeT val) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder() .abs_err(abs_err(val)) - .rel_err(136 * eps) + .rel_err(eps) .strict_signed_zeros() .build(); - } else { - return ErrorSpec{}; - } - }; - } - Run(Reciprocal, +[](NativeRefT x) { return 1 / x; }, error_spec_gen); + }) + .Run(); }) } // namespace diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_ops.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_ops.inc new file mode 100644 index 00000000000000..c1e6608c444176 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_ops.inc @@ -0,0 +1,217 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +template +T EvaluatePolynomial(T x, const std::array& coeffs) { + // Evaluate the polynomial as accurately as we can using double precision and + // FMA. + double result = 0; + double x_d = static_cast(x); + for (T c : coeffs) { + result = std::fma(result, x_d, static_cast(c)); + } + return static_cast(result); +} + +// There's no std::erfinv, so we have to implement it ourselves. This follows +// Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a +// different implementation from that in math.cc. +template +NativeRefT HostErfInv(NativeRefT x) { + const std::array poly_a = { + 8.8709406962545514830200e2, 1.1819493347062294404278e4, + 2.3782041382114385731252e4, 1.6235862515167575384252e4, + 4.8548868893843886794648e3, 6.9706266534389598238465e2, + 4.7072688112383978012285e1, 1.1975323115670912564578e0, + }; + const std::array poly_b = { + 5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4, + 2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2, + 4.2313330701600911252e1, 1.0000000000000000000e0, + }; + const std::array poly_c = { + 7.74545014278341407640e-4, 2.27238449892691845833e-2, + 2.41780725177450611770e-1, 1.27045825245236838258e0, + 3.64784832476320460504e0, 5.76949722146069140550e0, + 4.63033784615654529590e0, 1.42343711074968357734e0, + }; + const std::array poly_d = { + 1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4, + 2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1, + 9.7547832001787427186894837e-1, 2.3707661626024532365971225e0, + 2.9036514445419946173133295e0, 1.4142135623730950488016887e0, + }; + const std::array poly_e = { + 2.01033439929228813265e-7, 2.71155556874348757815e-5, + 1.24266094738807843860e-3, 2.65321895265761230930e-2, + 2.96560571828504891230e-1, 1.78482653991729133580e0, + 5.46378491116411436990e0, 6.65790464350110377720e0, + }; + const std::array poly_f = { + 2.891024605872965461538222e-15, 2.010321207683943062279931e-7, + 2.611088405080593625138020e-5, 1.112800997078859844711555e-3, + 2.103693768272068968719679e-2, 1.936480946950659106176712e-1, + 8.482908416595164588112026e-1, 1.414213562373095048801689e0, + }; + + if (std::abs(x) > 1 || std::isnan(x)) { + return std::numeric_limits::quiet_NaN(); + } + if (std::abs(x) == 1) { + return static_cast( + std::copysign(std::numeric_limits::infinity(), x)); + } + + double unsigned_result = [&] { + double y = std::abs(x); + if (y <= 0.85) { + double r = 0.180625 - 0.25 * y * y; + return (y * EvaluatePolynomial(r, poly_a)) / + EvaluatePolynomial(r, poly_b); + } + + double r = std::sqrt(std::log(2.0) - std::log1p(-y)); + if (r <= 5.0) { + r -= 1.6; + return EvaluatePolynomial(r, poly_c) / EvaluatePolynomial(r, poly_d); + } + + r -= 5; + return EvaluatePolynomial(r, poly_e) / EvaluatePolynomial(r, poly_f); + }(); + return static_cast(std::copysign(unsigned_result, x)); +} + +// Digamma implementation using a polynomial from Cephes. Notably this is a +// different implementation from the one in math.cc. +template +NativeRefT HostDigamma(NativeRefT x) { + // Euler-Mascheroni constant + const double gamma_constant = 0.57721566490153286061; + + const std::array poly = { + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + double reflection = 0; + if (x <= 0) { + double floor = std::floor(x); + if (x == floor) { + return std::numeric_limits::quiet_NaN(); + } + // Compute reflection term, pi * cot(pi * x). + reflection = x - floor; + if (reflection == 0.5) { + reflection = 0; + } else { + if (reflection > 0.5) { + reflection = x - (floor + 1.0f); + } + reflection = M_PI / std::tan(M_PI * reflection); + } + x = 1 - x; + } + + double result = 0; + if (x <= 10 && x == std::floor(x)) { + // Special case for integers <= 10. + for (size_t i = 1; i < static_cast(std::floor(x)); ++i) { + result += 1.0 / static_cast(i); + } + result -= gamma_constant; + } else { + double w = 0; + while (x < 10) { + w += 1.0 / x; + ++x; + } + if (x < 1e8) { + double z = 1.0 / (x * x); + result = z * EvaluatePolynomial(z, poly); + } + result = std::log(x) - 0.5 / x - result - w; + } + + // Compute the final, reflected value. + return static_cast(result - reflection); +} + +#define DEFINE_UNARY_TEST_OP(NAME, ENQUEUE, EVALUATE) \ + template \ + class NAME final : public UnaryTestOp { \ + public: \ + using Traits = UnaryTestOp::Traits; \ + using Test = UnaryTestOp::Test; \ + \ + explicit NAME(Test* test) : UnaryTestOp(test) {} \ + ~NAME() override {} \ + \ + Traits::EnqueueOp EnqueueOp() const override ENQUEUE; \ + \ + Traits::EvaluateOp EvaluateOp() const override EVALUATE; \ + }; \ + static_assert(true, "") + +DEFINE_UNARY_TEST_OP(LogOp, { return Log; }, { return std::log; }); +DEFINE_UNARY_TEST_OP(Log1pOp, { return Log1p; }, { return std::log1p; }); +DEFINE_UNARY_TEST_OP(ExpOp, { return Exp; }, { return std::exp; }); +DEFINE_UNARY_TEST_OP(Expm1Op, { return Expm1; }, { return std::expm1; }); +DEFINE_UNARY_TEST_OP( + LogisticOp, { return Logistic; }, + { + return +[](Traits::NativeRefT x) { return 1.0f / (1.0f + std::exp(-x)); }; + }); +DEFINE_UNARY_TEST_OP( + PowOneHalfOp, + { return [](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }; }, + { return +[](Traits::NativeRefT x) { return std::pow(x, 0.5f); }; }); +DEFINE_UNARY_TEST_OP( + RsqrtOp, { return Rsqrt; }, + { return +[](Traits::NativeRefT x) { return 1 / std::sqrt(x); }; }); +DEFINE_UNARY_TEST_OP(SqrtOp, { return Sqrt; }, { return std::sqrt; }); +DEFINE_UNARY_TEST_OP(CbrtOp, { return Cbrt; }, { return std::cbrt; }); +DEFINE_UNARY_TEST_OP(AcoshOp, { return Acosh; }, { return std::acosh; }); +DEFINE_UNARY_TEST_OP(AsinhOp, { return Asinh; }, { return std::asinh; }); +DEFINE_UNARY_TEST_OP(AtanhOp, { return Atanh; }, { return std::atanh; }); +DEFINE_UNARY_TEST_OP(AcosOp, { return Acos; }, { return std::acos; }); +DEFINE_UNARY_TEST_OP(AsinOp, { return Asin; }, { return std::asin; }); +DEFINE_UNARY_TEST_OP(AtanOp, { return Atan; }, { return std::atan; }); +DEFINE_UNARY_TEST_OP(CoshOp, { return Cosh; }, { return std::cosh; }); +DEFINE_UNARY_TEST_OP(SinhOp, { return Sinh; }, { return std::sinh; }); +DEFINE_UNARY_TEST_OP(TanhOp, { return Tanh; }, { return std::tanh; }); +DEFINE_UNARY_TEST_OP(CosOp, { return Cos; }, { return std::cos; }); +DEFINE_UNARY_TEST_OP(SinOp, { return Sin; }, { return std::sin; }); +DEFINE_UNARY_TEST_OP(TanOp, { return Tan; }, { return std::tan; }); +DEFINE_UNARY_TEST_OP(ErfOp, { return Erf; }, { return std::erf; }); +DEFINE_UNARY_TEST_OP(ErfcOp, { return Erfc; }, { return std::erfc; }); +DEFINE_UNARY_TEST_OP( + ErfInvOp, { return ErfInv; }, + { return HostErfInv; }); +DEFINE_UNARY_TEST_OP( + DigammaOp, { return Digamma; }, + { return HostDigamma; }); +DEFINE_UNARY_TEST_OP(LgammaOp, { return Lgamma; }, { return std::lgamma; }); +DEFINE_UNARY_TEST_OP(RoundOp, { return Round; }, { return std::round; }); +DEFINE_UNARY_TEST_OP( + RoundNearestEvenOp, { return RoundNearestEven; }, + { return std::nearbyint; }); +DEFINE_UNARY_TEST_OP( + ReciprocalOp, { return Reciprocal; }, + { return +[](Traits::NativeRefT x) { return 1 / x; }; }); + +#undef DEFINE_UNARY_TEST_OP diff --git a/third_party/xla/xla/tests/exhaustive/platform.cc b/third_party/xla/xla/tests/exhaustive/platform.cc new file mode 100644 index 00000000000000..704b1a5b8df9be --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/platform.cc @@ -0,0 +1,103 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tests/exhaustive/platform.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" + +namespace xla { +namespace exhaustive_op_test { + +Platform::Value GetPlatformValue(const stream_executor::Platform& platform) { + if (platform.Name() == "Host") { +// We process these copts in a library instead of the final exhaustive_xla_test +// target because we assume the final target will use the same target CPU arch +// as this target. +#ifdef __x86_64__ + return Platform::CpuValue::X86_64; +#endif +#ifdef __aarch64__ + return Platform::CpuValue::AARCH64; +#endif + } else if (platform.Name() == "CUDA") { + auto device_descriptor_status = platform.DescriptionForDevice(0); + CHECK_OK(device_descriptor_status); + std::unique_ptr device_descriptor = + std::move(*device_descriptor_status); + + auto cuda_compute_compatibility = + device_descriptor->cuda_compute_capability(); + // If not available, CudaComputeCompatibility will have major version 0. + if (cuda_compute_compatibility.IsAtLeast(1, 0)) { + return cuda_compute_compatibility; + } + } else if (platform.Name() == "ROCM") { + auto device_descriptor_status = platform.DescriptionForDevice(0); + CHECK_OK(device_descriptor_status); + std::unique_ptr device_descriptor = + std::move(*device_descriptor_status); + + auto rocm_compute_compatibility = + device_descriptor->rocm_compute_capability(); + // If not available, RocmComputeCompatibility will be an invalid platform + // value. + if (rocm_compute_compatibility.gfx_version() == "gfx000") { + return rocm_compute_compatibility; + } + } + LOG(FATAL) << "Unhandled stream_executor::Platform: " << platform.Name() + << ". Please add support to " __FILE__ "."; +} + +bool Platform::IsNvidiaP100() const { + return std::holds_alternative( + value_) && + !std::get(value_).IsAtLeast( + stream_executor::CudaComputeCapability::Volta()); +} + +bool Platform::IsNvidiaV100() const { + return std::holds_alternative( + value_) && + std::get(value_) == + stream_executor::CudaComputeCapability::Volta(); +} + +bool Platform::IsNvidiaA100() const { + return std::holds_alternative( + value_) && + std::get(value_) == + stream_executor::CudaComputeCapability::Ampere(); +} + +bool Platform::IsNvidiaH100() const { + return std::holds_alternative( + value_) && + std::get(value_) == + stream_executor::CudaComputeCapability::Hopper(); +} + +Platform::Platform(const stream_executor::Platform& platform) + : value_(GetPlatformValue(platform)) {} + +} // namespace exhaustive_op_test +} // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/platform.h b/third_party/xla/xla/tests/exhaustive/platform.h new file mode 100644 index 00000000000000..7728033ec5ea93 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/platform.h @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_EXHAUSTIVE_PLATFORM_H_ +#define XLA_TESTS_EXHAUSTIVE_PLATFORM_H_ + +#include + +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" + +namespace xla { +namespace exhaustive_op_test { + +// Represents an enum class of all possible openXLA execution platforms along +// with helper functions to categorically handle them. +class Platform { + public: + enum class CpuValue { + AARCH64, + X86_64, + }; + + using Value = std::variant; + + explicit Platform(const stream_executor::Platform& platform); + + bool IsCpu() const { return std::holds_alternative(value_); } + + bool IsGpu() const { + return std::holds_alternative( + value_) || + std::holds_alternative( + value_); + } + + bool IsNvidiaGpu() const { + return std::holds_alternative( + value_); + } + + bool IsNvidiaP100() const; + + bool IsNvidiaV100() const; + + bool IsNvidiaA100() const; + + bool IsNvidiaH100() const; + + bool IsAmdGpu() const { + return std::holds_alternative( + value_); + } + + const Value& value() const { return value_; } + + private: + const Value value_; +}; + +} // namespace exhaustive_op_test +} // namespace xla + +#endif // XLA_TESTS_EXHAUSTIVE_PLATFORM_H_ diff --git a/third_party/xla/xla/tests/exhaustive/test_op.h b/third_party/xla/xla/tests/exhaustive/test_op.h new file mode 100644 index 00000000000000..35ad4b51f69ad6 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/test_op.h @@ -0,0 +1,247 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_EXHAUSTIVE_TEST_OP_H_ +#define XLA_TESTS_EXHAUSTIVE_TEST_OP_H_ + +#include +#include + +#include "xla/tests/exhaustive/exhaustive_op_test.h" +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include "xla/tests/exhaustive/platform.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace exhaustive_op_test { + +// Declares a single exhaustive test operation. +// +// This class is intended to be subclassed by an actual operation implementation +// that configures EnqueueOp() and EvaluateOp() as necessary. +// +// The exhaustive test can be run using the Run() function defined here. +// +// Pure virtual functions: +// - EnqueueOp +// - EvaluateOp +template +class TestOp { + public: + using Traits = ExhaustiveOpTestTraits; + using Test = std::conditional_t< + N == 1, ExhaustiveUnaryTest, + std::conditional_t, + std::enable_if_t>>; + + explicit TestOp(Test* test) : test_(test) {} + + virtual ~TestOp() = default; + + virtual Traits::EnqueueOp EnqueueOp() const = 0; + virtual Traits::EvaluateOp EvaluateOp() const = 0; + + // Establish a verification check that each EnqueueOp() value is within range. + TestOp& OutputRangeCheck(Traits::OutputRangeCheck output_range_check) & { + output_range_check_ = output_range_check; + return *this; + } + TestOp&& OutputRangeCheck(Traits::OutputRangeCheck output_range_check) && { + output_range_check_ = output_range_check; + return std::move(*this); + } + + // The following methods set ErrorSpecGen for associated platforms. There is a + // precedence hierarchy to allow for easily setting fallbacks and overriding + // for certain platforms. + // + // CPU Precedence: + // CPU Make (x86, ARM, etc) Error -> CPU Error -> Error + // + // GPU Precedence: + // GPU Model (P100, V100, etc) Error -> GPU Make (Nvidia) Error -> GPU Error + // -> Error + + TestOp& Error(Traits::ErrorSpecGen error_spec_gen) & { + error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& Error(Traits::ErrorSpecGen error_spec_gen) && { + error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& CpuError(Traits::ErrorSpecGen error_spec_gen) & { + cpu_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& CpuError(Traits::ErrorSpecGen error_spec_gen) && { + cpu_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& CpuX86Error(Traits::ErrorSpecGen error_spec_gen) & { + cpu_x86_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& CpuX86Error(Traits::ErrorSpecGen error_spec_gen) && { + cpu_x86_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& CpuArmError(Traits::ErrorSpecGen error_spec_gen) & { + cpu_arm_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& CpuArmError(Traits::ErrorSpecGen error_spec_gen) && { + cpu_arm_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuError(Traits::ErrorSpecGen error_spec_gen) & { + gpu_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuError(Traits::ErrorSpecGen error_spec_gen) && { + gpu_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuNvidiaError(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuNvidiaError(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuP100Error(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_p100_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuP100Error(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_p100_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuV100Error(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_v100_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuV100Error(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_v100_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuA100Error(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_a100_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuA100Error(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_a100_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuH100Error(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_h100_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuH100Error(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_h100_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + // Execute the TestCase as configured. + // + // Requires invoking on a TestCase&& to ensure the TestCase is not used + // afterwards. + void Run() && { + typename Traits::ErrorSpecGen error_spec_gen; + if (test_->Platform().IsCpu()) { + switch (std::get(test_->Platform().value())) { + case Platform::CpuValue::X86_64: { + error_spec_gen = PickFirstErrorSpecGenPresent( + {cpu_x86_error_spec_gen_, cpu_error_spec_gen_, error_spec_gen_}); + break; + } + case Platform::CpuValue::AARCH64: { + error_spec_gen = PickFirstErrorSpecGenPresent( + {cpu_arm_error_spec_gen_, cpu_error_spec_gen_, error_spec_gen_}); + break; + } + default: { + error_spec_gen = PickFirstErrorSpecGenPresent( + {cpu_error_spec_gen_, error_spec_gen_}); + break; + } + } + } else if (test_->Platform().IsGpu()) { + if (test_->Platform().IsNvidiaGpu()) { + if (test_->Platform().IsNvidiaP100()) { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_p100_error_spec_gen_, gpu_nv_error_spec_gen_, + gpu_error_spec_gen_, error_spec_gen_}); + } else if (test_->Platform().IsNvidiaV100()) { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_v100_error_spec_gen_, gpu_nv_error_spec_gen_, + gpu_error_spec_gen_, error_spec_gen_}); + } else if (test_->Platform().IsNvidiaA100()) { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_a100_error_spec_gen_, gpu_nv_error_spec_gen_, + gpu_error_spec_gen_, error_spec_gen_}); + } else if (test_->Platform().IsNvidiaH100()) { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_h100_error_spec_gen_, gpu_nv_error_spec_gen_, + gpu_error_spec_gen_, error_spec_gen_}); + } else { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_error_spec_gen_, gpu_error_spec_gen_, error_spec_gen_}); + } + } else { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_error_spec_gen_, error_spec_gen_}); + } + } else { + error_spec_gen = PickFirstErrorSpecGenPresent({error_spec_gen_}); + } + test_->Run(EnqueueOp(), EvaluateOp(), error_spec_gen, output_range_check_); + } + + private: + Test* test_ = nullptr; + Traits::OutputRangeCheck output_range_check_ = nullptr; + Traits::ErrorSpecGen error_spec_gen_ = nullptr; + Traits::ErrorSpecGen cpu_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen cpu_x86_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen cpu_arm_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_p100_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_v100_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_a100_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_h100_error_spec_gen_ = nullptr; +}; + +template +using UnaryTestOp = TestOp; + +template +using BinaryTestOp = TestOp; + +} // namespace exhaustive_op_test +} // namespace xla + +#endif // XLA_TESTS_EXHAUSTIVE_TEST_OP_H_ From 308d127807e59f70673ec7fcc484f01ae17e35ae Mon Sep 17 00:00:00 2001 From: Augie Fackler Date: Wed, 25 Sep 2024 13:27:22 -0700 Subject: [PATCH 277/483] Integrate LLVM at llvm/llvm-project@9830156f623c Updates LLVM usage to match [9830156f623c](https://github.com/llvm/llvm-project/commit/9830156f623c) PiperOrigin-RevId: 678825303 --- third_party/llvm/generated.patch | 4094 ++++++++++++++++ third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 4110 ++++++++++++++++- third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 4110 ++++++++++++++++- .../xla/third_party/shardy/workspace.bzl | 4 +- 6 files changed, 12310 insertions(+), 16 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..de92cb4da63e52 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,4095 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst +--- a/llvm/docs/NVPTXUsage.rst ++++ b/llvm/docs/NVPTXUsage.rst +@@ -127,6 +127,69 @@ + NVPTX Intrinsics + ================ + ++Address Space Conversion ++------------------------ ++ ++'``llvm.nvvm.ptr.*.to.gen``' Intrinsics ++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ++ ++Syntax: ++""""""" ++ ++These are overloaded intrinsics. You can use these on any pointer types. ++ ++.. code-block:: llvm ++ ++ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) ++ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) ++ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) ++ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) ++ ++Overview: ++""""""""" ++ ++The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic ++address space to a generic address space pointer. ++ ++Semantics: ++"""""""""" ++ ++These intrinsics modify the pointer value to be a valid generic address space ++pointer. ++ ++ ++'``llvm.nvvm.ptr.gen.to.*``' Intrinsics ++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ++ ++Syntax: ++""""""" ++ ++These are overloaded intrinsics. You can use these on any pointer types. ++ ++.. code-block:: llvm ++ ++ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) ++ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) ++ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) ++ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) ++ ++Overview: ++""""""""" ++ ++The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic ++address space to a pointer in the target address space. Note that these ++intrinsics are only useful if the address space of the target address space of ++the pointer is known. It is not legal to use address space conversion ++intrinsics to convert a pointer from one non-generic address space to another ++non-generic address space. ++ ++Semantics: ++"""""""""" ++ ++These intrinsics modify the pointer value to be a valid pointer in the target ++non-generic address space. ++ ++ + Reading PTX Special Registers + ----------------------------- + +diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst +--- a/llvm/docs/ReleaseNotes.rst ++++ b/llvm/docs/ReleaseNotes.rst +@@ -63,24 +63,6 @@ + * ``llvm.nvvm.bitcast.d2ll`` + * ``llvm.nvvm.bitcast.ll2d`` + +-* Remove the following intrinsics which can be replaced with a funnel-shift: +- +- * ``llvm.nvvm.rotate.b32`` +- * ``llvm.nvvm.rotate.right.b64`` +- * ``llvm.nvvm.rotate.b64`` +- +-* Remove the following intrinsics which can be replaced with an +- ``addrspacecast``: +- +- * ``llvm.nvvm.ptr.gen.to.global`` +- * ``llvm.nvvm.ptr.gen.to.shared`` +- * ``llvm.nvvm.ptr.gen.to.constant`` +- * ``llvm.nvvm.ptr.gen.to.local`` +- * ``llvm.nvvm.ptr.global.to.gen`` +- * ``llvm.nvvm.ptr.shared.to.gen`` +- * ``llvm.nvvm.ptr.constant.to.gen`` +- * ``llvm.nvvm.ptr.local.to.gen`` +- + Changes to LLVM infrastructure + ------------------------------ + +diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td +--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td ++++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td +@@ -30,18 +30,10 @@ + // * llvm.nvvm.max.ui --> select(x ule y, x, y) + // * llvm.nvvm.max.ull --> ibid. + // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 +-// * llvm.nvvm.bitcast.f2i --> bitcast +-// * llvm.nvvm.bitcast.i2f --> ibid. +-// * llvm.nvvm.bitcast.d2ll --> ibid. +-// * llvm.nvvm.bitcast.ll2d --> ibid. +-// * llvm.nvvm.ptr.gen.to.global --> addrspacecast +-// * llvm.nvvm.ptr.gen.to.shared --> ibid. +-// * llvm.nvvm.ptr.gen.to.constant --> ibid. +-// * llvm.nvvm.ptr.gen.to.local --> ibid. +-// * llvm.nvvm.ptr.global.to.gen --> ibid. +-// * llvm.nvvm.ptr.shared.to.gen --> ibid. +-// * llvm.nvvm.ptr.constant.to.gen --> ibid. +-// * llvm.nvvm.ptr.local.to.gen --> ibid. ++// * llvm.nvvm.bitcast.f2i --> bitcast ++// * llvm.nvvm.bitcast.i2f --> ibid. ++// * llvm.nvvm.bitcast.d2ll --> ibid. ++// * llvm.nvvm.bitcast.ll2d --> ibid. + + def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr + def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr +@@ -1610,6 +1602,40 @@ + [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], + "llvm.nvvm.ldg.global.p">; + ++// Use for generic pointers ++// - These intrinsics are used to convert address spaces. ++// - The input pointer and output pointer must have the same type, except for ++// the address-space. (This restriction is not enforced here as there is ++// currently no way to describe it). ++// - This complements the llvm bitcast, which can be used to cast one type ++// of pointer to another type of pointer, while the address space remains ++// the same. ++def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], ++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], ++ "llvm.nvvm.ptr.local.to.gen">; ++def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], ++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], ++ "llvm.nvvm.ptr.shared.to.gen">; ++def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], ++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], ++ "llvm.nvvm.ptr.global.to.gen">; ++def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], ++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], ++ "llvm.nvvm.ptr.constant.to.gen">; ++ ++def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], ++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], ++ "llvm.nvvm.ptr.gen.to.global">; ++def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], ++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], ++ "llvm.nvvm.ptr.gen.to.shared">; ++def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], ++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], ++ "llvm.nvvm.ptr.gen.to.local">; ++def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], ++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], ++ "llvm.nvvm.ptr.gen.to.constant">; ++ + // Used in nvvm internally to help address space opt and ptx code generation + // This is for params that are passed to kernel functions by pointer by-val. + def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], +@@ -4453,6 +4479,22 @@ + "llvm.nvvm.sust.p.3d.v4i32.trap">, + ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; + ++ ++def int_nvvm_rotate_b32 ++ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], ++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, ++ ClangBuiltin<"__nvvm_rotate_b32">; ++ ++def int_nvvm_rotate_b64 ++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], ++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, ++ ClangBuiltin<"__nvvm_rotate_b64">; ++ ++def int_nvvm_rotate_right_b64 ++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], ++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, ++ ClangBuiltin<"__nvvm_rotate_right_b64">; ++ + def int_nvvm_swap_lo_hi_b64 + : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], + [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, +diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp +--- a/llvm/lib/IR/AutoUpgrade.cpp ++++ b/llvm/lib/IR/AutoUpgrade.cpp +@@ -1272,19 +1272,6 @@ + // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} + Expand = + Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; +- else if (Name.consume_front("rotate.")) +- // nvvm.rotate.{b32,b64,right.b64} +- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; +- else if (Name.consume_front("ptr.gen.to.")) +- // nvvm.ptr.gen.to.{local,shared,global,constant} +- Expand = Name.starts_with("local") || Name.starts_with("shared") || +- Name.starts_with("global") || Name.starts_with("constant"); +- else if (Name.consume_front("ptr.")) +- // nvvm.ptr.{local,shared,global,constant}.to.gen +- Expand = +- (Name.consume_front("local") || Name.consume_front("shared") || +- Name.consume_front("global") || Name.consume_front("constant")) && +- Name.starts_with(".to.gen"); + else + Expand = false; + +@@ -2271,117 +2258,6 @@ + } + } + +-static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, +- Function *F, IRBuilder<> &Builder) { +- Value *Rep = nullptr; +- +- if (Name == "abs.i" || Name == "abs.ll") { +- Value *Arg = CI->getArgOperand(0); +- Value *Neg = Builder.CreateNeg(Arg, "neg"); +- Value *Cmp = Builder.CreateICmpSGE( +- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); +- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); +- } else if (Name.starts_with("atomic.load.add.f32.p") || +- Name.starts_with("atomic.load.add.f64.p")) { +- Value *Ptr = CI->getArgOperand(0); +- Value *Val = CI->getArgOperand(1); +- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), +- AtomicOrdering::SequentiallyConsistent); +- } else if (Name.consume_front("max.") && +- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +- Name == "ui" || Name == "ull")) { +- Value *Arg0 = CI->getArgOperand(0); +- Value *Arg1 = CI->getArgOperand(1); +- Value *Cmp = Name.starts_with("u") +- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") +- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); +- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); +- } else if (Name.consume_front("min.") && +- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +- Name == "ui" || Name == "ull")) { +- Value *Arg0 = CI->getArgOperand(0); +- Value *Arg1 = CI->getArgOperand(1); +- Value *Cmp = Name.starts_with("u") +- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") +- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); +- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); +- } else if (Name == "clz.ll") { +- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. +- Value *Arg = CI->getArgOperand(0); +- Value *Ctlz = Builder.CreateCall( +- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, +- {Arg->getType()}), +- {Arg, Builder.getFalse()}, "ctlz"); +- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); +- } else if (Name == "popc.ll") { +- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an +- // i64. +- Value *Arg = CI->getArgOperand(0); +- Value *Popc = Builder.CreateCall( +- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, +- {Arg->getType()}), +- Arg, "ctpop"); +- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); +- } else if (Name == "h2f") { +- Rep = Builder.CreateCall( +- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, +- {Builder.getFloatTy()}), +- CI->getArgOperand(0), "h2f"); +- } else if (Name.consume_front("bitcast.") && +- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || +- Name == "d2ll")) { +- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); +- } else if (Name == "rotate.b32") { +- Value *Arg = CI->getOperand(0); +- Value *ShiftAmt = CI->getOperand(1); +- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, +- {Arg, Arg, ShiftAmt}); +- } else if (Name == "rotate.b64") { +- Type *Int64Ty = Builder.getInt64Ty(); +- Value *Arg = CI->getOperand(0); +- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); +- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, +- {Arg, Arg, ZExtShiftAmt}); +- } else if (Name == "rotate.right.b64") { +- Type *Int64Ty = Builder.getInt64Ty(); +- Value *Arg = CI->getOperand(0); +- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); +- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, +- {Arg, Arg, ZExtShiftAmt}); +- } else if ((Name.consume_front("ptr.gen.to.") && +- (Name.starts_with("local") || Name.starts_with("shared") || +- Name.starts_with("global") || Name.starts_with("constant"))) || +- (Name.consume_front("ptr.") && +- (Name.consume_front("local") || Name.consume_front("shared") || +- Name.consume_front("global") || +- Name.consume_front("constant")) && +- Name.starts_with(".to.gen"))) { +- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); +- } else { +- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); +- if (IID != Intrinsic::not_intrinsic && +- !F->getReturnType()->getScalarType()->isBFloatTy()) { +- rename(F); +- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); +- SmallVector Args; +- for (size_t I = 0; I < NewFn->arg_size(); ++I) { +- Value *Arg = CI->getArgOperand(I); +- Type *OldType = Arg->getType(); +- Type *NewType = NewFn->getArg(I)->getType(); +- Args.push_back( +- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) +- ? Builder.CreateBitCast(Arg, NewType) +- : Arg); +- } +- Rep = Builder.CreateCall(NewFn, Args); +- if (F->getReturnType()->isIntegerTy()) +- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); +- } +- } +- +- return Rep; +-} +- + static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, + IRBuilder<> &Builder) { + LLVMContext &C = F->getContext(); +@@ -4332,8 +4208,85 @@ + + if (!IsX86 && Name == "stackprotectorcheck") { + Rep = nullptr; ++ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { ++ Value *Arg = CI->getArgOperand(0); ++ Value *Neg = Builder.CreateNeg(Arg, "neg"); ++ Value *Cmp = Builder.CreateICmpSGE( ++ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); ++ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); ++ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || ++ Name.starts_with("atomic.load.add.f64.p"))) { ++ Value *Ptr = CI->getArgOperand(0); ++ Value *Val = CI->getArgOperand(1); ++ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), ++ AtomicOrdering::SequentiallyConsistent); ++ } else if (IsNVVM && Name.consume_front("max.") && ++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || ++ Name == "ui" || Name == "ull")) { ++ Value *Arg0 = CI->getArgOperand(0); ++ Value *Arg1 = CI->getArgOperand(1); ++ Value *Cmp = Name.starts_with("u") ++ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") ++ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); ++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); ++ } else if (IsNVVM && Name.consume_front("min.") && ++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || ++ Name == "ui" || Name == "ull")) { ++ Value *Arg0 = CI->getArgOperand(0); ++ Value *Arg1 = CI->getArgOperand(1); ++ Value *Cmp = Name.starts_with("u") ++ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") ++ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); ++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); ++ } else if (IsNVVM && Name == "clz.ll") { ++ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. ++ Value *Arg = CI->getArgOperand(0); ++ Value *Ctlz = Builder.CreateCall( ++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, ++ {Arg->getType()}), ++ {Arg, Builder.getFalse()}, "ctlz"); ++ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); ++ } else if (IsNVVM && Name == "popc.ll") { ++ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an ++ // i64. ++ Value *Arg = CI->getArgOperand(0); ++ Value *Popc = Builder.CreateCall( ++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, ++ {Arg->getType()}), ++ Arg, "ctpop"); ++ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); + } else if (IsNVVM) { +- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); ++ if (Name == "h2f") { ++ Rep = ++ Builder.CreateCall(Intrinsic::getDeclaration( ++ F->getParent(), Intrinsic::convert_from_fp16, ++ {Builder.getFloatTy()}), ++ CI->getArgOperand(0), "h2f"); ++ } else if (Name.consume_front("bitcast.") && ++ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || ++ Name == "d2ll")) { ++ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); ++ } else { ++ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); ++ if (IID != Intrinsic::not_intrinsic && ++ !F->getReturnType()->getScalarType()->isBFloatTy()) { ++ rename(F); ++ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); ++ SmallVector Args; ++ for (size_t I = 0; I < NewFn->arg_size(); ++I) { ++ Value *Arg = CI->getArgOperand(I); ++ Type *OldType = Arg->getType(); ++ Type *NewType = NewFn->getArg(I)->getType(); ++ Args.push_back((OldType->isIntegerTy() && ++ NewType->getScalarType()->isBFloatTy()) ++ ? Builder.CreateBitCast(Arg, NewType) ++ : Arg); ++ } ++ Rep = Builder.CreateCall(NewFn, Args); ++ if (F->getReturnType()->isIntegerTy()) ++ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); ++ } ++ } + } else if (IsX86) { + Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); + } else if (IsARM) { +diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp ++++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +@@ -292,7 +292,6 @@ + static const LLT S224 = LLT::scalar(224); + static const LLT S256 = LLT::scalar(256); + static const LLT S512 = LLT::scalar(512); +-static const LLT S1024 = LLT::scalar(1024); + static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); + + static const LLT V2S8 = LLT::fixed_vector(2, 8); +@@ -333,8 +332,8 @@ + static const LLT V2S128 = LLT::fixed_vector(2, 128); + static const LLT V4S128 = LLT::fixed_vector(4, 128); + +-static std::initializer_list AllScalarTypes = { +- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; ++static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, ++ S160, S224, S256, S512}; + + static std::initializer_list AllS16Vectors{ + V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; +@@ -890,11 +889,10 @@ + .clampScalar(0, S16, S64); + + getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) +- .legalIf(isRegisterClassType(0)) ++ .legalIf(isRegisterType(0)) + // s1 and s16 are special cases because they have legal operations on + // them, but don't really occupy registers in the normal way. + .legalFor({S1, S16}) +- .clampNumElements(0, V16S32, V32S32) + .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) + .clampScalarOrElt(0, S32, MaxScalar) + .widenScalarToNextPow2(0, 32) +diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +@@ -174,6 +174,10 @@ + def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" + "&& Subtarget->getPTXVersion() >= 64)">; + ++def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; ++def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; ++def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; ++ + def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; + def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; + +@@ -1661,6 +1665,167 @@ + "brev.b64 \t$dst, $a;", + [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; + ++// ++// Rotate: Use ptx shf instruction if available. ++// ++ ++// 32 bit r2 = rotl r1, n ++// => ++// r2 = shf.l r1, r1, n ++def ROTL32imm_hw : ++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), ++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", ++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, ++ Requires<[hasHWROT32]>; ++ ++def ROTL32reg_hw : ++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), ++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", ++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, ++ Requires<[hasHWROT32]>; ++ ++// 32 bit r2 = rotr r1, n ++// => ++// r2 = shf.r r1, r1, n ++def ROTR32imm_hw : ++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), ++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", ++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, ++ Requires<[hasHWROT32]>; ++ ++def ROTR32reg_hw : ++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), ++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", ++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, ++ Requires<[hasHWROT32]>; ++ ++// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. ++def ROT32imm_sw : ++ NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), ++ "{{\n\t" ++ ".reg .b32 %lhs;\n\t" ++ ".reg .b32 %rhs;\n\t" ++ "shl.b32 \t%lhs, $src, $amt1;\n\t" ++ "shr.b32 \t%rhs, $src, $amt2;\n\t" ++ "add.u32 \t$dst, %lhs, %rhs;\n\t" ++ "}}", ++ []>; ++ ++def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); ++}]>; ++ ++def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), ++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, ++ Requires<[noHWROT32]>; ++def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), ++ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, ++ Requires<[noHWROT32]>; ++ ++// 32-bit software rotate left by register. ++def ROTL32reg_sw : ++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), ++ "{{\n\t" ++ ".reg .b32 %lhs;\n\t" ++ ".reg .b32 %rhs;\n\t" ++ ".reg .b32 %amt2;\n\t" ++ "shl.b32 \t%lhs, $src, $amt;\n\t" ++ "sub.s32 \t%amt2, 32, $amt;\n\t" ++ "shr.b32 \t%rhs, $src, %amt2;\n\t" ++ "add.u32 \t$dst, %lhs, %rhs;\n\t" ++ "}}", ++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, ++ Requires<[noHWROT32]>; ++ ++// 32-bit software rotate right by register. ++def ROTR32reg_sw : ++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), ++ "{{\n\t" ++ ".reg .b32 %lhs;\n\t" ++ ".reg .b32 %rhs;\n\t" ++ ".reg .b32 %amt2;\n\t" ++ "shr.b32 \t%lhs, $src, $amt;\n\t" ++ "sub.s32 \t%amt2, 32, $amt;\n\t" ++ "shl.b32 \t%rhs, $src, %amt2;\n\t" ++ "add.u32 \t$dst, %lhs, %rhs;\n\t" ++ "}}", ++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, ++ Requires<[noHWROT32]>; ++ ++// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. ++def ROT64imm_sw : ++ NVPTXInst<(outs Int64Regs:$dst), ++ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), ++ "{{\n\t" ++ ".reg .b64 %lhs;\n\t" ++ ".reg .b64 %rhs;\n\t" ++ "shl.b64 \t%lhs, $src, $amt1;\n\t" ++ "shr.b64 \t%rhs, $src, $amt2;\n\t" ++ "add.u64 \t$dst, %lhs, %rhs;\n\t" ++ "}}", ++ []>; ++ ++def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); ++}]>; ++ ++def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), ++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; ++def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), ++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; ++ ++// 64-bit software rotate left by register. ++def ROTL64reg_sw : ++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), ++ "{{\n\t" ++ ".reg .b64 %lhs;\n\t" ++ ".reg .b64 %rhs;\n\t" ++ ".reg .u32 %amt2;\n\t" ++ "and.b32 \t%amt2, $amt, 63;\n\t" ++ "shl.b64 \t%lhs, $src, %amt2;\n\t" ++ "sub.u32 \t%amt2, 64, %amt2;\n\t" ++ "shr.b64 \t%rhs, $src, %amt2;\n\t" ++ "add.u64 \t$dst, %lhs, %rhs;\n\t" ++ "}}", ++ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; ++ ++def ROTR64reg_sw : ++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), ++ "{{\n\t" ++ ".reg .b64 %lhs;\n\t" ++ ".reg .b64 %rhs;\n\t" ++ ".reg .u32 %amt2;\n\t" ++ "and.b32 \t%amt2, $amt, 63;\n\t" ++ "shr.b64 \t%lhs, $src, %amt2;\n\t" ++ "sub.u32 \t%amt2, 64, %amt2;\n\t" ++ "shl.b64 \t%rhs, $src, %amt2;\n\t" ++ "add.u64 \t$dst, %lhs, %rhs;\n\t" ++ "}}", ++ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; ++ ++// ++// Funnnel shift in clamp mode ++// ++ ++// Create SDNodes so they can be used in the DAG code, e.g. ++// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) ++def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; ++def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; ++ ++def FUNSHFLCLAMP : ++ NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), ++ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", ++ [(set Int32Regs:$dst, ++ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; ++ ++def FUNSHFRCLAMP : ++ NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), ++ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", ++ [(set Int32Regs:$dst, ++ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; + + // + // BFE - bit-field extract +@@ -3492,42 +3657,6 @@ + def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), + (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; + +-// +-// Funnel-Shift +-// +- +-// Create SDNodes so they can be used in the DAG code, e.g. +-// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) +-def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; +-def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; +- +-// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so +-// no side effects. +-let hasSideEffects = false in { +- multiclass ShfInst { +- def _i +- : NVPTXInst<(outs Int32Regs:$dst), +- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", +- [(set Int32Regs:$dst, +- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, +- Requires<[hasHWROT32]>; +- +- def _r +- : NVPTXInst<(outs Int32Regs:$dst), +- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", +- [(set Int32Regs:$dst, +- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, +- Requires<[hasHWROT32]>; +- } +- +- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; +- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; +- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; +- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; +-} +- + // Count leading zeros + let hasSideEffects = false in { + def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), +diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td ++++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +@@ -2537,45 +2537,59 @@ + : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; + + +-multiclass NG_TO_G { ++multiclass NG_TO_G { + def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), +- "cvta." # Str # ".u32 \t$result, $src;", []>; ++ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), ++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; + def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), +- "cvta." # Str # ".u64 \t$result, $src;", []>; ++ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), ++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; ++ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), ++ "{{ .reg .b64 %tmp;\n\t" ++ #" cvt.u64.u32 \t%tmp, $src;\n\t" ++ #" cvta." # Str # ".u64 \t$result, %tmp; }}", ++ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, ++ Requires<[ShortPtr]>; + } + +-multiclass G_TO_NG { ++multiclass G_TO_NG { + def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), +- "cvta.to." # Str # ".u32 \t$result, $src;", []>; ++ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), ++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; + def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), +- "cvta.to." # Str # ".u64 \t$result, $src;", []>; ++ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), ++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; ++ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), ++ "{{ .reg .b64 %tmp;\n\t" ++ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" ++ #" cvt.u32.u64 \t$result, %tmp; }}", ++ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, ++ Requires<[ShortPtr]>; + } + +-defm cvta_local : NG_TO_G<"local">; +-defm cvta_shared : NG_TO_G<"shared">; +-defm cvta_global : NG_TO_G<"global">; +-defm cvta_const : NG_TO_G<"const">; +- +-defm cvta_to_local : G_TO_NG<"local">; +-defm cvta_to_shared : G_TO_NG<"shared">; +-defm cvta_to_global : G_TO_NG<"global">; +-defm cvta_to_const : G_TO_NG<"const">; +- +-// nvvm.ptr.param.to.gen +-defm cvta_param : NG_TO_G<"param">; +- +-def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), +- (cvta_param Int32Regs:$src)>; +- +-def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), +- (cvta_param_64 Int64Regs:$src)>; ++defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; ++defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; ++defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; ++defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; ++defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; ++ ++defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; ++defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; ++defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; ++defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; + + // nvvm.ptr.gen.to.param +-def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), +- (IMOV32rr Int32Regs:$src)>; ++def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), ++ (ins Int32Regs:$src), ++ "mov.u32 \t$result, $src;", ++ [(set Int32Regs:$result, ++ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; ++def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), ++ (ins Int64Regs:$src), ++ "mov.u64 \t$result, $src;", ++ [(set Int64Regs:$result, ++ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; + +-def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), +- (IMOV64rr Int64Regs:$src)>; + + // nvvm.move intrinsicc + def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), +@@ -2618,6 +2632,24 @@ + [(set Int64Regs:$r, + (int_nvvm_move_ptr texternalsym:$s))]>;*/ + ++ ++// MoveParam %r1, param ++// ptr_local_to_gen %r2, %r1 ++// ptr_gen_to_local %r3, %r2 ++// -> ++// mov %r1, param ++ ++// @TODO: Revisit this. There is a type ++// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym ++// instructions are not currently defined. However, we can use the ptr ++// variants and the asm printer will do the right thing. ++def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen ++ (MoveParam texternalsym:$src)))), ++ (nvvm_move_ptr64 texternalsym:$src)>; ++def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen ++ (MoveParam texternalsym:$src)))), ++ (nvvm_move_ptr32 texternalsym:$src)>; ++ + def texsurf_handles + : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), + "mov.u64 \t$result, $src;", []>; +@@ -2701,9 +2733,134 @@ + def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; + + ++// rotate builtin support ++ ++def ROTATE_B32_HW_IMM ++ : NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$src, i32imm:$amt), ++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", ++ [(set Int32Regs:$dst, ++ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, ++ Requires<[hasHWROT32]> ; ++ ++def ROTATE_B32_HW_REG ++ : NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$src, Int32Regs:$amt), ++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", ++ [(set Int32Regs:$dst, ++ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, ++ Requires<[hasHWROT32]> ; ++ ++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), ++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, ++ Requires<[noHWROT32]> ; ++ ++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), ++ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, ++ Requires<[noHWROT32]> ; ++ ++let hasSideEffects = false in { ++ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), ++ !strconcat("{{\n\t", ++ ".reg .b32 %dummy;\n\t", ++ "mov.b64 \t{$dst,%dummy}, $src;\n\t", ++ "}}"), ++ []> ; ++ ++ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), ++ !strconcat("{{\n\t", ++ ".reg .b32 %dummy;\n\t", ++ "mov.b64 \t{%dummy,$dst}, $src;\n\t", ++ "}}"), ++ []> ; ++} ++ ++let hasSideEffects = false in { ++ def PACK_TWO_INT32 ++ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), ++ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; ++} ++ + def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), +- (V2I32toI64 (I64toI32H Int64Regs:$src), +- (I64toI32L Int64Regs:$src))> ; ++ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), ++ (GET_LO_INT64 Int64Regs:$src))> ; ++ ++// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so ++// no side effects. ++let hasSideEffects = false in { ++ def SHF_L_WRAP_B32_IMM ++ : NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), ++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, ++ Requires<[hasHWROT32]>; ++ ++ def SHF_L_WRAP_B32_REG ++ : NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), ++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, ++ Requires<[hasHWROT32]>; ++ ++ def SHF_R_WRAP_B32_IMM ++ : NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), ++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, ++ Requires<[hasHWROT32]>; ++ ++ def SHF_R_WRAP_B32_REG ++ : NVPTXInst<(outs Int32Regs:$dst), ++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), ++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, ++ Requires<[hasHWROT32]>; ++} ++ ++// HW version of rotate 64 ++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), ++ (PACK_TWO_INT32 ++ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), ++ (GET_LO_INT64 Int64Regs:$src), imm:$amt), ++ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), ++ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, ++ Requires<[hasHWROT32]>; ++ ++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), ++ (PACK_TWO_INT32 ++ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), ++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), ++ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), ++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, ++ Requires<[hasHWROT32]>; ++ ++ ++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), ++ (PACK_TWO_INT32 ++ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), ++ (GET_HI_INT64 Int64Regs:$src), imm:$amt), ++ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), ++ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, ++ Requires<[hasHWROT32]>; ++ ++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), ++ (PACK_TWO_INT32 ++ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), ++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), ++ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), ++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, ++ Requires<[hasHWROT32]>; ++ ++// SW version of rotate 64 ++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), ++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, ++ Requires<[noHWROT32]>; ++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), ++ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, ++ Requires<[noHWROT32]>; ++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), ++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, ++ Requires<[noHWROT32]>; ++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), ++ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, ++ Requires<[noHWROT32]>; ++ + + //----------------------------------- + // Texture Intrinsics +diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp ++++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +@@ -1109,21 +1109,11 @@ + AddrSpaceCastSDNode *CastN = cast(N); + unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); + unsigned DstAddrSpace = CastN->getDestAddressSpace(); +- SDLoc DL(N); + assert(SrcAddrSpace != DstAddrSpace && + "addrspacecast must be between different address spaces"); + + if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { + // Specific to generic +- +- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { +- SDValue CvtNone = +- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); +- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, +- Src, CvtNone); +- Src = SDValue(Cvt, 0); +- } +- + unsigned Opc; + switch (SrcAddrSpace) { + default: report_fatal_error("Bad address space in addrspacecast"); +@@ -1131,16 +1121,26 @@ + Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; + break; + case ADDRESS_SPACE_SHARED: +- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; ++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 ++ ? NVPTX::cvta_shared_6432 ++ : NVPTX::cvta_shared_64) ++ : NVPTX::cvta_shared; + break; + case ADDRESS_SPACE_CONST: +- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; ++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 ++ ? NVPTX::cvta_const_6432 ++ : NVPTX::cvta_const_64) ++ : NVPTX::cvta_const; + break; + case ADDRESS_SPACE_LOCAL: +- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; ++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 ++ ? NVPTX::cvta_local_6432 ++ : NVPTX::cvta_local_64) ++ : NVPTX::cvta_local; + break; + } +- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); ++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), ++ Src)); + return; + } else { + // Generic to specific +@@ -1153,28 +1153,30 @@ + Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; + break; + case ADDRESS_SPACE_SHARED: +- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; ++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 ++ ? NVPTX::cvta_to_shared_3264 ++ : NVPTX::cvta_to_shared_64) ++ : NVPTX::cvta_to_shared; + break; + case ADDRESS_SPACE_CONST: +- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; ++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 ++ ? NVPTX::cvta_to_const_3264 ++ : NVPTX::cvta_to_const_64) ++ : NVPTX::cvta_to_const; + break; + case ADDRESS_SPACE_LOCAL: +- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; ++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 ++ ? NVPTX::cvta_to_local_3264 ++ : NVPTX::cvta_to_local_64) ++ : NVPTX::cvta_to_local; + break; + case ADDRESS_SPACE_PARAM: +- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; ++ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 ++ : NVPTX::nvvm_ptr_gen_to_param; + break; + } +- +- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); +- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { +- SDValue CvtNone = +- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); +- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, +- SDValue(CVTA, 0), CvtNone); +- } +- +- ReplaceNode(N, CVTA); ++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), ++ Src)); + return; + } + } +diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp ++++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +@@ -594,13 +594,20 @@ + setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); + setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); + +- setOperationAction({ISD::ROTL, ISD::ROTR}, +- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, +- Expand); +- +- if (STI.hasHWROT32()) +- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); ++ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs ++ // that don't have h/w rotation we lower them to multi-instruction assembly. ++ // See ROT*_sw in NVPTXIntrInfo.td ++ setOperationAction(ISD::ROTL, MVT::i64, Legal); ++ setOperationAction(ISD::ROTR, MVT::i64, Legal); ++ setOperationAction(ISD::ROTL, MVT::i32, Legal); ++ setOperationAction(ISD::ROTR, MVT::i32, Legal); + ++ setOperationAction(ISD::ROTL, MVT::i16, Expand); ++ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); ++ setOperationAction(ISD::ROTR, MVT::i16, Expand); ++ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); ++ setOperationAction(ISD::ROTL, MVT::i8, Expand); ++ setOperationAction(ISD::ROTR, MVT::i8, Expand); + setOperationAction(ISD::BSWAP, MVT::i16, Expand); + + setOperationAction(ISD::BR_JT, MVT::Other, Custom); +diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +--- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll ++++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +@@ -31,19 +31,6 @@ + declare i64 @llvm.nvvm.bitcast.d2ll(double) + declare double @llvm.nvvm.bitcast.ll2d(i64) + +-declare i32 @llvm.nvvm.rotate.b32(i32, i32) +-declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) +-declare i64 @llvm.nvvm.rotate.b64(i64, i32) +- +-declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +-declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) +-declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) +-declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) +-declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) +-declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) +-declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) +-declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) +- + ; CHECK-LABEL: @simple_upgrade + define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { + ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) +@@ -152,42 +139,4 @@ + %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) + + ret void +-} +- +-; CHECK-LABEL: @rotate +-define void @rotate(i32 %a, i64 %b) { +-; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) +-; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) +-; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) +-; +- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) +- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) +- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) +- ret void +-} +- +-; CHECK-LABEL: @addrspacecast +-define void @addrspacecast(ptr %p0) { +-; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) +-; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr +-; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) +-; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr +-; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) +-; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr +-; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) +-; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr +-; +- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) +- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) +- +- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) +- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) +- +- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) +- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) +- +- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) +- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) +- +- ret void +-} ++} +\ No newline at end of file +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll +--- a/llvm/test/CodeGen/AMDGPU/freeze.ll ++++ b/llvm/test/CodeGen/AMDGPU/freeze.ll +@@ -1,1856 +0,0 @@ +-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s +-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s +-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s +-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s +- +-define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_v2i32: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_v2i32: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <2 x i32> %a +- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_v3i32: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_v3i32: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <3 x i32> %a +- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_v4i32: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_v4i32: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <4 x i32> %a +- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v5i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x1 +-; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v5i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x1 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v5i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x1 +-; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v5i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x1 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <5 x i32> %a +- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v6i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x1 +-; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v6i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x1 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v6i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x1 +-; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v6i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x1 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <6 x i32> %a +- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v7i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x1 +-; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v7i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x1 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v7i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x1 +-; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v7i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x1 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <7 x i32> %a +- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v8i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x1 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v8i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x1 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v8i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x1 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v8i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x1 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <8 x i32> %a +- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v9i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x2 +-; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v9i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x2 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v9i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x2 +-; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v9i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x2 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <9 x i32> %a +- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_v10i32: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: s_clause 0x2 +-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 +-; GFX10-NEXT: s_waitcnt vmcnt(2) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-NEXT: s_waitcnt vmcnt(1) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_v10i32: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: s_clause 0x2 +-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 +-; GFX11-NEXT: s_waitcnt vmcnt(2) +-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-NEXT: s_waitcnt vmcnt(1) +-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <10 x i32> %a +- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_v11i32: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: s_clause 0x2 +-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 +-; GFX10-NEXT: s_waitcnt vmcnt(2) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-NEXT: s_waitcnt vmcnt(1) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_v11i32: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: s_clause 0x2 +-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 +-; GFX11-NEXT: s_waitcnt vmcnt(2) +-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-NEXT: s_waitcnt vmcnt(1) +-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <11 x i32> %a +- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_v12i32: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: s_clause 0x2 +-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-NEXT: s_waitcnt vmcnt(2) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-NEXT: s_waitcnt vmcnt(1) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_v12i32: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: s_clause 0x2 +-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-NEXT: s_waitcnt vmcnt(2) +-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-NEXT: s_waitcnt vmcnt(1) +-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <12 x i32> %a +- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +-define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v13i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x3 +-; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v13i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x3 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v13i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x3 +-; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v13i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x3 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <13 x i32> %a +- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v14i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x3 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v14i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x3 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v14i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x3 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v14i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x3 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <14 x i32> %a +- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v15i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x3 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v15i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x3 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v15i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x3 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v15i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x3 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <15 x i32> %a +- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v16i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x3 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v16i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x3 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v16i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x3 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v16i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x3 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <16 x i32> %a +- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v17i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x4 +-; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v17i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x4 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v17i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x4 +-; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v17i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x4 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <17 x i32> %a +- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v18i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x4 +-; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v18i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x4 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v18i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x4 +-; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v18i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x4 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <18 x i32> %a +- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v19i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x4 +-; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v19i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x4 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v19i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x4 +-; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v19i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x4 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <19 x i32> %a +- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v20i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x4 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v20i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x4 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v20i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x4 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v20i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x4 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <20 x i32> %a +- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v21i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x5 +-; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v21i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x5 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v21i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x5 +-; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v21i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x5 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <21 x i32> %a +- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v22i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x5 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v22i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x5 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v22i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x5 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v22i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x5 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <22 x i32> %a +- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v30i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x7 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +-; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v30i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x7 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +-; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v30i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x7 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +-; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v30i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x7 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <30 x i32> %a +- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v31i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x7 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +-; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v31i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x7 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +-; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v31i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x7 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +-; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v31i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x7 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +-; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <31 x i32> %a +- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_v32i32: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x7 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_v32i32: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x7 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_v32i32: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x7 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 +-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 +-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 +-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 +-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 +-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off +-; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_v32i32: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x7 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +-; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze <32 x i32> %a +- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_i32: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: global_load_dword v0, v[0:1], off +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dword v[2:3], v0, off +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_i32: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: global_load_b32 v0, v[0:1], off +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b32 v[2:3], v0, off +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load i32, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze i32 %a +- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_i64: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_i64: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load i64, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze i64 %a +- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_float: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: global_load_dword v0, v[0:1], off +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dword v[2:3], v0, off +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_float: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: global_load_b32 v0, v[0:1], off +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b32 v[2:3], v0, off +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load float, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze float %a +- store float %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-LABEL: freeze_i128: +-; GFX10: ; %bb.0: +-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-NEXT: s_waitcnt vmcnt(0) +-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-LABEL: freeze_i128: +-; GFX11: ; %bb.0: +-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-NEXT: s_waitcnt vmcnt(0) +-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-NEXT: s_setpc_b64 s[30:31] +- %a = load i128, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze i128 %a +- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +- +-define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +-; GFX10-SDAG-LABEL: freeze_i256: +-; GFX10-SDAG: ; %bb.0: +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-SDAG-NEXT: s_clause 0x1 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 +-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 +-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX10-GISEL-LABEL: freeze_i256: +-; GFX10-GISEL: ; %bb.0: +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX10-GISEL-NEXT: s_clause 0x1 +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-SDAG-LABEL: freeze_i256: +-; GFX11-SDAG: ; %bb.0: +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-SDAG-NEXT: s_clause 0x1 +-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 +-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 +-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +-; +-; GFX11-GISEL-LABEL: freeze_i256: +-; GFX11-GISEL: ; %bb.0: +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX11-GISEL-NEXT: s_clause 0x1 +-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +- %a = load i256, ptr addrspace(1) %ptra, align 4 +- %freeze = freeze i256 %a +- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 +- ret void +-} +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir ++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +@@ -171,9 +171,11 @@ + ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 + ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 + ; GCN-NEXT: {{ $}} +- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF +- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) +- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) ++ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF ++ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 ++ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 ++ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 ++ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] + %0:sgpr(s192) = G_IMPLICIT_DEF + %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 + S_ENDPGM 0, implicit %1, implicit %2, implicit %3 +@@ -292,11 +294,11 @@ + ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) + ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) + ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) +- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) +- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) +- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) +- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) +- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) ++ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) ++ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) ++ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) ++ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) ++ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) + %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 + %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 + %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir ++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +@@ -171,8 +171,12 @@ + + ; CHECK-LABEL: name: test_freeze_s448 + ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 +- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] +- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) ++ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) ++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] ++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) ++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) ++ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) + %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 + %1:_(s448) = G_TRUNC %0 + %2:_(s448) = G_FREEZE %1 +@@ -395,12 +399,14 @@ + bb.0: + + ; CHECK-LABEL: name: test_freeze_v33s32 +- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF ++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF + ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] +- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) +- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) ++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] ++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] ++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] ++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) ++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) ++ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) + ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) + %0:_(<33 x s32>) = G_IMPLICIT_DEF + %1:_(<33 x s32>) = G_FREEZE %0 +@@ -413,10 +419,12 @@ + bb.0: + + ; CHECK-LABEL: name: test_freeze_v64s32 +- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) ++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] ++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] ++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] ++ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] ++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) + ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) + %0:_(<64 x s32>) = G_IMPLICIT_DEF + %1:_(<64 x s32>) = G_FREEZE %0 +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir ++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +@@ -135,9 +135,8 @@ + bb.0: + + ; CHECK-LABEL: name: test_implicit_def_s448 +- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) +- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 ++ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 + ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) + %0:_(s448) = G_IMPLICIT_DEF + %1:_(s32) = G_EXTRACT %0, 0 +@@ -297,6 +296,18 @@ + ... + + --- ++name: test_implicit_def_v17s32 ++body: | ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_implicit_def_v17s32 ++ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) ++ %0:_(<17 x s32>) = G_IMPLICIT_DEF ++ S_NOP 0, implicit %0 ++... ++ ++--- + name: test_implicit_def_v32s32 + body: | + bb.0: +@@ -317,9 +328,9 @@ + ; CHECK-LABEL: name: test_implicit_def_v33s32 + ; CHECK: liveins: $vgpr0_vgpr1 + ; CHECK-NEXT: {{ $}} +- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF + ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) + ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 + ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) + ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) +@@ -337,9 +348,10 @@ + bb.0: + + ; CHECK-LABEL: name: test_implicit_def_v64s32 +- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) +- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) ++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) + %0:_(<64 x s32>) = G_IMPLICIT_DEF + %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 + S_NOP 0, implicit %0, implicit %1 +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir ++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +@@ -190,11 +190,13 @@ + ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 + ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 + ; CHECK-NEXT: {{ $}} +- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF + ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 + ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 +- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) + ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) + ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 + ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) +@@ -241,8 +243,10 @@ + ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 + ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) + ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) +- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) + ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) + ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) + ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir ++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +@@ -673,86 +673,88 @@ + ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) + ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 + ; CHECK-NEXT: {{ $}} +- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF + ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 + ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 + ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] +- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) + ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 + ; CHECK-NEXT: G_BR %bb.2 + ; CHECK-NEXT: {{ $}} + ; CHECK-NEXT: bb.1: + ; CHECK-NEXT: successors: %bb.2(0x80000000) + ; CHECK-NEXT: {{ $}} +- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] +- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] +- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] +- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] +- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] +- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] +- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] +- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] +- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] +- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] +- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] +- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] +- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] +- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] +- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] +- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] +- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] +- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] +- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] +- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] +- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] +- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] +- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] +- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] +- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] +- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] +- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] +- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] +- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] +- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] +- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] +- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] +- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] +- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] +- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] +- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] +- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] +- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] +- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] +- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] +- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] +- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] +- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] +- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] +- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] +- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] +- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] +- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] +- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] +- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] +- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] +- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] +- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] +- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] +- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] +- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] +- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] +- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] +- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] +- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] +- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] +- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] +- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] +- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] ++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] ++ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] ++ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] ++ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] ++ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] ++ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] ++ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] ++ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] ++ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] ++ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] ++ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] ++ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] ++ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] ++ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] ++ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] ++ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] ++ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] ++ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] ++ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] ++ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] ++ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] ++ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] ++ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] ++ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] ++ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] ++ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] ++ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] ++ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] ++ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] ++ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] ++ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] ++ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] ++ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] ++ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] ++ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] ++ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] ++ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] ++ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] ++ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] ++ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] ++ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] ++ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] ++ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] ++ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] ++ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] ++ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] ++ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] ++ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] ++ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] ++ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] ++ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] ++ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] ++ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] ++ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] ++ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] ++ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] ++ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] ++ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] ++ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] ++ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] ++ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] ++ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] ++ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] ++ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] + ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) + ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) + ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) +@@ -760,10 +762,10 @@ + ; CHECK-NEXT: G_BR %bb.2 + ; CHECK-NEXT: {{ $}} + ; CHECK-NEXT: bb.2: +- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 +- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 +- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 +- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 ++ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 ++ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 ++ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 ++ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 + ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) + ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) + bb.0: +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir ++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +@@ -42,6 +42,8 @@ + ret void + } + ++ define void @non_power_of_2() { ret void } ++ + define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { + ret void + } +@@ -185,6 +187,23 @@ + ... + + --- ++name: non_power_of_2 ++legalized: true ++ ++body: | ++ bb.0: ++ ; CHECK-LABEL: name: non_power_of_2 ++ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 ++ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) ++ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 ++ %0:_(s448) = G_IMPLICIT_DEF ++ %1:_(s32) = G_EXTRACT %0:_(s448), 0 ++ $sgpr0 = COPY %1:_(s32) ++ SI_RETURN_TO_EPILOG $sgpr0 ++... ++ ++--- + name: load_constant_v4i16_from_8_align8 + legalized: true + +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +--- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll ++++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +@@ -0,0 +1,21 @@ ++; RUN: opt < %s -O3 -S | FileCheck %s ++ ++; Address space intrinsics were erroneously marked NoCapture, leading to bad ++; optimizations (such as the store below being eliminated as dead code). This ++; test makes sure we don't regress. ++ ++declare void @foo(ptr addrspace(1)) ++ ++declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) ++ ++; CHECK: @bar ++define void @bar() { ++ %t1 = alloca i32 ++; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) ++; CHECK-NEXT: store i32 10, ptr %t1 ++ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) ++ store i32 10, ptr %t1 ++ call void @foo(ptr addrspace(1) %t2) ++ ret void ++} ++ +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll +--- a/llvm/test/CodeGen/NVPTX/rotate_64.ll ++++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll +@@ -1,38 +1,25 @@ +-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 + ; RUN: llc < %s -march=nvptx64 | FileCheck %s + ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} + + declare i64 @llvm.nvvm.rotate.b64(i64, i32) + declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) + ++; CHECK: rotate64 + define i64 @rotate64(i64 %a, i32 %b) { +-; CHECK-LABEL: rotate64( +-; CHECK: { +-; CHECK-NEXT: .reg .b64 %rd<5>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; +-; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; +-; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; +-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; +-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; +-; CHECK-NEXT: ret; ++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; ++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; ++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; ++; CHECK: ret + %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) + ret i64 %val + } + ++; CHECK: rotateright64 + define i64 @rotateright64(i64 %a, i32 %b) { +-; CHECK-LABEL: rotateright64( +-; CHECK: { +-; CHECK-NEXT: .reg .b64 %rd<5>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; +-; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; +-; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; +-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; +-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; +-; CHECK-NEXT: ret; ++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; ++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; ++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; ++; CHECK: ret + %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) + ret i64 %val + } +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll +--- a/llvm/test/CodeGen/NVPTX/rotate.ll ++++ b/llvm/test/CodeGen/NVPTX/rotate.ll +@@ -9,29 +9,26 @@ + declare i64 @llvm.nvvm.rotate.b64(i64, i32) + declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) + +-declare i64 @llvm.fshl.i64(i64, i64, i64) +-declare i64 @llvm.fshr.i64(i64, i64, i64) +-declare i32 @llvm.fshl.i32(i32, i32, i32) +-declare i32 @llvm.fshr.i32(i32, i32, i32) +- +- + ; SM20: rotate32 + ; SM35: rotate32 + define i32 @rotate32(i32 %a, i32 %b) { + ; SM20-LABEL: rotate32( + ; SM20: { +-; SM20-NEXT: .reg .b32 %r<9>; ++; SM20-NEXT: .reg .b32 %r<4>; + ; SM20-EMPTY: + ; SM20-NEXT: // %bb.0: + ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; + ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; +-; SM20-NEXT: and.b32 %r3, %r2, 31; +-; SM20-NEXT: shl.b32 %r4, %r1, %r3; +-; SM20-NEXT: neg.s32 %r5, %r2; +-; SM20-NEXT: and.b32 %r6, %r5, 31; +-; SM20-NEXT: shr.u32 %r7, %r1, %r6; +-; SM20-NEXT: or.b32 %r8, %r4, %r7; +-; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; ++; SM20-NEXT: { ++; SM20-NEXT: .reg .b32 %lhs; ++; SM20-NEXT: .reg .b32 %rhs; ++; SM20-NEXT: .reg .b32 %amt2; ++; SM20-NEXT: shl.b32 %lhs, %r1, %r2; ++; SM20-NEXT: sub.s32 %amt2, 32, %r2; ++; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; ++; SM20-NEXT: add.u32 %r3, %lhs, %rhs; ++; SM20-NEXT: } ++; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; + ; SM20-NEXT: ret; + ; + ; SM35-LABEL: rotate32( +@@ -53,36 +50,45 @@ + define i64 @rotate64(i64 %a, i32 %b) { + ; SM20-LABEL: rotate64( + ; SM20: { +-; SM20-NEXT: .reg .b32 %r<5>; +-; SM20-NEXT: .reg .b64 %rd<5>; ++; SM20-NEXT: .reg .b32 %r<2>; ++; SM20-NEXT: .reg .b64 %rd<3>; + ; SM20-EMPTY: + ; SM20-NEXT: // %bb.0: + ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; + ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; +-; SM20-NEXT: and.b32 %r2, %r1, 63; +-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +-; SM20-NEXT: neg.s32 %r3, %r1; +-; SM20-NEXT: and.b32 %r4, %r3, 63; +-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; +-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM20-NEXT: { ++; SM20-NEXT: .reg .b64 %lhs; ++; SM20-NEXT: .reg .b64 %rhs; ++; SM20-NEXT: .reg .u32 %amt2; ++; SM20-NEXT: and.b32 %amt2, %r1, 63; ++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; ++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; ++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; ++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM20-NEXT: } ++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM20-NEXT: ret; + ; + ; SM35-LABEL: rotate64( + ; SM35: { +-; SM35-NEXT: .reg .b32 %r<5>; +-; SM35-NEXT: .reg .b64 %rd<5>; ++; SM35-NEXT: .reg .b32 %r<6>; ++; SM35-NEXT: .reg .b64 %rd<3>; + ; SM35-EMPTY: + ; SM35-NEXT: // %bb.0: + ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; +-; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; +-; SM35-NEXT: and.b32 %r2, %r1, 63; +-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +-; SM35-NEXT: neg.s32 %r3, %r1; +-; SM35-NEXT: and.b32 %r4, %r3, 63; +-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; +-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM35-NEXT: { ++; SM35-NEXT: .reg .b32 %dummy; ++; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; ++; SM35-NEXT: } ++; SM35-NEXT: { ++; SM35-NEXT: .reg .b32 %dummy; ++; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; ++; SM35-NEXT: } ++; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; ++; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; ++; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; ++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; ++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM35-NEXT: ret; + %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) + ret i64 %val +@@ -93,36 +99,45 @@ + define i64 @rotateright64(i64 %a, i32 %b) { + ; SM20-LABEL: rotateright64( + ; SM20: { +-; SM20-NEXT: .reg .b32 %r<5>; +-; SM20-NEXT: .reg .b64 %rd<5>; ++; SM20-NEXT: .reg .b32 %r<2>; ++; SM20-NEXT: .reg .b64 %rd<3>; + ; SM20-EMPTY: + ; SM20-NEXT: // %bb.0: + ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; + ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; +-; SM20-NEXT: and.b32 %r2, %r1, 63; +-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; +-; SM20-NEXT: neg.s32 %r3, %r1; +-; SM20-NEXT: and.b32 %r4, %r3, 63; +-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; +-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM20-NEXT: { ++; SM20-NEXT: .reg .b64 %lhs; ++; SM20-NEXT: .reg .b64 %rhs; ++; SM20-NEXT: .reg .u32 %amt2; ++; SM20-NEXT: and.b32 %amt2, %r1, 63; ++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; ++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; ++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; ++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM20-NEXT: } ++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM20-NEXT: ret; + ; + ; SM35-LABEL: rotateright64( + ; SM35: { +-; SM35-NEXT: .reg .b32 %r<5>; +-; SM35-NEXT: .reg .b64 %rd<5>; ++; SM35-NEXT: .reg .b32 %r<6>; ++; SM35-NEXT: .reg .b64 %rd<3>; + ; SM35-EMPTY: + ; SM35-NEXT: // %bb.0: + ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; +-; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; +-; SM35-NEXT: and.b32 %r2, %r1, 63; +-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; +-; SM35-NEXT: neg.s32 %r3, %r1; +-; SM35-NEXT: and.b32 %r4, %r3, 63; +-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; +-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM35-NEXT: { ++; SM35-NEXT: .reg .b32 %dummy; ++; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; ++; SM35-NEXT: } ++; SM35-NEXT: { ++; SM35-NEXT: .reg .b32 %dummy; ++; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; ++; SM35-NEXT: } ++; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; ++; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; ++; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; ++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; ++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM35-NEXT: ret; + %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) + ret i64 %val +@@ -133,14 +148,18 @@ + define i32 @rotl0(i32 %x) { + ; SM20-LABEL: rotl0( + ; SM20: { +-; SM20-NEXT: .reg .b32 %r<5>; ++; SM20-NEXT: .reg .b32 %r<3>; + ; SM20-EMPTY: + ; SM20-NEXT: // %bb.0: + ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; +-; SM20-NEXT: shr.u32 %r2, %r1, 24; +-; SM20-NEXT: shl.b32 %r3, %r1, 8; +-; SM20-NEXT: or.b32 %r4, %r3, %r2; +-; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; ++; SM20-NEXT: { ++; SM20-NEXT: .reg .b32 %lhs; ++; SM20-NEXT: .reg .b32 %rhs; ++; SM20-NEXT: shl.b32 %lhs, %r1, 8; ++; SM20-NEXT: shr.b32 %rhs, %r1, 24; ++; SM20-NEXT: add.u32 %r2, %lhs, %rhs; ++; SM20-NEXT: } ++; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; + ; SM20-NEXT: ret; + ; + ; SM35-LABEL: rotl0( +@@ -158,40 +177,51 @@ + ret i32 %t2 + } + ++declare i64 @llvm.fshl.i64(i64, i64, i64) ++declare i64 @llvm.fshr.i64(i64, i64, i64) ++ + ; SM35: rotl64 + define i64 @rotl64(i64 %a, i64 %n) { + ; SM20-LABEL: rotl64( + ; SM20: { +-; SM20-NEXT: .reg .b32 %r<5>; +-; SM20-NEXT: .reg .b64 %rd<5>; ++; SM20-NEXT: .reg .b32 %r<2>; ++; SM20-NEXT: .reg .b64 %rd<3>; + ; SM20-EMPTY: + ; SM20-NEXT: // %bb.0: + ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; + ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; +-; SM20-NEXT: and.b32 %r2, %r1, 63; +-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +-; SM20-NEXT: neg.s32 %r3, %r1; +-; SM20-NEXT: and.b32 %r4, %r3, 63; +-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; +-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM20-NEXT: { ++; SM20-NEXT: .reg .b64 %lhs; ++; SM20-NEXT: .reg .b64 %rhs; ++; SM20-NEXT: .reg .u32 %amt2; ++; SM20-NEXT: and.b32 %amt2, %r1, 63; ++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; ++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; ++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; ++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM20-NEXT: } ++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM20-NEXT: ret; + ; + ; SM35-LABEL: rotl64( + ; SM35: { +-; SM35-NEXT: .reg .b32 %r<5>; +-; SM35-NEXT: .reg .b64 %rd<5>; ++; SM35-NEXT: .reg .b32 %r<2>; ++; SM35-NEXT: .reg .b64 %rd<3>; + ; SM35-EMPTY: + ; SM35-NEXT: // %bb.0: + ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; + ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; +-; SM35-NEXT: and.b32 %r2, %r1, 63; +-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +-; SM35-NEXT: neg.s32 %r3, %r1; +-; SM35-NEXT: and.b32 %r4, %r3, 63; +-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; +-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM35-NEXT: { ++; SM35-NEXT: .reg .b64 %lhs; ++; SM35-NEXT: .reg .b64 %rhs; ++; SM35-NEXT: .reg .u32 %amt2; ++; SM35-NEXT: and.b32 %amt2, %r1, 63; ++; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; ++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; ++; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; ++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM35-NEXT: } ++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM35-NEXT: ret; + %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) + ret i64 %val +@@ -201,26 +231,34 @@ + define i64 @rotl64_imm(i64 %a) { + ; SM20-LABEL: rotl64_imm( + ; SM20: { +-; SM20-NEXT: .reg .b64 %rd<5>; ++; SM20-NEXT: .reg .b64 %rd<3>; + ; SM20-EMPTY: + ; SM20-NEXT: // %bb.0: + ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; +-; SM20-NEXT: shr.u64 %rd2, %rd1, 62; +-; SM20-NEXT: shl.b64 %rd3, %rd1, 2; +-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; +-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM20-NEXT: { ++; SM20-NEXT: .reg .b64 %lhs; ++; SM20-NEXT: .reg .b64 %rhs; ++; SM20-NEXT: shl.b64 %lhs, %rd1, 2; ++; SM20-NEXT: shr.b64 %rhs, %rd1, 62; ++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM20-NEXT: } ++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM20-NEXT: ret; + ; + ; SM35-LABEL: rotl64_imm( + ; SM35: { +-; SM35-NEXT: .reg .b64 %rd<5>; ++; SM35-NEXT: .reg .b64 %rd<3>; + ; SM35-EMPTY: + ; SM35-NEXT: // %bb.0: + ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; +-; SM35-NEXT: shr.u64 %rd2, %rd1, 62; +-; SM35-NEXT: shl.b64 %rd3, %rd1, 2; +-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; +-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM35-NEXT: { ++; SM35-NEXT: .reg .b64 %lhs; ++; SM35-NEXT: .reg .b64 %rhs; ++; SM35-NEXT: shl.b64 %lhs, %rd1, 2; ++; SM35-NEXT: shr.b64 %rhs, %rd1, 62; ++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM35-NEXT: } ++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM35-NEXT: ret; + %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) + ret i64 %val +@@ -230,36 +268,44 @@ + define i64 @rotr64(i64 %a, i64 %n) { + ; SM20-LABEL: rotr64( + ; SM20: { +-; SM20-NEXT: .reg .b32 %r<5>; +-; SM20-NEXT: .reg .b64 %rd<5>; ++; SM20-NEXT: .reg .b32 %r<2>; ++; SM20-NEXT: .reg .b64 %rd<3>; + ; SM20-EMPTY: + ; SM20-NEXT: // %bb.0: + ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; + ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; +-; SM20-NEXT: and.b32 %r2, %r1, 63; +-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; +-; SM20-NEXT: neg.s32 %r3, %r1; +-; SM20-NEXT: and.b32 %r4, %r3, 63; +-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; +-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM20-NEXT: { ++; SM20-NEXT: .reg .b64 %lhs; ++; SM20-NEXT: .reg .b64 %rhs; ++; SM20-NEXT: .reg .u32 %amt2; ++; SM20-NEXT: and.b32 %amt2, %r1, 63; ++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; ++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; ++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; ++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM20-NEXT: } ++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM20-NEXT: ret; + ; + ; SM35-LABEL: rotr64( + ; SM35: { +-; SM35-NEXT: .reg .b32 %r<5>; +-; SM35-NEXT: .reg .b64 %rd<5>; ++; SM35-NEXT: .reg .b32 %r<2>; ++; SM35-NEXT: .reg .b64 %rd<3>; + ; SM35-EMPTY: + ; SM35-NEXT: // %bb.0: + ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; + ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; +-; SM35-NEXT: and.b32 %r2, %r1, 63; +-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; +-; SM35-NEXT: neg.s32 %r3, %r1; +-; SM35-NEXT: and.b32 %r4, %r3, 63; +-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; +-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM35-NEXT: { ++; SM35-NEXT: .reg .b64 %lhs; ++; SM35-NEXT: .reg .b64 %rhs; ++; SM35-NEXT: .reg .u32 %amt2; ++; SM35-NEXT: and.b32 %amt2, %r1, 63; ++; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; ++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; ++; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; ++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM35-NEXT: } ++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM35-NEXT: ret; + %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) + ret i64 %val +@@ -269,180 +315,35 @@ + define i64 @rotr64_imm(i64 %a) { + ; SM20-LABEL: rotr64_imm( + ; SM20: { +-; SM20-NEXT: .reg .b64 %rd<5>; ++; SM20-NEXT: .reg .b64 %rd<3>; + ; SM20-EMPTY: + ; SM20-NEXT: // %bb.0: + ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; +-; SM20-NEXT: shl.b64 %rd2, %rd1, 62; +-; SM20-NEXT: shr.u64 %rd3, %rd1, 2; +-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; +-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM20-NEXT: { ++; SM20-NEXT: .reg .b64 %lhs; ++; SM20-NEXT: .reg .b64 %rhs; ++; SM20-NEXT: shl.b64 %lhs, %rd1, 62; ++; SM20-NEXT: shr.b64 %rhs, %rd1, 2; ++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM20-NEXT: } ++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM20-NEXT: ret; + ; + ; SM35-LABEL: rotr64_imm( + ; SM35: { +-; SM35-NEXT: .reg .b64 %rd<5>; ++; SM35-NEXT: .reg .b64 %rd<3>; + ; SM35-EMPTY: + ; SM35-NEXT: // %bb.0: + ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; +-; SM35-NEXT: shl.b64 %rd2, %rd1, 62; +-; SM35-NEXT: shr.u64 %rd3, %rd1, 2; +-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; +-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; ++; SM35-NEXT: { ++; SM35-NEXT: .reg .b64 %lhs; ++; SM35-NEXT: .reg .b64 %rhs; ++; SM35-NEXT: shl.b64 %lhs, %rd1, 62; ++; SM35-NEXT: shr.b64 %rhs, %rd1, 2; ++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; ++; SM35-NEXT: } ++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; + ; SM35-NEXT: ret; + %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) + ret i64 %val + } +- +-define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { +-; SM20-LABEL: funnel_shift_right_32( +-; SM20: { +-; SM20-NEXT: .reg .b32 %r<11>; +-; SM20-EMPTY: +-; SM20-NEXT: // %bb.0: +-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; +-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; +-; SM20-NEXT: and.b32 %r3, %r2, 31; +-; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; +-; SM20-NEXT: shr.u32 %r5, %r4, %r3; +-; SM20-NEXT: shl.b32 %r6, %r1, 1; +-; SM20-NEXT: not.b32 %r7, %r2; +-; SM20-NEXT: and.b32 %r8, %r7, 31; +-; SM20-NEXT: shl.b32 %r9, %r6, %r8; +-; SM20-NEXT: or.b32 %r10, %r9, %r5; +-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; +-; SM20-NEXT: ret; +-; +-; SM35-LABEL: funnel_shift_right_32( +-; SM35: { +-; SM35-NEXT: .reg .b32 %r<5>; +-; SM35-EMPTY: +-; SM35-NEXT: // %bb.0: +-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; +-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; +-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; +-; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; +-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; +-; SM35-NEXT: ret; +- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) +- ret i32 %val +-} +- +-define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { +-; SM20-LABEL: funnel_shift_left_32( +-; SM20: { +-; SM20-NEXT: .reg .b32 %r<11>; +-; SM20-EMPTY: +-; SM20-NEXT: // %bb.0: +-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; +-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; +-; SM20-NEXT: and.b32 %r3, %r2, 31; +-; SM20-NEXT: shl.b32 %r4, %r1, %r3; +-; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; +-; SM20-NEXT: shr.u32 %r6, %r5, 1; +-; SM20-NEXT: not.b32 %r7, %r2; +-; SM20-NEXT: and.b32 %r8, %r7, 31; +-; SM20-NEXT: shr.u32 %r9, %r6, %r8; +-; SM20-NEXT: or.b32 %r10, %r4, %r9; +-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; +-; SM20-NEXT: ret; +-; +-; SM35-LABEL: funnel_shift_left_32( +-; SM35: { +-; SM35-NEXT: .reg .b32 %r<5>; +-; SM35-EMPTY: +-; SM35-NEXT: // %bb.0: +-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; +-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; +-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; +-; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; +-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; +-; SM35-NEXT: ret; +- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) +- ret i32 %val +-} +- +-define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { +-; SM20-LABEL: funnel_shift_right_64( +-; SM20: { +-; SM20-NEXT: .reg .b32 %r<5>; +-; SM20-NEXT: .reg .b64 %rd<7>; +-; SM20-EMPTY: +-; SM20-NEXT: // %bb.0: +-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; +-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; +-; SM20-NEXT: and.b32 %r2, %r1, 63; +-; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; +-; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; +-; SM20-NEXT: shl.b64 %rd4, %rd1, 1; +-; SM20-NEXT: not.b32 %r3, %r1; +-; SM20-NEXT: and.b32 %r4, %r3, 63; +-; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; +-; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; +-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; +-; SM20-NEXT: ret; +-; +-; SM35-LABEL: funnel_shift_right_64( +-; SM35: { +-; SM35-NEXT: .reg .b32 %r<5>; +-; SM35-NEXT: .reg .b64 %rd<7>; +-; SM35-EMPTY: +-; SM35-NEXT: // %bb.0: +-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; +-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; +-; SM35-NEXT: and.b32 %r2, %r1, 63; +-; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; +-; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; +-; SM35-NEXT: shl.b64 %rd4, %rd1, 1; +-; SM35-NEXT: not.b32 %r3, %r1; +-; SM35-NEXT: and.b32 %r4, %r3, 63; +-; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; +-; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; +-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; +-; SM35-NEXT: ret; +- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) +- ret i64 %val +-} +- +-define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { +-; SM20-LABEL: funnel_shift_left_64( +-; SM20: { +-; SM20-NEXT: .reg .b32 %r<5>; +-; SM20-NEXT: .reg .b64 %rd<7>; +-; SM20-EMPTY: +-; SM20-NEXT: // %bb.0: +-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; +-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; +-; SM20-NEXT: and.b32 %r2, %r1, 63; +-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +-; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; +-; SM20-NEXT: shr.u64 %rd4, %rd3, 1; +-; SM20-NEXT: not.b32 %r3, %r1; +-; SM20-NEXT: and.b32 %r4, %r3, 63; +-; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; +-; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; +-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; +-; SM20-NEXT: ret; +-; +-; SM35-LABEL: funnel_shift_left_64( +-; SM35: { +-; SM35-NEXT: .reg .b32 %r<5>; +-; SM35-NEXT: .reg .b64 %rd<7>; +-; SM35-EMPTY: +-; SM35-NEXT: // %bb.0: +-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; +-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; +-; SM35-NEXT: and.b32 %r2, %r1, 63; +-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +-; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; +-; SM35-NEXT: shr.u64 %rd4, %rd3, 1; +-; SM35-NEXT: not.b32 %r3, %r1; +-; SM35-NEXT: and.b32 %r4, %r3, 63; +-; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; +-; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; +-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; +-; SM35-NEXT: ret; +- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) +- ret i64 %val +-} +- +diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll +--- a/llvm/test/DebugInfo/NVPTX/debug-info.ll ++++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll +@@ -25,10 +25,6 @@ + ; CHECK-DAG: .reg .b64 %rd<8>; + ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 + ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; +-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; + ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 + ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; + ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 +@@ -42,6 +38,10 @@ + ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 + ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; + ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; ++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; ++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; ++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; ++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; + ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 + ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; + ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; +@@ -2661,22 +2661,22 @@ + ; CHECK-NEXT:.b32 4579 // DW_AT_type + ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine + ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin +-; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc +-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc ++; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc ++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc + ; CHECK-NEXT:.b8 1 // DW_AT_call_file + ; CHECK-NEXT:.b8 6 // DW_AT_call_line + ; CHECK-NEXT:.b8 11 // DW_AT_call_column + ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine + ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin +-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc +-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc ++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc ++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc + ; CHECK-NEXT:.b8 1 // DW_AT_call_file + ; CHECK-NEXT:.b8 6 // DW_AT_call_line + ; CHECK-NEXT:.b8 24 // DW_AT_call_column + ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine + ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin +-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc +-; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc ++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc ++; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc + ; CHECK-NEXT:.b8 1 // DW_AT_call_file + ; CHECK-NEXT:.b8 6 // DW_AT_call_line + ; CHECK-NEXT:.b8 37 // DW_AT_call_column diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index abe15efc5e7204..af35fe705c0b99 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" - LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" + LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" + LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 5fd5f295cd7dfc..d3fd21823cce19 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +1,4115 @@ +diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch +index 509398d..de92cb4 100644 +--- a/third_party/llvm/generated.patch ++++ b/third_party/llvm/generated.patch +@@ -1 +1,4095 @@ + Auto generated patch. Do not edit or delete it, even if empty. ++diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst ++--- a/llvm/docs/NVPTXUsage.rst +++++ b/llvm/docs/NVPTXUsage.rst ++@@ -127,6 +127,69 @@ ++ NVPTX Intrinsics ++ ================ ++ +++Address Space Conversion +++------------------------ +++ +++'``llvm.nvvm.ptr.*.to.gen``' Intrinsics +++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +++ +++Syntax: +++""""""" +++ +++These are overloaded intrinsics. You can use these on any pointer types. +++ +++.. code-block:: llvm +++ +++ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) +++ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) +++ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) +++ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) +++ +++Overview: +++""""""""" +++ +++The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic +++address space to a generic address space pointer. +++ +++Semantics: +++"""""""""" +++ +++These intrinsics modify the pointer value to be a valid generic address space +++pointer. +++ +++ +++'``llvm.nvvm.ptr.gen.to.*``' Intrinsics +++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +++ +++Syntax: +++""""""" +++ +++These are overloaded intrinsics. You can use these on any pointer types. +++ +++.. code-block:: llvm +++ +++ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +++ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) +++ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) +++ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) +++ +++Overview: +++""""""""" +++ +++The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic +++address space to a pointer in the target address space. Note that these +++intrinsics are only useful if the address space of the target address space of +++the pointer is known. It is not legal to use address space conversion +++intrinsics to convert a pointer from one non-generic address space to another +++non-generic address space. +++ +++Semantics: +++"""""""""" +++ +++These intrinsics modify the pointer value to be a valid pointer in the target +++non-generic address space. +++ +++ ++ Reading PTX Special Registers ++ ----------------------------- ++ ++diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst ++--- a/llvm/docs/ReleaseNotes.rst +++++ b/llvm/docs/ReleaseNotes.rst ++@@ -63,24 +63,6 @@ ++ * ``llvm.nvvm.bitcast.d2ll`` ++ * ``llvm.nvvm.bitcast.ll2d`` ++ ++-* Remove the following intrinsics which can be replaced with a funnel-shift: ++- ++- * ``llvm.nvvm.rotate.b32`` ++- * ``llvm.nvvm.rotate.right.b64`` ++- * ``llvm.nvvm.rotate.b64`` ++- ++-* Remove the following intrinsics which can be replaced with an ++- ``addrspacecast``: ++- ++- * ``llvm.nvvm.ptr.gen.to.global`` ++- * ``llvm.nvvm.ptr.gen.to.shared`` ++- * ``llvm.nvvm.ptr.gen.to.constant`` ++- * ``llvm.nvvm.ptr.gen.to.local`` ++- * ``llvm.nvvm.ptr.global.to.gen`` ++- * ``llvm.nvvm.ptr.shared.to.gen`` ++- * ``llvm.nvvm.ptr.constant.to.gen`` ++- * ``llvm.nvvm.ptr.local.to.gen`` ++- ++ Changes to LLVM infrastructure ++ ------------------------------ ++ ++diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td ++--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td ++@@ -30,18 +30,10 @@ ++ // * llvm.nvvm.max.ui --> select(x ule y, x, y) ++ // * llvm.nvvm.max.ull --> ibid. ++ // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 ++-// * llvm.nvvm.bitcast.f2i --> bitcast ++-// * llvm.nvvm.bitcast.i2f --> ibid. ++-// * llvm.nvvm.bitcast.d2ll --> ibid. ++-// * llvm.nvvm.bitcast.ll2d --> ibid. ++-// * llvm.nvvm.ptr.gen.to.global --> addrspacecast ++-// * llvm.nvvm.ptr.gen.to.shared --> ibid. ++-// * llvm.nvvm.ptr.gen.to.constant --> ibid. ++-// * llvm.nvvm.ptr.gen.to.local --> ibid. ++-// * llvm.nvvm.ptr.global.to.gen --> ibid. ++-// * llvm.nvvm.ptr.shared.to.gen --> ibid. ++-// * llvm.nvvm.ptr.constant.to.gen --> ibid. ++-// * llvm.nvvm.ptr.local.to.gen --> ibid. +++// * llvm.nvvm.bitcast.f2i --> bitcast +++// * llvm.nvvm.bitcast.i2f --> ibid. +++// * llvm.nvvm.bitcast.d2ll --> ibid. +++// * llvm.nvvm.bitcast.ll2d --> ibid. ++ ++ def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr ++ def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr ++@@ -1610,6 +1602,40 @@ ++ [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], ++ "llvm.nvvm.ldg.global.p">; ++ +++// Use for generic pointers +++// - These intrinsics are used to convert address spaces. +++// - The input pointer and output pointer must have the same type, except for +++// the address-space. (This restriction is not enforced here as there is +++// currently no way to describe it). +++// - This complements the llvm bitcast, which can be used to cast one type +++// of pointer to another type of pointer, while the address space remains +++// the same. +++def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.local.to.gen">; +++def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.shared.to.gen">; +++def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.global.to.gen">; +++def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.constant.to.gen">; +++ +++def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.gen.to.global">; +++def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.gen.to.shared">; +++def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.gen.to.local">; +++def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.gen.to.constant">; +++ ++ // Used in nvvm internally to help address space opt and ptx code generation ++ // This is for params that are passed to kernel functions by pointer by-val. ++ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], ++@@ -4453,6 +4479,22 @@ ++ "llvm.nvvm.sust.p.3d.v4i32.trap">, ++ ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; ++ +++ +++def int_nvvm_rotate_b32 +++ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], +++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, +++ ClangBuiltin<"__nvvm_rotate_b32">; +++ +++def int_nvvm_rotate_b64 +++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], +++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, +++ ClangBuiltin<"__nvvm_rotate_b64">; +++ +++def int_nvvm_rotate_right_b64 +++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], +++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, +++ ClangBuiltin<"__nvvm_rotate_right_b64">; +++ ++ def int_nvvm_swap_lo_hi_b64 ++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], ++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, ++diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp ++--- a/llvm/lib/IR/AutoUpgrade.cpp +++++ b/llvm/lib/IR/AutoUpgrade.cpp ++@@ -1272,19 +1272,6 @@ ++ // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} ++ Expand = ++ Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; ++- else if (Name.consume_front("rotate.")) ++- // nvvm.rotate.{b32,b64,right.b64} ++- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; ++- else if (Name.consume_front("ptr.gen.to.")) ++- // nvvm.ptr.gen.to.{local,shared,global,constant} ++- Expand = Name.starts_with("local") || Name.starts_with("shared") || ++- Name.starts_with("global") || Name.starts_with("constant"); ++- else if (Name.consume_front("ptr.")) ++- // nvvm.ptr.{local,shared,global,constant}.to.gen ++- Expand = ++- (Name.consume_front("local") || Name.consume_front("shared") || ++- Name.consume_front("global") || Name.consume_front("constant")) && ++- Name.starts_with(".to.gen"); ++ else ++ Expand = false; ++ ++@@ -2271,117 +2258,6 @@ ++ } ++ } ++ ++-static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, ++- Function *F, IRBuilder<> &Builder) { ++- Value *Rep = nullptr; ++- ++- if (Name == "abs.i" || Name == "abs.ll") { ++- Value *Arg = CI->getArgOperand(0); ++- Value *Neg = Builder.CreateNeg(Arg, "neg"); ++- Value *Cmp = Builder.CreateICmpSGE( ++- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); ++- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); ++- } else if (Name.starts_with("atomic.load.add.f32.p") || ++- Name.starts_with("atomic.load.add.f64.p")) { ++- Value *Ptr = CI->getArgOperand(0); ++- Value *Val = CI->getArgOperand(1); ++- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), ++- AtomicOrdering::SequentiallyConsistent); ++- } else if (Name.consume_front("max.") && ++- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || ++- Name == "ui" || Name == "ull")) { ++- Value *Arg0 = CI->getArgOperand(0); ++- Value *Arg1 = CI->getArgOperand(1); ++- Value *Cmp = Name.starts_with("u") ++- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") ++- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); ++- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); ++- } else if (Name.consume_front("min.") && ++- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || ++- Name == "ui" || Name == "ull")) { ++- Value *Arg0 = CI->getArgOperand(0); ++- Value *Arg1 = CI->getArgOperand(1); ++- Value *Cmp = Name.starts_with("u") ++- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") ++- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); ++- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); ++- } else if (Name == "clz.ll") { ++- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. ++- Value *Arg = CI->getArgOperand(0); ++- Value *Ctlz = Builder.CreateCall( ++- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, ++- {Arg->getType()}), ++- {Arg, Builder.getFalse()}, "ctlz"); ++- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); ++- } else if (Name == "popc.ll") { ++- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an ++- // i64. ++- Value *Arg = CI->getArgOperand(0); ++- Value *Popc = Builder.CreateCall( ++- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, ++- {Arg->getType()}), ++- Arg, "ctpop"); ++- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); ++- } else if (Name == "h2f") { ++- Rep = Builder.CreateCall( ++- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, ++- {Builder.getFloatTy()}), ++- CI->getArgOperand(0), "h2f"); ++- } else if (Name.consume_front("bitcast.") && ++- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || ++- Name == "d2ll")) { ++- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); ++- } else if (Name == "rotate.b32") { ++- Value *Arg = CI->getOperand(0); ++- Value *ShiftAmt = CI->getOperand(1); ++- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, ++- {Arg, Arg, ShiftAmt}); ++- } else if (Name == "rotate.b64") { ++- Type *Int64Ty = Builder.getInt64Ty(); ++- Value *Arg = CI->getOperand(0); ++- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); ++- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, ++- {Arg, Arg, ZExtShiftAmt}); ++- } else if (Name == "rotate.right.b64") { ++- Type *Int64Ty = Builder.getInt64Ty(); ++- Value *Arg = CI->getOperand(0); ++- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); ++- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, ++- {Arg, Arg, ZExtShiftAmt}); ++- } else if ((Name.consume_front("ptr.gen.to.") && ++- (Name.starts_with("local") || Name.starts_with("shared") || ++- Name.starts_with("global") || Name.starts_with("constant"))) || ++- (Name.consume_front("ptr.") && ++- (Name.consume_front("local") || Name.consume_front("shared") || ++- Name.consume_front("global") || ++- Name.consume_front("constant")) && ++- Name.starts_with(".to.gen"))) { ++- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); ++- } else { ++- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); ++- if (IID != Intrinsic::not_intrinsic && ++- !F->getReturnType()->getScalarType()->isBFloatTy()) { ++- rename(F); ++- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); ++- SmallVector Args; ++- for (size_t I = 0; I < NewFn->arg_size(); ++I) { ++- Value *Arg = CI->getArgOperand(I); ++- Type *OldType = Arg->getType(); ++- Type *NewType = NewFn->getArg(I)->getType(); ++- Args.push_back( ++- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) ++- ? Builder.CreateBitCast(Arg, NewType) ++- : Arg); ++- } ++- Rep = Builder.CreateCall(NewFn, Args); ++- if (F->getReturnType()->isIntegerTy()) ++- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); ++- } ++- } ++- ++- return Rep; ++-} ++- ++ static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, ++ IRBuilder<> &Builder) { ++ LLVMContext &C = F->getContext(); ++@@ -4332,8 +4208,85 @@ ++ ++ if (!IsX86 && Name == "stackprotectorcheck") { ++ Rep = nullptr; +++ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { +++ Value *Arg = CI->getArgOperand(0); +++ Value *Neg = Builder.CreateNeg(Arg, "neg"); +++ Value *Cmp = Builder.CreateICmpSGE( +++ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); +++ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); +++ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || +++ Name.starts_with("atomic.load.add.f64.p"))) { +++ Value *Ptr = CI->getArgOperand(0); +++ Value *Val = CI->getArgOperand(1); +++ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), +++ AtomicOrdering::SequentiallyConsistent); +++ } else if (IsNVVM && Name.consume_front("max.") && +++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +++ Name == "ui" || Name == "ull")) { +++ Value *Arg0 = CI->getArgOperand(0); +++ Value *Arg1 = CI->getArgOperand(1); +++ Value *Cmp = Name.starts_with("u") +++ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") +++ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); +++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); +++ } else if (IsNVVM && Name.consume_front("min.") && +++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +++ Name == "ui" || Name == "ull")) { +++ Value *Arg0 = CI->getArgOperand(0); +++ Value *Arg1 = CI->getArgOperand(1); +++ Value *Cmp = Name.starts_with("u") +++ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") +++ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); +++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); +++ } else if (IsNVVM && Name == "clz.ll") { +++ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. +++ Value *Arg = CI->getArgOperand(0); +++ Value *Ctlz = Builder.CreateCall( +++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, +++ {Arg->getType()}), +++ {Arg, Builder.getFalse()}, "ctlz"); +++ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); +++ } else if (IsNVVM && Name == "popc.ll") { +++ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an +++ // i64. +++ Value *Arg = CI->getArgOperand(0); +++ Value *Popc = Builder.CreateCall( +++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, +++ {Arg->getType()}), +++ Arg, "ctpop"); +++ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); ++ } else if (IsNVVM) { ++- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); +++ if (Name == "h2f") { +++ Rep = +++ Builder.CreateCall(Intrinsic::getDeclaration( +++ F->getParent(), Intrinsic::convert_from_fp16, +++ {Builder.getFloatTy()}), +++ CI->getArgOperand(0), "h2f"); +++ } else if (Name.consume_front("bitcast.") && +++ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || +++ Name == "d2ll")) { +++ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); +++ } else { +++ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); +++ if (IID != Intrinsic::not_intrinsic && +++ !F->getReturnType()->getScalarType()->isBFloatTy()) { +++ rename(F); +++ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); +++ SmallVector Args; +++ for (size_t I = 0; I < NewFn->arg_size(); ++I) { +++ Value *Arg = CI->getArgOperand(I); +++ Type *OldType = Arg->getType(); +++ Type *NewType = NewFn->getArg(I)->getType(); +++ Args.push_back((OldType->isIntegerTy() && +++ NewType->getScalarType()->isBFloatTy()) +++ ? Builder.CreateBitCast(Arg, NewType) +++ : Arg); +++ } +++ Rep = Builder.CreateCall(NewFn, Args); +++ if (F->getReturnType()->isIntegerTy()) +++ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); +++ } +++ } ++ } else if (IsX86) { ++ Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); ++ } else if (IsARM) { ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp ++--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp ++@@ -292,7 +292,6 @@ ++ static const LLT S224 = LLT::scalar(224); ++ static const LLT S256 = LLT::scalar(256); ++ static const LLT S512 = LLT::scalar(512); ++-static const LLT S1024 = LLT::scalar(1024); ++ static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); ++ ++ static const LLT V2S8 = LLT::fixed_vector(2, 8); ++@@ -333,8 +332,8 @@ ++ static const LLT V2S128 = LLT::fixed_vector(2, 128); ++ static const LLT V4S128 = LLT::fixed_vector(4, 128); ++ ++-static std::initializer_list AllScalarTypes = { ++- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; +++static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, +++ S160, S224, S256, S512}; ++ ++ static std::initializer_list AllS16Vectors{ ++ V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; ++@@ -890,11 +889,10 @@ ++ .clampScalar(0, S16, S64); ++ ++ getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) ++- .legalIf(isRegisterClassType(0)) +++ .legalIf(isRegisterType(0)) ++ // s1 and s16 are special cases because they have legal operations on ++ // them, but don't really occupy registers in the normal way. ++ .legalFor({S1, S16}) ++- .clampNumElements(0, V16S32, V32S32) ++ .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) ++ .clampScalarOrElt(0, S32, MaxScalar) ++ .widenScalarToNextPow2(0, 32) ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++@@ -174,6 +174,10 @@ ++ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" ++ "&& Subtarget->getPTXVersion() >= 64)">; ++ +++def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; +++def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; +++def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; +++ ++ def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; ++ def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; ++ ++@@ -1661,6 +1665,167 @@ ++ "brev.b64 \t$dst, $a;", ++ [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; ++ +++// +++// Rotate: Use ptx shf instruction if available. +++// +++ +++// 32 bit r2 = rotl r1, n +++// => +++// r2 = shf.l r1, r1, n +++def ROTL32imm_hw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), +++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, +++ Requires<[hasHWROT32]>; +++ +++def ROTL32reg_hw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +++ Requires<[hasHWROT32]>; +++ +++// 32 bit r2 = rotr r1, n +++// => +++// r2 = shf.r r1, r1, n +++def ROTR32imm_hw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), +++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, +++ Requires<[hasHWROT32]>; +++ +++def ROTR32reg_hw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +++ Requires<[hasHWROT32]>; +++ +++// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. +++def ROT32imm_sw : +++ NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), +++ "{{\n\t" +++ ".reg .b32 %lhs;\n\t" +++ ".reg .b32 %rhs;\n\t" +++ "shl.b32 \t%lhs, $src, $amt1;\n\t" +++ "shr.b32 \t%rhs, $src, $amt2;\n\t" +++ "add.u32 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ []>; +++ +++def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); +++}]>; +++ +++def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), +++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, +++ Requires<[noHWROT32]>; +++def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), +++ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, +++ Requires<[noHWROT32]>; +++ +++// 32-bit software rotate left by register. +++def ROTL32reg_sw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +++ "{{\n\t" +++ ".reg .b32 %lhs;\n\t" +++ ".reg .b32 %rhs;\n\t" +++ ".reg .b32 %amt2;\n\t" +++ "shl.b32 \t%lhs, $src, $amt;\n\t" +++ "sub.s32 \t%amt2, 32, $amt;\n\t" +++ "shr.b32 \t%rhs, $src, %amt2;\n\t" +++ "add.u32 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +++ Requires<[noHWROT32]>; +++ +++// 32-bit software rotate right by register. +++def ROTR32reg_sw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +++ "{{\n\t" +++ ".reg .b32 %lhs;\n\t" +++ ".reg .b32 %rhs;\n\t" +++ ".reg .b32 %amt2;\n\t" +++ "shr.b32 \t%lhs, $src, $amt;\n\t" +++ "sub.s32 \t%amt2, 32, $amt;\n\t" +++ "shl.b32 \t%rhs, $src, %amt2;\n\t" +++ "add.u32 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +++ Requires<[noHWROT32]>; +++ +++// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. +++def ROT64imm_sw : +++ NVPTXInst<(outs Int64Regs:$dst), +++ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), +++ "{{\n\t" +++ ".reg .b64 %lhs;\n\t" +++ ".reg .b64 %rhs;\n\t" +++ "shl.b64 \t%lhs, $src, $amt1;\n\t" +++ "shr.b64 \t%rhs, $src, $amt2;\n\t" +++ "add.u64 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ []>; +++ +++def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); +++}]>; +++ +++def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), +++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; +++def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), +++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; +++ +++// 64-bit software rotate left by register. +++def ROTL64reg_sw : +++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), +++ "{{\n\t" +++ ".reg .b64 %lhs;\n\t" +++ ".reg .b64 %rhs;\n\t" +++ ".reg .u32 %amt2;\n\t" +++ "and.b32 \t%amt2, $amt, 63;\n\t" +++ "shl.b64 \t%lhs, $src, %amt2;\n\t" +++ "sub.u32 \t%amt2, 64, %amt2;\n\t" +++ "shr.b64 \t%rhs, $src, %amt2;\n\t" +++ "add.u64 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; +++ +++def ROTR64reg_sw : +++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), +++ "{{\n\t" +++ ".reg .b64 %lhs;\n\t" +++ ".reg .b64 %rhs;\n\t" +++ ".reg .u32 %amt2;\n\t" +++ "and.b32 \t%amt2, $amt, 63;\n\t" +++ "shr.b64 \t%lhs, $src, %amt2;\n\t" +++ "sub.u32 \t%amt2, 64, %amt2;\n\t" +++ "shl.b64 \t%rhs, $src, %amt2;\n\t" +++ "add.u64 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; +++ +++// +++// Funnnel shift in clamp mode +++// +++ +++// Create SDNodes so they can be used in the DAG code, e.g. +++// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) +++def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; +++def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; +++ +++def FUNSHFLCLAMP : +++ NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +++ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", +++ [(set Int32Regs:$dst, +++ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; +++ +++def FUNSHFRCLAMP : +++ NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +++ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", +++ [(set Int32Regs:$dst, +++ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; ++ ++ // ++ // BFE - bit-field extract ++@@ -3492,42 +3657,6 @@ ++ def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), ++ (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; ++ ++-// ++-// Funnel-Shift ++-// ++- ++-// Create SDNodes so they can be used in the DAG code, e.g. ++-// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) ++-def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; ++-def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; ++- ++-// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so ++-// no side effects. ++-let hasSideEffects = false in { ++- multiclass ShfInst { ++- def _i ++- : NVPTXInst<(outs Int32Regs:$dst), ++- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), ++- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", ++- [(set Int32Regs:$dst, ++- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, ++- Requires<[hasHWROT32]>; ++- ++- def _r ++- : NVPTXInst<(outs Int32Regs:$dst), ++- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), ++- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", ++- [(set Int32Regs:$dst, ++- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, ++- Requires<[hasHWROT32]>; ++- } ++- ++- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; ++- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; ++- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; ++- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; ++-} ++- ++ // Count leading zeros ++ let hasSideEffects = false in { ++ def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td ++--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td ++@@ -2537,45 +2537,59 @@ ++ : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; ++ ++ ++-multiclass NG_TO_G { +++multiclass NG_TO_G { ++ def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), ++- "cvta." # Str # ".u32 \t$result, $src;", []>; +++ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), +++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; ++ def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), ++- "cvta." # Str # ".u64 \t$result, $src;", []>; +++ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), +++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; +++ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), +++ "{{ .reg .b64 %tmp;\n\t" +++ #" cvt.u64.u32 \t%tmp, $src;\n\t" +++ #" cvta." # Str # ".u64 \t$result, %tmp; }}", +++ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, +++ Requires<[ShortPtr]>; ++ } ++ ++-multiclass G_TO_NG { +++multiclass G_TO_NG { ++ def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), ++- "cvta.to." # Str # ".u32 \t$result, $src;", []>; +++ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), +++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; ++ def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), ++- "cvta.to." # Str # ".u64 \t$result, $src;", []>; +++ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), +++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; +++ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), +++ "{{ .reg .b64 %tmp;\n\t" +++ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" +++ #" cvt.u32.u64 \t$result, %tmp; }}", +++ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, +++ Requires<[ShortPtr]>; ++ } ++ ++-defm cvta_local : NG_TO_G<"local">; ++-defm cvta_shared : NG_TO_G<"shared">; ++-defm cvta_global : NG_TO_G<"global">; ++-defm cvta_const : NG_TO_G<"const">; ++- ++-defm cvta_to_local : G_TO_NG<"local">; ++-defm cvta_to_shared : G_TO_NG<"shared">; ++-defm cvta_to_global : G_TO_NG<"global">; ++-defm cvta_to_const : G_TO_NG<"const">; ++- ++-// nvvm.ptr.param.to.gen ++-defm cvta_param : NG_TO_G<"param">; ++- ++-def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), ++- (cvta_param Int32Regs:$src)>; ++- ++-def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), ++- (cvta_param_64 Int64Regs:$src)>; +++defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; +++defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; +++defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; +++defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; +++defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; +++ +++defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; +++defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; +++defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; +++defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; ++ ++ // nvvm.ptr.gen.to.param ++-def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), ++- (IMOV32rr Int32Regs:$src)>; +++def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), +++ (ins Int32Regs:$src), +++ "mov.u32 \t$result, $src;", +++ [(set Int32Regs:$result, +++ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; +++def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), +++ (ins Int64Regs:$src), +++ "mov.u64 \t$result, $src;", +++ [(set Int64Regs:$result, +++ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; ++ ++-def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), ++- (IMOV64rr Int64Regs:$src)>; ++ ++ // nvvm.move intrinsicc ++ def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), ++@@ -2618,6 +2632,24 @@ ++ [(set Int64Regs:$r, ++ (int_nvvm_move_ptr texternalsym:$s))]>;*/ ++ +++ +++// MoveParam %r1, param +++// ptr_local_to_gen %r2, %r1 +++// ptr_gen_to_local %r3, %r2 +++// -> +++// mov %r1, param +++ +++// @TODO: Revisit this. There is a type +++// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym +++// instructions are not currently defined. However, we can use the ptr +++// variants and the asm printer will do the right thing. +++def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen +++ (MoveParam texternalsym:$src)))), +++ (nvvm_move_ptr64 texternalsym:$src)>; +++def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen +++ (MoveParam texternalsym:$src)))), +++ (nvvm_move_ptr32 texternalsym:$src)>; +++ ++ def texsurf_handles ++ : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), ++ "mov.u64 \t$result, $src;", []>; ++@@ -2701,9 +2733,134 @@ ++ def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; ++ ++ +++// rotate builtin support +++ +++def ROTATE_B32_HW_IMM +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$src, i32imm:$amt), +++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, +++ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, +++ Requires<[hasHWROT32]> ; +++ +++def ROTATE_B32_HW_REG +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$src, Int32Regs:$amt), +++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, +++ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, +++ Requires<[hasHWROT32]> ; +++ +++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), +++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, +++ Requires<[noHWROT32]> ; +++ +++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), +++ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, +++ Requires<[noHWROT32]> ; +++ +++let hasSideEffects = false in { +++ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), +++ !strconcat("{{\n\t", +++ ".reg .b32 %dummy;\n\t", +++ "mov.b64 \t{$dst,%dummy}, $src;\n\t", +++ "}}"), +++ []> ; +++ +++ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), +++ !strconcat("{{\n\t", +++ ".reg .b32 %dummy;\n\t", +++ "mov.b64 \t{%dummy,$dst}, $src;\n\t", +++ "}}"), +++ []> ; +++} +++ +++let hasSideEffects = false in { +++ def PACK_TWO_INT32 +++ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), +++ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; +++} +++ ++ def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), ++- (V2I32toI64 (I64toI32H Int64Regs:$src), ++- (I64toI32L Int64Regs:$src))> ; +++ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src))> ; +++ +++// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so +++// no side effects. +++let hasSideEffects = false in { +++ def SHF_L_WRAP_B32_IMM +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +++ Requires<[hasHWROT32]>; +++ +++ def SHF_L_WRAP_B32_REG +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +++ Requires<[hasHWROT32]>; +++ +++ def SHF_R_WRAP_B32_IMM +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +++ Requires<[hasHWROT32]>; +++ +++ def SHF_R_WRAP_B32_REG +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +++ Requires<[hasHWROT32]>; +++} +++ +++// HW version of rotate 64 +++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), +++ (PACK_TWO_INT32 +++ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src), imm:$amt), +++ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), +++ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, +++ Requires<[hasHWROT32]>; +++ +++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), +++ (PACK_TWO_INT32 +++ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), +++ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), +++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, +++ Requires<[hasHWROT32]>; +++ +++ +++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), +++ (PACK_TWO_INT32 +++ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), +++ (GET_HI_INT64 Int64Regs:$src), imm:$amt), +++ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, +++ Requires<[hasHWROT32]>; +++ +++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), +++ (PACK_TWO_INT32 +++ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), +++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), +++ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, +++ Requires<[hasHWROT32]>; +++ +++// SW version of rotate 64 +++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), +++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, +++ Requires<[noHWROT32]>; +++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), +++ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, +++ Requires<[noHWROT32]>; +++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), +++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, +++ Requires<[noHWROT32]>; +++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), +++ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, +++ Requires<[noHWROT32]>; +++ ++ ++ //----------------------------------- ++ // Texture Intrinsics ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp ++--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp ++@@ -1109,21 +1109,11 @@ ++ AddrSpaceCastSDNode *CastN = cast(N); ++ unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); ++ unsigned DstAddrSpace = CastN->getDestAddressSpace(); ++- SDLoc DL(N); ++ assert(SrcAddrSpace != DstAddrSpace && ++ "addrspacecast must be between different address spaces"); ++ ++ if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { ++ // Specific to generic ++- ++- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { ++- SDValue CvtNone = ++- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); ++- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, ++- Src, CvtNone); ++- Src = SDValue(Cvt, 0); ++- } ++- ++ unsigned Opc; ++ switch (SrcAddrSpace) { ++ default: report_fatal_error("Bad address space in addrspacecast"); ++@@ -1131,16 +1121,26 @@ ++ Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; ++ break; ++ case ADDRESS_SPACE_SHARED: ++- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +++ ? NVPTX::cvta_shared_6432 +++ : NVPTX::cvta_shared_64) +++ : NVPTX::cvta_shared; ++ break; ++ case ADDRESS_SPACE_CONST: ++- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +++ ? NVPTX::cvta_const_6432 +++ : NVPTX::cvta_const_64) +++ : NVPTX::cvta_const; ++ break; ++ case ADDRESS_SPACE_LOCAL: ++- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +++ ? NVPTX::cvta_local_6432 +++ : NVPTX::cvta_local_64) +++ : NVPTX::cvta_local; ++ break; ++ } ++- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); +++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), +++ Src)); ++ return; ++ } else { ++ // Generic to specific ++@@ -1153,28 +1153,30 @@ ++ Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; ++ break; ++ case ADDRESS_SPACE_SHARED: ++- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +++ ? NVPTX::cvta_to_shared_3264 +++ : NVPTX::cvta_to_shared_64) +++ : NVPTX::cvta_to_shared; ++ break; ++ case ADDRESS_SPACE_CONST: ++- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +++ ? NVPTX::cvta_to_const_3264 +++ : NVPTX::cvta_to_const_64) +++ : NVPTX::cvta_to_const; ++ break; ++ case ADDRESS_SPACE_LOCAL: ++- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +++ ? NVPTX::cvta_to_local_3264 +++ : NVPTX::cvta_to_local_64) +++ : NVPTX::cvta_to_local; ++ break; ++ case ADDRESS_SPACE_PARAM: ++- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; +++ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 +++ : NVPTX::nvvm_ptr_gen_to_param; ++ break; ++ } ++- ++- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); ++- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { ++- SDValue CvtNone = ++- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); ++- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, ++- SDValue(CVTA, 0), CvtNone); ++- } ++- ++- ReplaceNode(N, CVTA); +++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), +++ Src)); ++ return; ++ } ++ } ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp ++--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp ++@@ -594,13 +594,20 @@ ++ setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); ++ setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); ++ ++- setOperationAction({ISD::ROTL, ISD::ROTR}, ++- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, ++- Expand); ++- ++- if (STI.hasHWROT32()) ++- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); +++ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs +++ // that don't have h/w rotation we lower them to multi-instruction assembly. +++ // See ROT*_sw in NVPTXIntrInfo.td +++ setOperationAction(ISD::ROTL, MVT::i64, Legal); +++ setOperationAction(ISD::ROTR, MVT::i64, Legal); +++ setOperationAction(ISD::ROTL, MVT::i32, Legal); +++ setOperationAction(ISD::ROTR, MVT::i32, Legal); ++ +++ setOperationAction(ISD::ROTL, MVT::i16, Expand); +++ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); +++ setOperationAction(ISD::ROTR, MVT::i16, Expand); +++ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); +++ setOperationAction(ISD::ROTL, MVT::i8, Expand); +++ setOperationAction(ISD::ROTR, MVT::i8, Expand); ++ setOperationAction(ISD::BSWAP, MVT::i16, Expand); ++ ++ setOperationAction(ISD::BR_JT, MVT::Other, Custom); ++diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll ++--- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +++++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll ++@@ -31,19 +31,6 @@ ++ declare i64 @llvm.nvvm.bitcast.d2ll(double) ++ declare double @llvm.nvvm.bitcast.ll2d(i64) ++ ++-declare i32 @llvm.nvvm.rotate.b32(i32, i32) ++-declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) ++-declare i64 @llvm.nvvm.rotate.b64(i64, i32) ++- ++-declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) ++-declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) ++-declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) ++-declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) ++-declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) ++-declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) ++-declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) ++-declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) ++- ++ ; CHECK-LABEL: @simple_upgrade ++ define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { ++ ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) ++@@ -152,42 +139,4 @@ ++ %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) ++ ++ ret void ++-} ++- ++-; CHECK-LABEL: @rotate ++-define void @rotate(i32 %a, i64 %b) { ++-; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) ++-; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) ++-; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) ++-; ++- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) ++- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) ++- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) ++- ret void ++-} ++- ++-; CHECK-LABEL: @addrspacecast ++-define void @addrspacecast(ptr %p0) { ++-; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) ++-; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr ++-; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) ++-; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr ++-; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) ++-; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr ++-; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) ++-; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr ++-; ++- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) ++- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) ++- ++- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) ++- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) ++- ++- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) ++- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) ++- ++- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) ++- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) ++- ++- ret void ++-} +++} ++\ No newline at end of file ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll ++--- a/llvm/test/CodeGen/AMDGPU/freeze.ll +++++ b/llvm/test/CodeGen/AMDGPU/freeze.ll ++@@ -1,1856 +0,0 @@ ++-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ++-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s ++-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s ++-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s ++-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s ++- ++-define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v2i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v2i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <2 x i32> %a ++- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v3i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v3i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <3 x i32> %a ++- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v4i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v4i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <4 x i32> %a ++- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v5i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v5i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v5i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v5i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <5 x i32> %a ++- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v6i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v6i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v6i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v6i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <6 x i32> %a ++- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v7i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v7i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v7i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v7i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <7 x i32> %a ++- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v8i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v8i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v8i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v8i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <8 x i32> %a ++- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v9i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x2 ++-; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v9i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x2 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v9i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x2 ++-; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v9i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x2 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <9 x i32> %a ++- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v10i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: s_clause 0x2 ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 ++-; GFX10-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v10i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: s_clause 0x2 ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 ++-; GFX11-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <10 x i32> %a ++- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v11i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: s_clause 0x2 ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 ++-; GFX10-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v11i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: s_clause 0x2 ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 ++-; GFX11-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <11 x i32> %a ++- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v12i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: s_clause 0x2 ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v12i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: s_clause 0x2 ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <12 x i32> %a ++- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++-define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v13i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x3 ++-; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v13i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x3 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v13i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x3 ++-; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v13i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x3 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <13 x i32> %a ++- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v14i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x3 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v14i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x3 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v14i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x3 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v14i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x3 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <14 x i32> %a ++- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v15i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x3 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v15i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x3 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v15i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x3 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v15i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x3 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <15 x i32> %a ++- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v16i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x3 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v16i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x3 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v16i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x3 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v16i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x3 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <16 x i32> %a ++- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v17i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x4 ++-; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v17i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x4 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v17i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x4 ++-; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v17i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x4 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <17 x i32> %a ++- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v18i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x4 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v18i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x4 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v18i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x4 ++-; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v18i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x4 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <18 x i32> %a ++- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v19i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x4 ++-; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v19i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x4 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v19i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x4 ++-; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v19i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x4 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <19 x i32> %a ++- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v20i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x4 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v20i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x4 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v20i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x4 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v20i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x4 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <20 x i32> %a ++- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v21i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x5 ++-; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v21i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x5 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v21i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x5 ++-; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v21i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x5 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <21 x i32> %a ++- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v22i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x5 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v22i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x5 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v22i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x5 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v22i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x5 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <22 x i32> %a ++- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v30i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x7 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v30i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x7 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v30i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x7 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 ++-; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v30i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x7 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <30 x i32> %a ++- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v31i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x7 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 ++-; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v31i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x7 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 ++-; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v31i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x7 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 ++-; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v31i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x7 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 ++-; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <31 x i32> %a ++- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v32i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x7 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v32i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x7 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v32i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x7 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v32i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x7 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 ++-; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <32 x i32> %a ++- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dword v0, v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dword v[2:3], v0, off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b32 v0, v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b32 v[2:3], v0, off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load i32, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze i32 %a ++- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_i64: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_i64: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load i64, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze i64 %a ++- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_float: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dword v0, v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dword v[2:3], v0, off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_float: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b32 v0, v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b32 v[2:3], v0, off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load float, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze float %a ++- store float %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_i128: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_i128: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load i128, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze i128 %a ++- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_i256: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_i256: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_i256: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_i256: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load i256, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze i256 %a ++- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir ++@@ -171,9 +171,11 @@ ++ ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 ++ ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 ++ ; GCN-NEXT: {{ $}} ++- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF ++- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) ++- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) +++ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF +++ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 +++ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 +++ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 +++ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] ++ %0:sgpr(s192) = G_IMPLICIT_DEF ++ %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 ++ S_ENDPGM 0, implicit %1, implicit %2, implicit %3 ++@@ -292,11 +294,11 @@ ++ ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) ++ ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) ++ ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) ++- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) ++- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) ++- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) ++- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) ++- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) +++ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) +++ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) +++ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) +++ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) +++ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) ++ %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 ++ %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 ++ %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir ++@@ -171,8 +171,12 @@ ++ ++ ; CHECK-LABEL: name: test_freeze_s448 ++ ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 ++- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] ++- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) +++ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) +++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) +++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) +++ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) ++ %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 ++ %1:_(s448) = G_TRUNC %0 ++ %2:_(s448) = G_FREEZE %1 ++@@ -395,12 +399,14 @@ ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_freeze_v33s32 ++- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] ++- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] ++- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) ++- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) +++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) +++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) +++ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) ++ ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) ++ %0:_(<33 x s32>) = G_IMPLICIT_DEF ++ %1:_(<33 x s32>) = G_FREEZE %0 ++@@ -413,10 +419,12 @@ ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_freeze_v64s32 ++- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] ++- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] ++- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) +++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) ++ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) ++ %0:_(<64 x s32>) = G_IMPLICIT_DEF ++ %1:_(<64 x s32>) = G_FREEZE %0 ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir ++@@ -135,9 +135,8 @@ ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_implicit_def_s448 ++- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) ++- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 +++ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 ++ ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) ++ %0:_(s448) = G_IMPLICIT_DEF ++ %1:_(s32) = G_EXTRACT %0, 0 ++@@ -297,6 +296,18 @@ ++ ... ++ ++ --- +++name: test_implicit_def_v17s32 +++body: | +++ bb.0: +++ +++ ; CHECK-LABEL: name: test_implicit_def_v17s32 +++ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) +++ %0:_(<17 x s32>) = G_IMPLICIT_DEF +++ S_NOP 0, implicit %0 +++... +++ +++--- ++ name: test_implicit_def_v32s32 ++ body: | ++ bb.0: ++@@ -317,9 +328,9 @@ ++ ; CHECK-LABEL: name: test_implicit_def_v33s32 ++ ; CHECK: liveins: $vgpr0_vgpr1 ++ ; CHECK-NEXT: {{ $}} ++- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 ++ ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) ++ ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) ++@@ -337,9 +348,10 @@ ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_implicit_def_v64s32 ++- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) +++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) ++ %0:_(<64 x s32>) = G_IMPLICIT_DEF ++ %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 ++ S_NOP 0, implicit %0, implicit %1 ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir ++@@ -190,11 +190,13 @@ ++ ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 ++ ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 ++ ; CHECK-NEXT: {{ $}} ++- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 ++ ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 ++- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) ++ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 ++ ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) ++@@ -241,8 +243,10 @@ ++ ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 ++ ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) ++ ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) ++- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) ++ ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) ++ ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir ++@@ -673,86 +673,88 @@ ++ ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) ++ ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 ++ ; CHECK-NEXT: {{ $}} ++- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 ++ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 ++ ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] ++- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++ ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 ++ ; CHECK-NEXT: G_BR %bb.2 ++ ; CHECK-NEXT: {{ $}} ++ ; CHECK-NEXT: bb.1: ++ ; CHECK-NEXT: successors: %bb.2(0x80000000) ++ ; CHECK-NEXT: {{ $}} ++- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] ++- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] ++- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] ++- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] ++- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] ++- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] ++- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] ++- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] ++- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] ++- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] ++- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] ++- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] ++- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] ++- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] ++- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] ++- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] ++- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] ++- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] ++- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] ++- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] ++- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] ++- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] ++- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] ++- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] ++- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] ++- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] ++- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] ++- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] ++- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] ++- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] ++- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] ++- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] ++- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] ++- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] ++- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] ++- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] ++- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] ++- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] ++- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] ++- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] ++- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] ++- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] ++- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] ++- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] ++- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] ++- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] ++- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] ++- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] ++- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] ++- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] ++- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] ++- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] ++- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] ++- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] ++- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] ++- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] ++- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] ++- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] ++- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] ++- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] ++- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] ++- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] ++- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] ++- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] +++ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] +++ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] +++ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] +++ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] +++ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] +++ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] +++ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] +++ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] +++ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] +++ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] +++ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] +++ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] +++ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] +++ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] +++ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] +++ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] +++ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] +++ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] +++ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] +++ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] +++ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] +++ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] +++ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] +++ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] +++ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] +++ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] +++ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] +++ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] +++ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] +++ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] +++ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] +++ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] +++ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] +++ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] +++ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] +++ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] +++ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] +++ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] +++ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] +++ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] +++ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] +++ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] +++ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] +++ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] +++ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] +++ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] +++ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] +++ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] +++ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] +++ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] +++ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] +++ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] +++ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] +++ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] +++ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] +++ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] +++ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] +++ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] +++ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] +++ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] +++ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] +++ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] +++ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] ++ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) ++ ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) ++ ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) ++@@ -760,10 +762,10 @@ ++ ; CHECK-NEXT: G_BR %bb.2 ++ ; CHECK-NEXT: {{ $}} ++ ; CHECK-NEXT: bb.2: ++- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 ++- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 ++- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 ++- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 +++ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 +++ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 +++ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 +++ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 ++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) ++ ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) ++ bb.0: ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir ++@@ -42,6 +42,8 @@ ++ ret void ++ } ++ +++ define void @non_power_of_2() { ret void } +++ ++ define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { ++ ret void ++ } ++@@ -185,6 +187,23 @@ ++ ... ++ ++ --- +++name: non_power_of_2 +++legalized: true +++ +++body: | +++ bb.0: +++ ; CHECK-LABEL: name: non_power_of_2 +++ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 +++ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) +++ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 +++ %0:_(s448) = G_IMPLICIT_DEF +++ %1:_(s32) = G_EXTRACT %0:_(s448), 0 +++ $sgpr0 = COPY %1:_(s32) +++ SI_RETURN_TO_EPILOG $sgpr0 +++... +++ +++--- ++ name: load_constant_v4i16_from_8_align8 ++ legalized: true ++ ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll ++--- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +++++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll ++@@ -0,0 +1,21 @@ +++; RUN: opt < %s -O3 -S | FileCheck %s +++ +++; Address space intrinsics were erroneously marked NoCapture, leading to bad +++; optimizations (such as the store below being eliminated as dead code). This +++; test makes sure we don't regress. +++ +++declare void @foo(ptr addrspace(1)) +++ +++declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +++ +++; CHECK: @bar +++define void @bar() { +++ %t1 = alloca i32 +++; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) +++; CHECK-NEXT: store i32 10, ptr %t1 +++ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) +++ store i32 10, ptr %t1 +++ call void @foo(ptr addrspace(1) %t2) +++ ret void +++} +++ ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll ++--- a/llvm/test/CodeGen/NVPTX/rotate_64.ll +++++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll ++@@ -1,38 +1,25 @@ ++-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 ++ ; RUN: llc < %s -march=nvptx64 | FileCheck %s ++ ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} ++ ++ declare i64 @llvm.nvvm.rotate.b64(i64, i32) ++ declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) ++ +++; CHECK: rotate64 ++ define i64 @rotate64(i64 %a, i32 %b) { ++-; CHECK-LABEL: rotate64( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b64 %rd<5>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; ++-; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; ++-; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; ++-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; ++-; CHECK-NEXT: ret; +++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; +++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; +++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; +++; CHECK: ret ++ %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) ++ ret i64 %val ++ } ++ +++; CHECK: rotateright64 ++ define i64 @rotateright64(i64 %a, i32 %b) { ++-; CHECK-LABEL: rotateright64( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b64 %rd<5>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; ++-; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; ++-; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; ++-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; ++-; CHECK-NEXT: ret; +++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; +++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; +++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; +++; CHECK: ret ++ %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) ++ ret i64 %val ++ } ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll ++--- a/llvm/test/CodeGen/NVPTX/rotate.ll +++++ b/llvm/test/CodeGen/NVPTX/rotate.ll ++@@ -9,29 +9,26 @@ ++ declare i64 @llvm.nvvm.rotate.b64(i64, i32) ++ declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) ++ ++-declare i64 @llvm.fshl.i64(i64, i64, i64) ++-declare i64 @llvm.fshr.i64(i64, i64, i64) ++-declare i32 @llvm.fshl.i32(i32, i32, i32) ++-declare i32 @llvm.fshr.i32(i32, i32, i32) ++- ++- ++ ; SM20: rotate32 ++ ; SM35: rotate32 ++ define i32 @rotate32(i32 %a, i32 %b) { ++ ; SM20-LABEL: rotate32( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<9>; +++; SM20-NEXT: .reg .b32 %r<4>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; ++-; SM20-NEXT: and.b32 %r3, %r2, 31; ++-; SM20-NEXT: shl.b32 %r4, %r1, %r3; ++-; SM20-NEXT: neg.s32 %r5, %r2; ++-; SM20-NEXT: and.b32 %r6, %r5, 31; ++-; SM20-NEXT: shr.u32 %r7, %r1, %r6; ++-; SM20-NEXT: or.b32 %r8, %r4, %r7; ++-; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b32 %lhs; +++; SM20-NEXT: .reg .b32 %rhs; +++; SM20-NEXT: .reg .b32 %amt2; +++; SM20-NEXT: shl.b32 %lhs, %r1, %r2; +++; SM20-NEXT: sub.s32 %amt2, 32, %r2; +++; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; +++; SM20-NEXT: add.u32 %r3, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotate32( ++@@ -53,36 +50,45 @@ ++ define i64 @rotate64(i64 %a, i32 %b) { ++ ; SM20-LABEL: rotate64( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b32 %r<2>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM20-NEXT: neg.s32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; ++-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: .reg .u32 %amt2; +++; SM20-NEXT: and.b32 %amt2, %r1, 63; +++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; +++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotate64( ++ ; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b32 %r<6>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; ++-; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM35-NEXT: neg.s32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; ++-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b32 %dummy; +++; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; +++; SM35-NEXT: } +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b32 %dummy; +++; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; +++; SM35-NEXT: } +++; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; +++; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; +++; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; +++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) ++ ret i64 %val ++@@ -93,36 +99,45 @@ ++ define i64 @rotateright64(i64 %a, i32 %b) { ++ ; SM20-LABEL: rotateright64( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b32 %r<2>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; ++-; SM20-NEXT: neg.s32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; ++-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: .reg .u32 %amt2; +++; SM20-NEXT: and.b32 %amt2, %r1, 63; +++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; +++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotateright64( ++ ; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b32 %r<6>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; ++-; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; ++-; SM35-NEXT: neg.s32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; ++-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b32 %dummy; +++; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; +++; SM35-NEXT: } +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b32 %dummy; +++; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; +++; SM35-NEXT: } +++; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; +++; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; +++; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; +++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) ++ ret i64 %val ++@@ -133,14 +148,18 @@ ++ define i32 @rotl0(i32 %x) { ++ ; SM20-LABEL: rotl0( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; +++; SM20-NEXT: .reg .b32 %r<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; ++-; SM20-NEXT: shr.u32 %r2, %r1, 24; ++-; SM20-NEXT: shl.b32 %r3, %r1, 8; ++-; SM20-NEXT: or.b32 %r4, %r3, %r2; ++-; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b32 %lhs; +++; SM20-NEXT: .reg .b32 %rhs; +++; SM20-NEXT: shl.b32 %lhs, %r1, 8; +++; SM20-NEXT: shr.b32 %rhs, %r1, 24; +++; SM20-NEXT: add.u32 %r2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotl0( ++@@ -158,40 +177,51 @@ ++ ret i32 %t2 ++ } ++ +++declare i64 @llvm.fshl.i64(i64, i64, i64) +++declare i64 @llvm.fshr.i64(i64, i64, i64) +++ ++ ; SM35: rotl64 ++ define i64 @rotl64(i64 %a, i64 %n) { ++ ; SM20-LABEL: rotl64( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b32 %r<2>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM20-NEXT: neg.s32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; ++-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: .reg .u32 %amt2; +++; SM20-NEXT: and.b32 %amt2, %r1, 63; +++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; +++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotl64( ++ ; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b32 %r<2>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; ++ ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM35-NEXT: neg.s32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; ++-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b64 %lhs; +++; SM35-NEXT: .reg .b64 %rhs; +++; SM35-NEXT: .reg .u32 %amt2; +++; SM35-NEXT: and.b32 %amt2, %r1, 63; +++; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; +++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; +++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM35-NEXT: } +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) ++ ret i64 %val ++@@ -201,26 +231,34 @@ ++ define i64 @rotl64_imm(i64 %a) { ++ ; SM20-LABEL: rotl64_imm( ++ ; SM20: { ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; ++-; SM20-NEXT: shr.u64 %rd2, %rd1, 62; ++-; SM20-NEXT: shl.b64 %rd3, %rd1, 2; ++-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: shl.b64 %lhs, %rd1, 2; +++; SM20-NEXT: shr.b64 %rhs, %rd1, 62; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotl64_imm( ++ ; SM35: { ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; ++-; SM35-NEXT: shr.u64 %rd2, %rd1, 62; ++-; SM35-NEXT: shl.b64 %rd3, %rd1, 2; ++-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b64 %lhs; +++; SM35-NEXT: .reg .b64 %rhs; +++; SM35-NEXT: shl.b64 %lhs, %rd1, 2; +++; SM35-NEXT: shr.b64 %rhs, %rd1, 62; +++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM35-NEXT: } +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) ++ ret i64 %val ++@@ -230,36 +268,44 @@ ++ define i64 @rotr64(i64 %a, i64 %n) { ++ ; SM20-LABEL: rotr64( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b32 %r<2>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; ++-; SM20-NEXT: neg.s32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; ++-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: .reg .u32 %amt2; +++; SM20-NEXT: and.b32 %amt2, %r1, 63; +++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; +++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotr64( ++ ; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b32 %r<2>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; ++ ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; ++-; SM35-NEXT: neg.s32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; ++-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b64 %lhs; +++; SM35-NEXT: .reg .b64 %rhs; +++; SM35-NEXT: .reg .u32 %amt2; +++; SM35-NEXT: and.b32 %amt2, %r1, 63; +++; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; +++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; +++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM35-NEXT: } +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) ++ ret i64 %val ++@@ -269,180 +315,35 @@ ++ define i64 @rotr64_imm(i64 %a) { ++ ; SM20-LABEL: rotr64_imm( ++ ; SM20: { ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; ++-; SM20-NEXT: shl.b64 %rd2, %rd1, 62; ++-; SM20-NEXT: shr.u64 %rd3, %rd1, 2; ++-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: shl.b64 %lhs, %rd1, 62; +++; SM20-NEXT: shr.b64 %rhs, %rd1, 2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotr64_imm( ++ ; SM35: { ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; ++-; SM35-NEXT: shl.b64 %rd2, %rd1, 62; ++-; SM35-NEXT: shr.u64 %rd3, %rd1, 2; ++-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b64 %lhs; +++; SM35-NEXT: .reg .b64 %rhs; +++; SM35-NEXT: shl.b64 %lhs, %rd1, 62; +++; SM35-NEXT: shr.b64 %rhs, %rd1, 2; +++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM35-NEXT: } +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) ++ ret i64 %val ++ } ++- ++-define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { ++-; SM20-LABEL: funnel_shift_right_32( ++-; SM20: { ++-; SM20-NEXT: .reg .b32 %r<11>; ++-; SM20-EMPTY: ++-; SM20-NEXT: // %bb.0: ++-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; ++-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; ++-; SM20-NEXT: and.b32 %r3, %r2, 31; ++-; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; ++-; SM20-NEXT: shr.u32 %r5, %r4, %r3; ++-; SM20-NEXT: shl.b32 %r6, %r1, 1; ++-; SM20-NEXT: not.b32 %r7, %r2; ++-; SM20-NEXT: and.b32 %r8, %r7, 31; ++-; SM20-NEXT: shl.b32 %r9, %r6, %r8; ++-; SM20-NEXT: or.b32 %r10, %r9, %r5; ++-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; ++-; SM20-NEXT: ret; ++-; ++-; SM35-LABEL: funnel_shift_right_32( ++-; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-EMPTY: ++-; SM35-NEXT: // %bb.0: ++-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; ++-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; ++-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; ++-; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; ++-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; ++-; SM35-NEXT: ret; ++- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) ++- ret i32 %val ++-} ++- ++-define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { ++-; SM20-LABEL: funnel_shift_left_32( ++-; SM20: { ++-; SM20-NEXT: .reg .b32 %r<11>; ++-; SM20-EMPTY: ++-; SM20-NEXT: // %bb.0: ++-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; ++-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; ++-; SM20-NEXT: and.b32 %r3, %r2, 31; ++-; SM20-NEXT: shl.b32 %r4, %r1, %r3; ++-; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; ++-; SM20-NEXT: shr.u32 %r6, %r5, 1; ++-; SM20-NEXT: not.b32 %r7, %r2; ++-; SM20-NEXT: and.b32 %r8, %r7, 31; ++-; SM20-NEXT: shr.u32 %r9, %r6, %r8; ++-; SM20-NEXT: or.b32 %r10, %r4, %r9; ++-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; ++-; SM20-NEXT: ret; ++-; ++-; SM35-LABEL: funnel_shift_left_32( ++-; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-EMPTY: ++-; SM35-NEXT: // %bb.0: ++-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; ++-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; ++-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; ++-; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; ++-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; ++-; SM35-NEXT: ret; ++- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) ++- ret i32 %val ++-} ++- ++-define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { ++-; SM20-LABEL: funnel_shift_right_64( ++-; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<7>; ++-; SM20-EMPTY: ++-; SM20-NEXT: // %bb.0: ++-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; ++-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; ++-; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; ++-; SM20-NEXT: shl.b64 %rd4, %rd1, 1; ++-; SM20-NEXT: not.b32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; ++-; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; ++-; SM20-NEXT: ret; ++-; ++-; SM35-LABEL: funnel_shift_right_64( ++-; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<7>; ++-; SM35-EMPTY: ++-; SM35-NEXT: // %bb.0: ++-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; ++-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; ++-; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; ++-; SM35-NEXT: shl.b64 %rd4, %rd1, 1; ++-; SM35-NEXT: not.b32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; ++-; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; ++-; SM35-NEXT: ret; ++- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) ++- ret i64 %val ++-} ++- ++-define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { ++-; SM20-LABEL: funnel_shift_left_64( ++-; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<7>; ++-; SM20-EMPTY: ++-; SM20-NEXT: // %bb.0: ++-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; ++-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; ++-; SM20-NEXT: shr.u64 %rd4, %rd3, 1; ++-; SM20-NEXT: not.b32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; ++-; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; ++-; SM20-NEXT: ret; ++-; ++-; SM35-LABEL: funnel_shift_left_64( ++-; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<7>; ++-; SM35-EMPTY: ++-; SM35-NEXT: // %bb.0: ++-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; ++-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; ++-; SM35-NEXT: shr.u64 %rd4, %rd3, 1; ++-; SM35-NEXT: not.b32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; ++-; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; ++-; SM35-NEXT: ret; ++- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) ++- ret i64 %val ++-} ++- ++diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll ++--- a/llvm/test/DebugInfo/NVPTX/debug-info.ll +++++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll ++@@ -25,10 +25,6 @@ ++ ; CHECK-DAG: .reg .b64 %rd<8>; ++ ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 ++ ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; ++-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; ++-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; ++-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; ++-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; ++ ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 ++ ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; ++ ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 ++@@ -42,6 +38,10 @@ ++ ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 ++ ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; ++ ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; +++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; ++ ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 ++ ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; ++ ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; ++@@ -2661,22 +2661,22 @@ ++ ; CHECK-NEXT:.b32 4579 // DW_AT_type ++ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine ++ ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin ++-; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc ++-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc +++; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc +++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc ++ ; CHECK-NEXT:.b8 1 // DW_AT_call_file ++ ; CHECK-NEXT:.b8 6 // DW_AT_call_line ++ ; CHECK-NEXT:.b8 11 // DW_AT_call_column ++ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine ++ ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin ++-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc ++-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc +++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc +++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc ++ ; CHECK-NEXT:.b8 1 // DW_AT_call_file ++ ; CHECK-NEXT:.b8 6 // DW_AT_call_line ++ ; CHECK-NEXT:.b8 24 // DW_AT_call_column ++ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine ++ ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin ++-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc ++-; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc +++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc +++; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc ++ ; CHECK-NEXT:.b8 1 // DW_AT_call_file ++ ; CHECK-NEXT:.b8 6 // DW_AT_call_line ++ ; CHECK-NEXT:.b8 37 // DW_AT_call_column diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 726a367..abe15ef 100644 +index abe15ef..af35fe7 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" -- LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" -+ LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" -+ LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" +- LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" +- LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" ++ LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" ++ LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index b62c918736eb7e..4f6a0785270667 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "f9efe2966f00f8e7da8f7af3f8c8b3255cc158b8" - SHARDY_SHA256 = "6ca4c5f2de2102eca2a78ab64a443b2d327fd7b0ceb8c633a67cd1a2a316a2db" + SHARDY_COMMIT = "22e68fa19cfb2d28434a75d4d20d0efc182b166a" + SHARDY_SHA256 = "2b47b0ee994feca2bd782e20aca7d709e29bc870c2ac435aca967f7664c9f949" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 5fd5f295cd7dfc..d3fd21823cce19 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,15 +1,4115 @@ +diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch +index 509398d..de92cb4 100644 +--- a/third_party/llvm/generated.patch ++++ b/third_party/llvm/generated.patch +@@ -1 +1,4095 @@ + Auto generated patch. Do not edit or delete it, even if empty. ++diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst ++--- a/llvm/docs/NVPTXUsage.rst +++++ b/llvm/docs/NVPTXUsage.rst ++@@ -127,6 +127,69 @@ ++ NVPTX Intrinsics ++ ================ ++ +++Address Space Conversion +++------------------------ +++ +++'``llvm.nvvm.ptr.*.to.gen``' Intrinsics +++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +++ +++Syntax: +++""""""" +++ +++These are overloaded intrinsics. You can use these on any pointer types. +++ +++.. code-block:: llvm +++ +++ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) +++ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) +++ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) +++ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) +++ +++Overview: +++""""""""" +++ +++The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic +++address space to a generic address space pointer. +++ +++Semantics: +++"""""""""" +++ +++These intrinsics modify the pointer value to be a valid generic address space +++pointer. +++ +++ +++'``llvm.nvvm.ptr.gen.to.*``' Intrinsics +++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +++ +++Syntax: +++""""""" +++ +++These are overloaded intrinsics. You can use these on any pointer types. +++ +++.. code-block:: llvm +++ +++ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +++ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) +++ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) +++ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) +++ +++Overview: +++""""""""" +++ +++The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic +++address space to a pointer in the target address space. Note that these +++intrinsics are only useful if the address space of the target address space of +++the pointer is known. It is not legal to use address space conversion +++intrinsics to convert a pointer from one non-generic address space to another +++non-generic address space. +++ +++Semantics: +++"""""""""" +++ +++These intrinsics modify the pointer value to be a valid pointer in the target +++non-generic address space. +++ +++ ++ Reading PTX Special Registers ++ ----------------------------- ++ ++diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst ++--- a/llvm/docs/ReleaseNotes.rst +++++ b/llvm/docs/ReleaseNotes.rst ++@@ -63,24 +63,6 @@ ++ * ``llvm.nvvm.bitcast.d2ll`` ++ * ``llvm.nvvm.bitcast.ll2d`` ++ ++-* Remove the following intrinsics which can be replaced with a funnel-shift: ++- ++- * ``llvm.nvvm.rotate.b32`` ++- * ``llvm.nvvm.rotate.right.b64`` ++- * ``llvm.nvvm.rotate.b64`` ++- ++-* Remove the following intrinsics which can be replaced with an ++- ``addrspacecast``: ++- ++- * ``llvm.nvvm.ptr.gen.to.global`` ++- * ``llvm.nvvm.ptr.gen.to.shared`` ++- * ``llvm.nvvm.ptr.gen.to.constant`` ++- * ``llvm.nvvm.ptr.gen.to.local`` ++- * ``llvm.nvvm.ptr.global.to.gen`` ++- * ``llvm.nvvm.ptr.shared.to.gen`` ++- * ``llvm.nvvm.ptr.constant.to.gen`` ++- * ``llvm.nvvm.ptr.local.to.gen`` ++- ++ Changes to LLVM infrastructure ++ ------------------------------ ++ ++diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td ++--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td ++@@ -30,18 +30,10 @@ ++ // * llvm.nvvm.max.ui --> select(x ule y, x, y) ++ // * llvm.nvvm.max.ull --> ibid. ++ // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 ++-// * llvm.nvvm.bitcast.f2i --> bitcast ++-// * llvm.nvvm.bitcast.i2f --> ibid. ++-// * llvm.nvvm.bitcast.d2ll --> ibid. ++-// * llvm.nvvm.bitcast.ll2d --> ibid. ++-// * llvm.nvvm.ptr.gen.to.global --> addrspacecast ++-// * llvm.nvvm.ptr.gen.to.shared --> ibid. ++-// * llvm.nvvm.ptr.gen.to.constant --> ibid. ++-// * llvm.nvvm.ptr.gen.to.local --> ibid. ++-// * llvm.nvvm.ptr.global.to.gen --> ibid. ++-// * llvm.nvvm.ptr.shared.to.gen --> ibid. ++-// * llvm.nvvm.ptr.constant.to.gen --> ibid. ++-// * llvm.nvvm.ptr.local.to.gen --> ibid. +++// * llvm.nvvm.bitcast.f2i --> bitcast +++// * llvm.nvvm.bitcast.i2f --> ibid. +++// * llvm.nvvm.bitcast.d2ll --> ibid. +++// * llvm.nvvm.bitcast.ll2d --> ibid. ++ ++ def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr ++ def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr ++@@ -1610,6 +1602,40 @@ ++ [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], ++ "llvm.nvvm.ldg.global.p">; ++ +++// Use for generic pointers +++// - These intrinsics are used to convert address spaces. +++// - The input pointer and output pointer must have the same type, except for +++// the address-space. (This restriction is not enforced here as there is +++// currently no way to describe it). +++// - This complements the llvm bitcast, which can be used to cast one type +++// of pointer to another type of pointer, while the address space remains +++// the same. +++def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.local.to.gen">; +++def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.shared.to.gen">; +++def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.global.to.gen">; +++def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.constant.to.gen">; +++ +++def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.gen.to.global">; +++def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.gen.to.shared">; +++def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.gen.to.local">; +++def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +++ "llvm.nvvm.ptr.gen.to.constant">; +++ ++ // Used in nvvm internally to help address space opt and ptx code generation ++ // This is for params that are passed to kernel functions by pointer by-val. ++ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], ++@@ -4453,6 +4479,22 @@ ++ "llvm.nvvm.sust.p.3d.v4i32.trap">, ++ ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; ++ +++ +++def int_nvvm_rotate_b32 +++ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], +++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, +++ ClangBuiltin<"__nvvm_rotate_b32">; +++ +++def int_nvvm_rotate_b64 +++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], +++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, +++ ClangBuiltin<"__nvvm_rotate_b64">; +++ +++def int_nvvm_rotate_right_b64 +++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], +++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, +++ ClangBuiltin<"__nvvm_rotate_right_b64">; +++ ++ def int_nvvm_swap_lo_hi_b64 ++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], ++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, ++diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp ++--- a/llvm/lib/IR/AutoUpgrade.cpp +++++ b/llvm/lib/IR/AutoUpgrade.cpp ++@@ -1272,19 +1272,6 @@ ++ // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} ++ Expand = ++ Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; ++- else if (Name.consume_front("rotate.")) ++- // nvvm.rotate.{b32,b64,right.b64} ++- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; ++- else if (Name.consume_front("ptr.gen.to.")) ++- // nvvm.ptr.gen.to.{local,shared,global,constant} ++- Expand = Name.starts_with("local") || Name.starts_with("shared") || ++- Name.starts_with("global") || Name.starts_with("constant"); ++- else if (Name.consume_front("ptr.")) ++- // nvvm.ptr.{local,shared,global,constant}.to.gen ++- Expand = ++- (Name.consume_front("local") || Name.consume_front("shared") || ++- Name.consume_front("global") || Name.consume_front("constant")) && ++- Name.starts_with(".to.gen"); ++ else ++ Expand = false; ++ ++@@ -2271,117 +2258,6 @@ ++ } ++ } ++ ++-static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, ++- Function *F, IRBuilder<> &Builder) { ++- Value *Rep = nullptr; ++- ++- if (Name == "abs.i" || Name == "abs.ll") { ++- Value *Arg = CI->getArgOperand(0); ++- Value *Neg = Builder.CreateNeg(Arg, "neg"); ++- Value *Cmp = Builder.CreateICmpSGE( ++- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); ++- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); ++- } else if (Name.starts_with("atomic.load.add.f32.p") || ++- Name.starts_with("atomic.load.add.f64.p")) { ++- Value *Ptr = CI->getArgOperand(0); ++- Value *Val = CI->getArgOperand(1); ++- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), ++- AtomicOrdering::SequentiallyConsistent); ++- } else if (Name.consume_front("max.") && ++- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || ++- Name == "ui" || Name == "ull")) { ++- Value *Arg0 = CI->getArgOperand(0); ++- Value *Arg1 = CI->getArgOperand(1); ++- Value *Cmp = Name.starts_with("u") ++- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") ++- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); ++- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); ++- } else if (Name.consume_front("min.") && ++- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || ++- Name == "ui" || Name == "ull")) { ++- Value *Arg0 = CI->getArgOperand(0); ++- Value *Arg1 = CI->getArgOperand(1); ++- Value *Cmp = Name.starts_with("u") ++- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") ++- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); ++- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); ++- } else if (Name == "clz.ll") { ++- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. ++- Value *Arg = CI->getArgOperand(0); ++- Value *Ctlz = Builder.CreateCall( ++- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, ++- {Arg->getType()}), ++- {Arg, Builder.getFalse()}, "ctlz"); ++- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); ++- } else if (Name == "popc.ll") { ++- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an ++- // i64. ++- Value *Arg = CI->getArgOperand(0); ++- Value *Popc = Builder.CreateCall( ++- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, ++- {Arg->getType()}), ++- Arg, "ctpop"); ++- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); ++- } else if (Name == "h2f") { ++- Rep = Builder.CreateCall( ++- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, ++- {Builder.getFloatTy()}), ++- CI->getArgOperand(0), "h2f"); ++- } else if (Name.consume_front("bitcast.") && ++- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || ++- Name == "d2ll")) { ++- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); ++- } else if (Name == "rotate.b32") { ++- Value *Arg = CI->getOperand(0); ++- Value *ShiftAmt = CI->getOperand(1); ++- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, ++- {Arg, Arg, ShiftAmt}); ++- } else if (Name == "rotate.b64") { ++- Type *Int64Ty = Builder.getInt64Ty(); ++- Value *Arg = CI->getOperand(0); ++- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); ++- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, ++- {Arg, Arg, ZExtShiftAmt}); ++- } else if (Name == "rotate.right.b64") { ++- Type *Int64Ty = Builder.getInt64Ty(); ++- Value *Arg = CI->getOperand(0); ++- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); ++- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, ++- {Arg, Arg, ZExtShiftAmt}); ++- } else if ((Name.consume_front("ptr.gen.to.") && ++- (Name.starts_with("local") || Name.starts_with("shared") || ++- Name.starts_with("global") || Name.starts_with("constant"))) || ++- (Name.consume_front("ptr.") && ++- (Name.consume_front("local") || Name.consume_front("shared") || ++- Name.consume_front("global") || ++- Name.consume_front("constant")) && ++- Name.starts_with(".to.gen"))) { ++- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); ++- } else { ++- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); ++- if (IID != Intrinsic::not_intrinsic && ++- !F->getReturnType()->getScalarType()->isBFloatTy()) { ++- rename(F); ++- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); ++- SmallVector Args; ++- for (size_t I = 0; I < NewFn->arg_size(); ++I) { ++- Value *Arg = CI->getArgOperand(I); ++- Type *OldType = Arg->getType(); ++- Type *NewType = NewFn->getArg(I)->getType(); ++- Args.push_back( ++- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) ++- ? Builder.CreateBitCast(Arg, NewType) ++- : Arg); ++- } ++- Rep = Builder.CreateCall(NewFn, Args); ++- if (F->getReturnType()->isIntegerTy()) ++- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); ++- } ++- } ++- ++- return Rep; ++-} ++- ++ static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, ++ IRBuilder<> &Builder) { ++ LLVMContext &C = F->getContext(); ++@@ -4332,8 +4208,85 @@ ++ ++ if (!IsX86 && Name == "stackprotectorcheck") { ++ Rep = nullptr; +++ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { +++ Value *Arg = CI->getArgOperand(0); +++ Value *Neg = Builder.CreateNeg(Arg, "neg"); +++ Value *Cmp = Builder.CreateICmpSGE( +++ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); +++ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); +++ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || +++ Name.starts_with("atomic.load.add.f64.p"))) { +++ Value *Ptr = CI->getArgOperand(0); +++ Value *Val = CI->getArgOperand(1); +++ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), +++ AtomicOrdering::SequentiallyConsistent); +++ } else if (IsNVVM && Name.consume_front("max.") && +++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +++ Name == "ui" || Name == "ull")) { +++ Value *Arg0 = CI->getArgOperand(0); +++ Value *Arg1 = CI->getArgOperand(1); +++ Value *Cmp = Name.starts_with("u") +++ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") +++ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); +++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); +++ } else if (IsNVVM && Name.consume_front("min.") && +++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +++ Name == "ui" || Name == "ull")) { +++ Value *Arg0 = CI->getArgOperand(0); +++ Value *Arg1 = CI->getArgOperand(1); +++ Value *Cmp = Name.starts_with("u") +++ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") +++ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); +++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); +++ } else if (IsNVVM && Name == "clz.ll") { +++ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. +++ Value *Arg = CI->getArgOperand(0); +++ Value *Ctlz = Builder.CreateCall( +++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, +++ {Arg->getType()}), +++ {Arg, Builder.getFalse()}, "ctlz"); +++ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); +++ } else if (IsNVVM && Name == "popc.ll") { +++ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an +++ // i64. +++ Value *Arg = CI->getArgOperand(0); +++ Value *Popc = Builder.CreateCall( +++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, +++ {Arg->getType()}), +++ Arg, "ctpop"); +++ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); ++ } else if (IsNVVM) { ++- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); +++ if (Name == "h2f") { +++ Rep = +++ Builder.CreateCall(Intrinsic::getDeclaration( +++ F->getParent(), Intrinsic::convert_from_fp16, +++ {Builder.getFloatTy()}), +++ CI->getArgOperand(0), "h2f"); +++ } else if (Name.consume_front("bitcast.") && +++ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || +++ Name == "d2ll")) { +++ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); +++ } else { +++ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); +++ if (IID != Intrinsic::not_intrinsic && +++ !F->getReturnType()->getScalarType()->isBFloatTy()) { +++ rename(F); +++ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); +++ SmallVector Args; +++ for (size_t I = 0; I < NewFn->arg_size(); ++I) { +++ Value *Arg = CI->getArgOperand(I); +++ Type *OldType = Arg->getType(); +++ Type *NewType = NewFn->getArg(I)->getType(); +++ Args.push_back((OldType->isIntegerTy() && +++ NewType->getScalarType()->isBFloatTy()) +++ ? Builder.CreateBitCast(Arg, NewType) +++ : Arg); +++ } +++ Rep = Builder.CreateCall(NewFn, Args); +++ if (F->getReturnType()->isIntegerTy()) +++ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); +++ } +++ } ++ } else if (IsX86) { ++ Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); ++ } else if (IsARM) { ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp ++--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp ++@@ -292,7 +292,6 @@ ++ static const LLT S224 = LLT::scalar(224); ++ static const LLT S256 = LLT::scalar(256); ++ static const LLT S512 = LLT::scalar(512); ++-static const LLT S1024 = LLT::scalar(1024); ++ static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); ++ ++ static const LLT V2S8 = LLT::fixed_vector(2, 8); ++@@ -333,8 +332,8 @@ ++ static const LLT V2S128 = LLT::fixed_vector(2, 128); ++ static const LLT V4S128 = LLT::fixed_vector(4, 128); ++ ++-static std::initializer_list AllScalarTypes = { ++- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; +++static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, +++ S160, S224, S256, S512}; ++ ++ static std::initializer_list AllS16Vectors{ ++ V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; ++@@ -890,11 +889,10 @@ ++ .clampScalar(0, S16, S64); ++ ++ getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) ++- .legalIf(isRegisterClassType(0)) +++ .legalIf(isRegisterType(0)) ++ // s1 and s16 are special cases because they have legal operations on ++ // them, but don't really occupy registers in the normal way. ++ .legalFor({S1, S16}) ++- .clampNumElements(0, V16S32, V32S32) ++ .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) ++ .clampScalarOrElt(0, S32, MaxScalar) ++ .widenScalarToNextPow2(0, 32) ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++@@ -174,6 +174,10 @@ ++ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" ++ "&& Subtarget->getPTXVersion() >= 64)">; ++ +++def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; +++def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; +++def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; +++ ++ def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; ++ def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; ++ ++@@ -1661,6 +1665,167 @@ ++ "brev.b64 \t$dst, $a;", ++ [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; ++ +++// +++// Rotate: Use ptx shf instruction if available. +++// +++ +++// 32 bit r2 = rotl r1, n +++// => +++// r2 = shf.l r1, r1, n +++def ROTL32imm_hw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), +++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, +++ Requires<[hasHWROT32]>; +++ +++def ROTL32reg_hw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +++ Requires<[hasHWROT32]>; +++ +++// 32 bit r2 = rotr r1, n +++// => +++// r2 = shf.r r1, r1, n +++def ROTR32imm_hw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), +++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, +++ Requires<[hasHWROT32]>; +++ +++def ROTR32reg_hw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +++ Requires<[hasHWROT32]>; +++ +++// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. +++def ROT32imm_sw : +++ NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), +++ "{{\n\t" +++ ".reg .b32 %lhs;\n\t" +++ ".reg .b32 %rhs;\n\t" +++ "shl.b32 \t%lhs, $src, $amt1;\n\t" +++ "shr.b32 \t%rhs, $src, $amt2;\n\t" +++ "add.u32 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ []>; +++ +++def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); +++}]>; +++ +++def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), +++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, +++ Requires<[noHWROT32]>; +++def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), +++ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, +++ Requires<[noHWROT32]>; +++ +++// 32-bit software rotate left by register. +++def ROTL32reg_sw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +++ "{{\n\t" +++ ".reg .b32 %lhs;\n\t" +++ ".reg .b32 %rhs;\n\t" +++ ".reg .b32 %amt2;\n\t" +++ "shl.b32 \t%lhs, $src, $amt;\n\t" +++ "sub.s32 \t%amt2, 32, $amt;\n\t" +++ "shr.b32 \t%rhs, $src, %amt2;\n\t" +++ "add.u32 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +++ Requires<[noHWROT32]>; +++ +++// 32-bit software rotate right by register. +++def ROTR32reg_sw : +++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +++ "{{\n\t" +++ ".reg .b32 %lhs;\n\t" +++ ".reg .b32 %rhs;\n\t" +++ ".reg .b32 %amt2;\n\t" +++ "shr.b32 \t%lhs, $src, $amt;\n\t" +++ "sub.s32 \t%amt2, 32, $amt;\n\t" +++ "shl.b32 \t%rhs, $src, %amt2;\n\t" +++ "add.u32 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +++ Requires<[noHWROT32]>; +++ +++// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. +++def ROT64imm_sw : +++ NVPTXInst<(outs Int64Regs:$dst), +++ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), +++ "{{\n\t" +++ ".reg .b64 %lhs;\n\t" +++ ".reg .b64 %rhs;\n\t" +++ "shl.b64 \t%lhs, $src, $amt1;\n\t" +++ "shr.b64 \t%rhs, $src, $amt2;\n\t" +++ "add.u64 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ []>; +++ +++def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); +++}]>; +++ +++def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), +++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; +++def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), +++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; +++ +++// 64-bit software rotate left by register. +++def ROTL64reg_sw : +++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), +++ "{{\n\t" +++ ".reg .b64 %lhs;\n\t" +++ ".reg .b64 %rhs;\n\t" +++ ".reg .u32 %amt2;\n\t" +++ "and.b32 \t%amt2, $amt, 63;\n\t" +++ "shl.b64 \t%lhs, $src, %amt2;\n\t" +++ "sub.u32 \t%amt2, 64, %amt2;\n\t" +++ "shr.b64 \t%rhs, $src, %amt2;\n\t" +++ "add.u64 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; +++ +++def ROTR64reg_sw : +++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), +++ "{{\n\t" +++ ".reg .b64 %lhs;\n\t" +++ ".reg .b64 %rhs;\n\t" +++ ".reg .u32 %amt2;\n\t" +++ "and.b32 \t%amt2, $amt, 63;\n\t" +++ "shr.b64 \t%lhs, $src, %amt2;\n\t" +++ "sub.u32 \t%amt2, 64, %amt2;\n\t" +++ "shl.b64 \t%rhs, $src, %amt2;\n\t" +++ "add.u64 \t$dst, %lhs, %rhs;\n\t" +++ "}}", +++ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; +++ +++// +++// Funnnel shift in clamp mode +++// +++ +++// Create SDNodes so they can be used in the DAG code, e.g. +++// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) +++def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; +++def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; +++ +++def FUNSHFLCLAMP : +++ NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +++ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", +++ [(set Int32Regs:$dst, +++ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; +++ +++def FUNSHFRCLAMP : +++ NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +++ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", +++ [(set Int32Regs:$dst, +++ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; ++ ++ // ++ // BFE - bit-field extract ++@@ -3492,42 +3657,6 @@ ++ def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), ++ (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; ++ ++-// ++-// Funnel-Shift ++-// ++- ++-// Create SDNodes so they can be used in the DAG code, e.g. ++-// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) ++-def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; ++-def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; ++- ++-// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so ++-// no side effects. ++-let hasSideEffects = false in { ++- multiclass ShfInst { ++- def _i ++- : NVPTXInst<(outs Int32Regs:$dst), ++- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), ++- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", ++- [(set Int32Regs:$dst, ++- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, ++- Requires<[hasHWROT32]>; ++- ++- def _r ++- : NVPTXInst<(outs Int32Regs:$dst), ++- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), ++- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", ++- [(set Int32Regs:$dst, ++- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, ++- Requires<[hasHWROT32]>; ++- } ++- ++- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; ++- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; ++- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; ++- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; ++-} ++- ++ // Count leading zeros ++ let hasSideEffects = false in { ++ def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td ++--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td ++@@ -2537,45 +2537,59 @@ ++ : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; ++ ++ ++-multiclass NG_TO_G { +++multiclass NG_TO_G { ++ def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), ++- "cvta." # Str # ".u32 \t$result, $src;", []>; +++ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), +++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; ++ def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), ++- "cvta." # Str # ".u64 \t$result, $src;", []>; +++ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), +++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; +++ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), +++ "{{ .reg .b64 %tmp;\n\t" +++ #" cvt.u64.u32 \t%tmp, $src;\n\t" +++ #" cvta." # Str # ".u64 \t$result, %tmp; }}", +++ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, +++ Requires<[ShortPtr]>; ++ } ++ ++-multiclass G_TO_NG { +++multiclass G_TO_NG { ++ def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), ++- "cvta.to." # Str # ".u32 \t$result, $src;", []>; +++ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), +++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; ++ def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), ++- "cvta.to." # Str # ".u64 \t$result, $src;", []>; +++ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), +++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; +++ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), +++ "{{ .reg .b64 %tmp;\n\t" +++ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" +++ #" cvt.u32.u64 \t$result, %tmp; }}", +++ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, +++ Requires<[ShortPtr]>; ++ } ++ ++-defm cvta_local : NG_TO_G<"local">; ++-defm cvta_shared : NG_TO_G<"shared">; ++-defm cvta_global : NG_TO_G<"global">; ++-defm cvta_const : NG_TO_G<"const">; ++- ++-defm cvta_to_local : G_TO_NG<"local">; ++-defm cvta_to_shared : G_TO_NG<"shared">; ++-defm cvta_to_global : G_TO_NG<"global">; ++-defm cvta_to_const : G_TO_NG<"const">; ++- ++-// nvvm.ptr.param.to.gen ++-defm cvta_param : NG_TO_G<"param">; ++- ++-def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), ++- (cvta_param Int32Regs:$src)>; ++- ++-def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), ++- (cvta_param_64 Int64Regs:$src)>; +++defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; +++defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; +++defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; +++defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; +++defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; +++ +++defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; +++defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; +++defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; +++defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; ++ ++ // nvvm.ptr.gen.to.param ++-def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), ++- (IMOV32rr Int32Regs:$src)>; +++def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), +++ (ins Int32Regs:$src), +++ "mov.u32 \t$result, $src;", +++ [(set Int32Regs:$result, +++ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; +++def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), +++ (ins Int64Regs:$src), +++ "mov.u64 \t$result, $src;", +++ [(set Int64Regs:$result, +++ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; ++ ++-def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), ++- (IMOV64rr Int64Regs:$src)>; ++ ++ // nvvm.move intrinsicc ++ def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), ++@@ -2618,6 +2632,24 @@ ++ [(set Int64Regs:$r, ++ (int_nvvm_move_ptr texternalsym:$s))]>;*/ ++ +++ +++// MoveParam %r1, param +++// ptr_local_to_gen %r2, %r1 +++// ptr_gen_to_local %r3, %r2 +++// -> +++// mov %r1, param +++ +++// @TODO: Revisit this. There is a type +++// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym +++// instructions are not currently defined. However, we can use the ptr +++// variants and the asm printer will do the right thing. +++def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen +++ (MoveParam texternalsym:$src)))), +++ (nvvm_move_ptr64 texternalsym:$src)>; +++def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen +++ (MoveParam texternalsym:$src)))), +++ (nvvm_move_ptr32 texternalsym:$src)>; +++ ++ def texsurf_handles ++ : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), ++ "mov.u64 \t$result, $src;", []>; ++@@ -2701,9 +2733,134 @@ ++ def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; ++ ++ +++// rotate builtin support +++ +++def ROTATE_B32_HW_IMM +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$src, i32imm:$amt), +++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, +++ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, +++ Requires<[hasHWROT32]> ; +++ +++def ROTATE_B32_HW_REG +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$src, Int32Regs:$amt), +++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +++ [(set Int32Regs:$dst, +++ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, +++ Requires<[hasHWROT32]> ; +++ +++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), +++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, +++ Requires<[noHWROT32]> ; +++ +++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), +++ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, +++ Requires<[noHWROT32]> ; +++ +++let hasSideEffects = false in { +++ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), +++ !strconcat("{{\n\t", +++ ".reg .b32 %dummy;\n\t", +++ "mov.b64 \t{$dst,%dummy}, $src;\n\t", +++ "}}"), +++ []> ; +++ +++ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), +++ !strconcat("{{\n\t", +++ ".reg .b32 %dummy;\n\t", +++ "mov.b64 \t{%dummy,$dst}, $src;\n\t", +++ "}}"), +++ []> ; +++} +++ +++let hasSideEffects = false in { +++ def PACK_TWO_INT32 +++ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), +++ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; +++} +++ ++ def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), ++- (V2I32toI64 (I64toI32H Int64Regs:$src), ++- (I64toI32L Int64Regs:$src))> ; +++ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src))> ; +++ +++// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so +++// no side effects. +++let hasSideEffects = false in { +++ def SHF_L_WRAP_B32_IMM +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +++ Requires<[hasHWROT32]>; +++ +++ def SHF_L_WRAP_B32_REG +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +++ Requires<[hasHWROT32]>; +++ +++ def SHF_R_WRAP_B32_IMM +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +++ Requires<[hasHWROT32]>; +++ +++ def SHF_R_WRAP_B32_REG +++ : NVPTXInst<(outs Int32Regs:$dst), +++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +++ Requires<[hasHWROT32]>; +++} +++ +++// HW version of rotate 64 +++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), +++ (PACK_TWO_INT32 +++ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src), imm:$amt), +++ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), +++ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, +++ Requires<[hasHWROT32]>; +++ +++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), +++ (PACK_TWO_INT32 +++ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), +++ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), +++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, +++ Requires<[hasHWROT32]>; +++ +++ +++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), +++ (PACK_TWO_INT32 +++ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), +++ (GET_HI_INT64 Int64Regs:$src), imm:$amt), +++ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, +++ Requires<[hasHWROT32]>; +++ +++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), +++ (PACK_TWO_INT32 +++ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), +++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), +++ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), +++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, +++ Requires<[hasHWROT32]>; +++ +++// SW version of rotate 64 +++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), +++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, +++ Requires<[noHWROT32]>; +++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), +++ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, +++ Requires<[noHWROT32]>; +++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), +++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, +++ Requires<[noHWROT32]>; +++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), +++ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, +++ Requires<[noHWROT32]>; +++ ++ ++ //----------------------------------- ++ // Texture Intrinsics ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp ++--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp ++@@ -1109,21 +1109,11 @@ ++ AddrSpaceCastSDNode *CastN = cast(N); ++ unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); ++ unsigned DstAddrSpace = CastN->getDestAddressSpace(); ++- SDLoc DL(N); ++ assert(SrcAddrSpace != DstAddrSpace && ++ "addrspacecast must be between different address spaces"); ++ ++ if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { ++ // Specific to generic ++- ++- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { ++- SDValue CvtNone = ++- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); ++- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, ++- Src, CvtNone); ++- Src = SDValue(Cvt, 0); ++- } ++- ++ unsigned Opc; ++ switch (SrcAddrSpace) { ++ default: report_fatal_error("Bad address space in addrspacecast"); ++@@ -1131,16 +1121,26 @@ ++ Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; ++ break; ++ case ADDRESS_SPACE_SHARED: ++- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +++ ? NVPTX::cvta_shared_6432 +++ : NVPTX::cvta_shared_64) +++ : NVPTX::cvta_shared; ++ break; ++ case ADDRESS_SPACE_CONST: ++- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +++ ? NVPTX::cvta_const_6432 +++ : NVPTX::cvta_const_64) +++ : NVPTX::cvta_const; ++ break; ++ case ADDRESS_SPACE_LOCAL: ++- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +++ ? NVPTX::cvta_local_6432 +++ : NVPTX::cvta_local_64) +++ : NVPTX::cvta_local; ++ break; ++ } ++- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); +++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), +++ Src)); ++ return; ++ } else { ++ // Generic to specific ++@@ -1153,28 +1153,30 @@ ++ Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; ++ break; ++ case ADDRESS_SPACE_SHARED: ++- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +++ ? NVPTX::cvta_to_shared_3264 +++ : NVPTX::cvta_to_shared_64) +++ : NVPTX::cvta_to_shared; ++ break; ++ case ADDRESS_SPACE_CONST: ++- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +++ ? NVPTX::cvta_to_const_3264 +++ : NVPTX::cvta_to_const_64) +++ : NVPTX::cvta_to_const; ++ break; ++ case ADDRESS_SPACE_LOCAL: ++- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; +++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +++ ? NVPTX::cvta_to_local_3264 +++ : NVPTX::cvta_to_local_64) +++ : NVPTX::cvta_to_local; ++ break; ++ case ADDRESS_SPACE_PARAM: ++- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; +++ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 +++ : NVPTX::nvvm_ptr_gen_to_param; ++ break; ++ } ++- ++- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); ++- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { ++- SDValue CvtNone = ++- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); ++- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, ++- SDValue(CVTA, 0), CvtNone); ++- } ++- ++- ReplaceNode(N, CVTA); +++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), +++ Src)); ++ return; ++ } ++ } ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp ++--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp ++@@ -594,13 +594,20 @@ ++ setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); ++ setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); ++ ++- setOperationAction({ISD::ROTL, ISD::ROTR}, ++- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, ++- Expand); ++- ++- if (STI.hasHWROT32()) ++- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); +++ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs +++ // that don't have h/w rotation we lower them to multi-instruction assembly. +++ // See ROT*_sw in NVPTXIntrInfo.td +++ setOperationAction(ISD::ROTL, MVT::i64, Legal); +++ setOperationAction(ISD::ROTR, MVT::i64, Legal); +++ setOperationAction(ISD::ROTL, MVT::i32, Legal); +++ setOperationAction(ISD::ROTR, MVT::i32, Legal); ++ +++ setOperationAction(ISD::ROTL, MVT::i16, Expand); +++ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); +++ setOperationAction(ISD::ROTR, MVT::i16, Expand); +++ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); +++ setOperationAction(ISD::ROTL, MVT::i8, Expand); +++ setOperationAction(ISD::ROTR, MVT::i8, Expand); ++ setOperationAction(ISD::BSWAP, MVT::i16, Expand); ++ ++ setOperationAction(ISD::BR_JT, MVT::Other, Custom); ++diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll ++--- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +++++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll ++@@ -31,19 +31,6 @@ ++ declare i64 @llvm.nvvm.bitcast.d2ll(double) ++ declare double @llvm.nvvm.bitcast.ll2d(i64) ++ ++-declare i32 @llvm.nvvm.rotate.b32(i32, i32) ++-declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) ++-declare i64 @llvm.nvvm.rotate.b64(i64, i32) ++- ++-declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) ++-declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) ++-declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) ++-declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) ++-declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) ++-declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) ++-declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) ++-declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) ++- ++ ; CHECK-LABEL: @simple_upgrade ++ define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { ++ ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) ++@@ -152,42 +139,4 @@ ++ %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) ++ ++ ret void ++-} ++- ++-; CHECK-LABEL: @rotate ++-define void @rotate(i32 %a, i64 %b) { ++-; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) ++-; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) ++-; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) ++-; ++- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) ++- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) ++- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) ++- ret void ++-} ++- ++-; CHECK-LABEL: @addrspacecast ++-define void @addrspacecast(ptr %p0) { ++-; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) ++-; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr ++-; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) ++-; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr ++-; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) ++-; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr ++-; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) ++-; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr ++-; ++- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) ++- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) ++- ++- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) ++- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) ++- ++- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) ++- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) ++- ++- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) ++- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) ++- ++- ret void ++-} +++} ++\ No newline at end of file ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll ++--- a/llvm/test/CodeGen/AMDGPU/freeze.ll +++++ b/llvm/test/CodeGen/AMDGPU/freeze.ll ++@@ -1,1856 +0,0 @@ ++-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ++-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s ++-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s ++-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s ++-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s ++- ++-define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v2i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v2i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <2 x i32> %a ++- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v3i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v3i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <3 x i32> %a ++- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v4i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v4i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <4 x i32> %a ++- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v5i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v5i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v5i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v5i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <5 x i32> %a ++- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v6i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v6i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v6i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v6i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <6 x i32> %a ++- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v7i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v7i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v7i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v7i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <7 x i32> %a ++- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v8i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v8i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v8i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v8i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <8 x i32> %a ++- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v9i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x2 ++-; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v9i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x2 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v9i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x2 ++-; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v9i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x2 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <9 x i32> %a ++- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v10i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: s_clause 0x2 ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 ++-; GFX10-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v10i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: s_clause 0x2 ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 ++-; GFX11-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <10 x i32> %a ++- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v11i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: s_clause 0x2 ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 ++-; GFX10-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v11i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: s_clause 0x2 ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 ++-; GFX11-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <11 x i32> %a ++- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_v12i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: s_clause 0x2 ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_v12i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: s_clause 0x2 ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <12 x i32> %a ++- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++-define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v13i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x3 ++-; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v13i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x3 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v13i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x3 ++-; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v13i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x3 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <13 x i32> %a ++- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v14i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x3 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v14i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x3 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v14i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x3 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v14i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x3 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <14 x i32> %a ++- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v15i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x3 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v15i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x3 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v15i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x3 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v15i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x3 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <15 x i32> %a ++- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v16i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x3 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v16i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x3 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v16i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x3 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v16i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x3 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <16 x i32> %a ++- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v17i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x4 ++-; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v17i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x4 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v17i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x4 ++-; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v17i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x4 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <17 x i32> %a ++- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v18i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x4 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v18i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x4 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v18i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x4 ++-; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v18i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x4 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <18 x i32> %a ++- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v19i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x4 ++-; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v19i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x4 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v19i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x4 ++-; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v19i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x4 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <19 x i32> %a ++- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v20i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x4 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v20i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x4 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v20i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x4 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v20i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x4 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <20 x i32> %a ++- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v21i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x5 ++-; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v21i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x5 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v21i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x5 ++-; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v21i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x5 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <21 x i32> %a ++- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v22i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x5 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v22i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x5 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v22i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x5 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v22i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x5 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <22 x i32> %a ++- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v30i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x7 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 ++-; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v30i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x7 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 ++-; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v30i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x7 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 ++-; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v30i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x7 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 ++-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <30 x i32> %a ++- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v31i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x7 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 ++-; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v31i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x7 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 ++-; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v31i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x7 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 ++-; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v31i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x7 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 ++-; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <31 x i32> %a ++- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_v32i32: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x7 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_v32i32: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x7 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_v32i32: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x7 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 ++-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 ++-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 ++-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 ++-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 ++-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off ++-; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_v32i32: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x7 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 ++-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 ++-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 ++-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 ++-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 ++-; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze <32 x i32> %a ++- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_i32: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dword v0, v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dword v[2:3], v0, off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_i32: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b32 v0, v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b32 v[2:3], v0, off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load i32, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze i32 %a ++- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_i64: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_i64: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load i64, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze i64 %a ++- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_float: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dword v0, v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dword v[2:3], v0, off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_float: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b32 v0, v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b32 v[2:3], v0, off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load float, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze float %a ++- store float %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-LABEL: freeze_i128: ++-; GFX10: ; %bb.0: ++-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-LABEL: freeze_i128: ++-; GFX11: ; %bb.0: ++-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-NEXT: s_setpc_b64 s[30:31] ++- %a = load i128, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze i128 %a ++- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++- ++-define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { ++-; GFX10-SDAG-LABEL: freeze_i256: ++-; GFX10-SDAG: ; %bb.0: ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-SDAG-NEXT: s_clause 0x1 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 ++-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 ++-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off ++-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX10-GISEL-LABEL: freeze_i256: ++-; GFX10-GISEL: ; %bb.0: ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX10-GISEL-NEXT: s_clause 0x1 ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off ++-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off ++-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 ++-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-SDAG-LABEL: freeze_i256: ++-; GFX11-SDAG: ; %bb.0: ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-SDAG-NEXT: s_clause 0x1 ++-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 ++-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 ++-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off ++-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] ++-; ++-; GFX11-GISEL-LABEL: freeze_i256: ++-; GFX11-GISEL: ; %bb.0: ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ++-; GFX11-GISEL-NEXT: s_clause 0x1 ++-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off ++-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off ++-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) ++-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 ++-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] ++- %a = load i256, ptr addrspace(1) %ptra, align 4 ++- %freeze = freeze i256 %a ++- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 ++- ret void ++-} ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir ++@@ -171,9 +171,11 @@ ++ ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 ++ ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 ++ ; GCN-NEXT: {{ $}} ++- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF ++- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) ++- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) +++ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF +++ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 +++ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 +++ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 +++ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] ++ %0:sgpr(s192) = G_IMPLICIT_DEF ++ %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 ++ S_ENDPGM 0, implicit %1, implicit %2, implicit %3 ++@@ -292,11 +294,11 @@ ++ ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) ++ ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) ++ ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) ++- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) ++- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) ++- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) ++- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) ++- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) +++ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) +++ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) +++ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) +++ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) +++ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) ++ %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 ++ %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 ++ %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir ++@@ -171,8 +171,12 @@ ++ ++ ; CHECK-LABEL: name: test_freeze_s448 ++ ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 ++- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] ++- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) +++ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) +++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) +++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) +++ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) ++ %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 ++ %1:_(s448) = G_TRUNC %0 ++ %2:_(s448) = G_FREEZE %1 ++@@ -395,12 +399,14 @@ ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_freeze_v33s32 ++- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] ++- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] ++- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) ++- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) +++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) +++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) +++ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) ++ ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) ++ %0:_(<33 x s32>) = G_IMPLICIT_DEF ++ %1:_(<33 x s32>) = G_FREEZE %0 ++@@ -413,10 +419,12 @@ ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_freeze_v64s32 ++- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] ++- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] ++- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) +++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) ++ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) ++ %0:_(<64 x s32>) = G_IMPLICIT_DEF ++ %1:_(<64 x s32>) = G_FREEZE %0 ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir ++@@ -135,9 +135,8 @@ ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_implicit_def_s448 ++- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) ++- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 +++ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 ++ ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) ++ %0:_(s448) = G_IMPLICIT_DEF ++ %1:_(s32) = G_EXTRACT %0, 0 ++@@ -297,6 +296,18 @@ ++ ... ++ ++ --- +++name: test_implicit_def_v17s32 +++body: | +++ bb.0: +++ +++ ; CHECK-LABEL: name: test_implicit_def_v17s32 +++ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) +++ %0:_(<17 x s32>) = G_IMPLICIT_DEF +++ S_NOP 0, implicit %0 +++... +++ +++--- ++ name: test_implicit_def_v32s32 ++ body: | ++ bb.0: ++@@ -317,9 +328,9 @@ ++ ; CHECK-LABEL: name: test_implicit_def_v33s32 ++ ; CHECK: liveins: $vgpr0_vgpr1 ++ ; CHECK-NEXT: {{ $}} ++- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 ++ ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) ++ ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) ++@@ -337,9 +348,10 @@ ++ bb.0: ++ ++ ; CHECK-LABEL: name: test_implicit_def_v64s32 ++- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF ++- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) +++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) ++ %0:_(<64 x s32>) = G_IMPLICIT_DEF ++ %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 ++ S_NOP 0, implicit %0, implicit %1 ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir ++@@ -190,11 +190,13 @@ ++ ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 ++ ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 ++ ; CHECK-NEXT: {{ $}} ++- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 ++ ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 ++- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) ++ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 ++ ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) ++@@ -241,8 +243,10 @@ ++ ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 ++ ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) ++ ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) ++- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) ++ ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) ++ ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) ++ ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir ++@@ -673,86 +673,88 @@ ++ ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) ++ ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 ++ ; CHECK-NEXT: {{ $}} ++- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF ++ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 ++ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 ++ ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] ++- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++ ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 ++ ; CHECK-NEXT: G_BR %bb.2 ++ ; CHECK-NEXT: {{ $}} ++ ; CHECK-NEXT: bb.1: ++ ; CHECK-NEXT: successors: %bb.2(0x80000000) ++ ; CHECK-NEXT: {{ $}} ++- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) ++- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] ++- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] ++- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] ++- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] ++- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] ++- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] ++- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] ++- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] ++- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] ++- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] ++- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] ++- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] ++- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] ++- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] ++- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] ++- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] ++- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] ++- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] ++- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] ++- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] ++- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] ++- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] ++- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] ++- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] ++- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] ++- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] ++- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] ++- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] ++- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] ++- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] ++- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] ++- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] ++- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] ++- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] ++- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] ++- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] ++- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] ++- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] ++- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] ++- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] ++- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] ++- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] ++- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] ++- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] ++- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] ++- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] ++- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] ++- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] ++- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] ++- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] ++- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] ++- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] ++- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] ++- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] ++- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] ++- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] ++- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] ++- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] ++- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] ++- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] ++- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] ++- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] ++- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] ++- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] +++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +++ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] +++ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] +++ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] +++ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] +++ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] +++ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] +++ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] +++ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] +++ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] +++ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] +++ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] +++ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] +++ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] +++ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] +++ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] +++ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] +++ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] +++ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] +++ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] +++ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] +++ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] +++ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] +++ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] +++ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] +++ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] +++ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] +++ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] +++ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] +++ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] +++ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] +++ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] +++ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] +++ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] +++ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] +++ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] +++ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] +++ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] +++ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] +++ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] +++ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] +++ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] +++ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] +++ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] +++ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] +++ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] +++ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] +++ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] +++ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] +++ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] +++ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] +++ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] +++ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] +++ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] +++ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] +++ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] +++ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] +++ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] +++ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] +++ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] +++ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] +++ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] +++ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] +++ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] +++ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] ++ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) ++ ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) ++ ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) ++@@ -760,10 +762,10 @@ ++ ; CHECK-NEXT: G_BR %bb.2 ++ ; CHECK-NEXT: {{ $}} ++ ; CHECK-NEXT: bb.2: ++- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 ++- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 ++- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 ++- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 +++ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 +++ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 +++ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 +++ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 ++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) ++ ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) ++ bb.0: ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir ++--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir ++@@ -42,6 +42,8 @@ ++ ret void ++ } ++ +++ define void @non_power_of_2() { ret void } +++ ++ define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { ++ ret void ++ } ++@@ -185,6 +187,23 @@ ++ ... ++ ++ --- +++name: non_power_of_2 +++legalized: true +++ +++body: | +++ bb.0: +++ ; CHECK-LABEL: name: non_power_of_2 +++ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF +++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 +++ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) +++ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 +++ %0:_(s448) = G_IMPLICIT_DEF +++ %1:_(s32) = G_EXTRACT %0:_(s448), 0 +++ $sgpr0 = COPY %1:_(s32) +++ SI_RETURN_TO_EPILOG $sgpr0 +++... +++ +++--- ++ name: load_constant_v4i16_from_8_align8 ++ legalized: true ++ ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll ++--- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +++++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll ++@@ -0,0 +1,21 @@ +++; RUN: opt < %s -O3 -S | FileCheck %s +++ +++; Address space intrinsics were erroneously marked NoCapture, leading to bad +++; optimizations (such as the store below being eliminated as dead code). This +++; test makes sure we don't regress. +++ +++declare void @foo(ptr addrspace(1)) +++ +++declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +++ +++; CHECK: @bar +++define void @bar() { +++ %t1 = alloca i32 +++; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) +++; CHECK-NEXT: store i32 10, ptr %t1 +++ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) +++ store i32 10, ptr %t1 +++ call void @foo(ptr addrspace(1) %t2) +++ ret void +++} +++ ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll ++--- a/llvm/test/CodeGen/NVPTX/rotate_64.ll +++++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll ++@@ -1,38 +1,25 @@ ++-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 ++ ; RUN: llc < %s -march=nvptx64 | FileCheck %s ++ ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} ++ ++ declare i64 @llvm.nvvm.rotate.b64(i64, i32) ++ declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) ++ +++; CHECK: rotate64 ++ define i64 @rotate64(i64 %a, i32 %b) { ++-; CHECK-LABEL: rotate64( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b64 %rd<5>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; ++-; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; ++-; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; ++-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; ++-; CHECK-NEXT: ret; +++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; +++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; +++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; +++; CHECK: ret ++ %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) ++ ret i64 %val ++ } ++ +++; CHECK: rotateright64 ++ define i64 @rotateright64(i64 %a, i32 %b) { ++-; CHECK-LABEL: rotateright64( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b64 %rd<5>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; ++-; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; ++-; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; ++-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; ++-; CHECK-NEXT: ret; +++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; +++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; +++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; +++; CHECK: ret ++ %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) ++ ret i64 %val ++ } ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll ++--- a/llvm/test/CodeGen/NVPTX/rotate.ll +++++ b/llvm/test/CodeGen/NVPTX/rotate.ll ++@@ -9,29 +9,26 @@ ++ declare i64 @llvm.nvvm.rotate.b64(i64, i32) ++ declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) ++ ++-declare i64 @llvm.fshl.i64(i64, i64, i64) ++-declare i64 @llvm.fshr.i64(i64, i64, i64) ++-declare i32 @llvm.fshl.i32(i32, i32, i32) ++-declare i32 @llvm.fshr.i32(i32, i32, i32) ++- ++- ++ ; SM20: rotate32 ++ ; SM35: rotate32 ++ define i32 @rotate32(i32 %a, i32 %b) { ++ ; SM20-LABEL: rotate32( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<9>; +++; SM20-NEXT: .reg .b32 %r<4>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; ++-; SM20-NEXT: and.b32 %r3, %r2, 31; ++-; SM20-NEXT: shl.b32 %r4, %r1, %r3; ++-; SM20-NEXT: neg.s32 %r5, %r2; ++-; SM20-NEXT: and.b32 %r6, %r5, 31; ++-; SM20-NEXT: shr.u32 %r7, %r1, %r6; ++-; SM20-NEXT: or.b32 %r8, %r4, %r7; ++-; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b32 %lhs; +++; SM20-NEXT: .reg .b32 %rhs; +++; SM20-NEXT: .reg .b32 %amt2; +++; SM20-NEXT: shl.b32 %lhs, %r1, %r2; +++; SM20-NEXT: sub.s32 %amt2, 32, %r2; +++; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; +++; SM20-NEXT: add.u32 %r3, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotate32( ++@@ -53,36 +50,45 @@ ++ define i64 @rotate64(i64 %a, i32 %b) { ++ ; SM20-LABEL: rotate64( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b32 %r<2>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM20-NEXT: neg.s32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; ++-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: .reg .u32 %amt2; +++; SM20-NEXT: and.b32 %amt2, %r1, 63; +++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; +++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotate64( ++ ; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b32 %r<6>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; ++-; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM35-NEXT: neg.s32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; ++-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b32 %dummy; +++; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; +++; SM35-NEXT: } +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b32 %dummy; +++; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; +++; SM35-NEXT: } +++; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; +++; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; +++; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; +++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) ++ ret i64 %val ++@@ -93,36 +99,45 @@ ++ define i64 @rotateright64(i64 %a, i32 %b) { ++ ; SM20-LABEL: rotateright64( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b32 %r<2>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; ++-; SM20-NEXT: neg.s32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; ++-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: .reg .u32 %amt2; +++; SM20-NEXT: and.b32 %amt2, %r1, 63; +++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; +++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotateright64( ++ ; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b32 %r<6>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; ++-; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; ++-; SM35-NEXT: neg.s32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; ++-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b32 %dummy; +++; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; +++; SM35-NEXT: } +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b32 %dummy; +++; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; +++; SM35-NEXT: } +++; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; +++; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; +++; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; +++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) ++ ret i64 %val ++@@ -133,14 +148,18 @@ ++ define i32 @rotl0(i32 %x) { ++ ; SM20-LABEL: rotl0( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; +++; SM20-NEXT: .reg .b32 %r<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; ++-; SM20-NEXT: shr.u32 %r2, %r1, 24; ++-; SM20-NEXT: shl.b32 %r3, %r1, 8; ++-; SM20-NEXT: or.b32 %r4, %r3, %r2; ++-; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b32 %lhs; +++; SM20-NEXT: .reg .b32 %rhs; +++; SM20-NEXT: shl.b32 %lhs, %r1, 8; +++; SM20-NEXT: shr.b32 %rhs, %r1, 24; +++; SM20-NEXT: add.u32 %r2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotl0( ++@@ -158,40 +177,51 @@ ++ ret i32 %t2 ++ } ++ +++declare i64 @llvm.fshl.i64(i64, i64, i64) +++declare i64 @llvm.fshr.i64(i64, i64, i64) +++ ++ ; SM35: rotl64 ++ define i64 @rotl64(i64 %a, i64 %n) { ++ ; SM20-LABEL: rotl64( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b32 %r<2>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM20-NEXT: neg.s32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; ++-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: .reg .u32 %amt2; +++; SM20-NEXT: and.b32 %amt2, %r1, 63; +++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; +++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotl64( ++ ; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b32 %r<2>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; ++ ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM35-NEXT: neg.s32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; ++-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b64 %lhs; +++; SM35-NEXT: .reg .b64 %rhs; +++; SM35-NEXT: .reg .u32 %amt2; +++; SM35-NEXT: and.b32 %amt2, %r1, 63; +++; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; +++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; +++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM35-NEXT: } +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) ++ ret i64 %val ++@@ -201,26 +231,34 @@ ++ define i64 @rotl64_imm(i64 %a) { ++ ; SM20-LABEL: rotl64_imm( ++ ; SM20: { ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; ++-; SM20-NEXT: shr.u64 %rd2, %rd1, 62; ++-; SM20-NEXT: shl.b64 %rd3, %rd1, 2; ++-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: shl.b64 %lhs, %rd1, 2; +++; SM20-NEXT: shr.b64 %rhs, %rd1, 62; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotl64_imm( ++ ; SM35: { ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; ++-; SM35-NEXT: shr.u64 %rd2, %rd1, 62; ++-; SM35-NEXT: shl.b64 %rd3, %rd1, 2; ++-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b64 %lhs; +++; SM35-NEXT: .reg .b64 %rhs; +++; SM35-NEXT: shl.b64 %lhs, %rd1, 2; +++; SM35-NEXT: shr.b64 %rhs, %rd1, 62; +++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM35-NEXT: } +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) ++ ret i64 %val ++@@ -230,36 +268,44 @@ ++ define i64 @rotr64(i64 %a, i64 %n) { ++ ; SM20-LABEL: rotr64( ++ ; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b32 %r<2>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; ++ ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; ++-; SM20-NEXT: neg.s32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; ++-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: .reg .u32 %amt2; +++; SM20-NEXT: and.b32 %amt2, %r1, 63; +++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; +++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotr64( ++ ; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b32 %r<2>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; ++ ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; ++-; SM35-NEXT: neg.s32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; ++-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b64 %lhs; +++; SM35-NEXT: .reg .b64 %rhs; +++; SM35-NEXT: .reg .u32 %amt2; +++; SM35-NEXT: and.b32 %amt2, %r1, 63; +++; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; +++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; +++; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; +++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM35-NEXT: } +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) ++ ret i64 %val ++@@ -269,180 +315,35 @@ ++ define i64 @rotr64_imm(i64 %a) { ++ ; SM20-LABEL: rotr64_imm( ++ ; SM20: { ++-; SM20-NEXT: .reg .b64 %rd<5>; +++; SM20-NEXT: .reg .b64 %rd<3>; ++ ; SM20-EMPTY: ++ ; SM20-NEXT: // %bb.0: ++ ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; ++-; SM20-NEXT: shl.b64 %rd2, %rd1, 62; ++-; SM20-NEXT: shr.u64 %rd3, %rd1, 2; ++-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM20-NEXT: { +++; SM20-NEXT: .reg .b64 %lhs; +++; SM20-NEXT: .reg .b64 %rhs; +++; SM20-NEXT: shl.b64 %lhs, %rd1, 62; +++; SM20-NEXT: shr.b64 %rhs, %rd1, 2; +++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM20-NEXT: } +++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM20-NEXT: ret; ++ ; ++ ; SM35-LABEL: rotr64_imm( ++ ; SM35: { ++-; SM35-NEXT: .reg .b64 %rd<5>; +++; SM35-NEXT: .reg .b64 %rd<3>; ++ ; SM35-EMPTY: ++ ; SM35-NEXT: // %bb.0: ++ ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; ++-; SM35-NEXT: shl.b64 %rd2, %rd1, 62; ++-; SM35-NEXT: shr.u64 %rd3, %rd1, 2; ++-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +++; SM35-NEXT: { +++; SM35-NEXT: .reg .b64 %lhs; +++; SM35-NEXT: .reg .b64 %rhs; +++; SM35-NEXT: shl.b64 %lhs, %rd1, 62; +++; SM35-NEXT: shr.b64 %rhs, %rd1, 2; +++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +++; SM35-NEXT: } +++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; ++ ; SM35-NEXT: ret; ++ %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) ++ ret i64 %val ++ } ++- ++-define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { ++-; SM20-LABEL: funnel_shift_right_32( ++-; SM20: { ++-; SM20-NEXT: .reg .b32 %r<11>; ++-; SM20-EMPTY: ++-; SM20-NEXT: // %bb.0: ++-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; ++-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; ++-; SM20-NEXT: and.b32 %r3, %r2, 31; ++-; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; ++-; SM20-NEXT: shr.u32 %r5, %r4, %r3; ++-; SM20-NEXT: shl.b32 %r6, %r1, 1; ++-; SM20-NEXT: not.b32 %r7, %r2; ++-; SM20-NEXT: and.b32 %r8, %r7, 31; ++-; SM20-NEXT: shl.b32 %r9, %r6, %r8; ++-; SM20-NEXT: or.b32 %r10, %r9, %r5; ++-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; ++-; SM20-NEXT: ret; ++-; ++-; SM35-LABEL: funnel_shift_right_32( ++-; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-EMPTY: ++-; SM35-NEXT: // %bb.0: ++-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; ++-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; ++-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; ++-; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; ++-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; ++-; SM35-NEXT: ret; ++- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) ++- ret i32 %val ++-} ++- ++-define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { ++-; SM20-LABEL: funnel_shift_left_32( ++-; SM20: { ++-; SM20-NEXT: .reg .b32 %r<11>; ++-; SM20-EMPTY: ++-; SM20-NEXT: // %bb.0: ++-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; ++-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; ++-; SM20-NEXT: and.b32 %r3, %r2, 31; ++-; SM20-NEXT: shl.b32 %r4, %r1, %r3; ++-; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; ++-; SM20-NEXT: shr.u32 %r6, %r5, 1; ++-; SM20-NEXT: not.b32 %r7, %r2; ++-; SM20-NEXT: and.b32 %r8, %r7, 31; ++-; SM20-NEXT: shr.u32 %r9, %r6, %r8; ++-; SM20-NEXT: or.b32 %r10, %r4, %r9; ++-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; ++-; SM20-NEXT: ret; ++-; ++-; SM35-LABEL: funnel_shift_left_32( ++-; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-EMPTY: ++-; SM35-NEXT: // %bb.0: ++-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; ++-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; ++-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; ++-; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; ++-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; ++-; SM35-NEXT: ret; ++- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) ++- ret i32 %val ++-} ++- ++-define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { ++-; SM20-LABEL: funnel_shift_right_64( ++-; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<7>; ++-; SM20-EMPTY: ++-; SM20-NEXT: // %bb.0: ++-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; ++-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; ++-; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; ++-; SM20-NEXT: shl.b64 %rd4, %rd1, 1; ++-; SM20-NEXT: not.b32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; ++-; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; ++-; SM20-NEXT: ret; ++-; ++-; SM35-LABEL: funnel_shift_right_64( ++-; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<7>; ++-; SM35-EMPTY: ++-; SM35-NEXT: // %bb.0: ++-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; ++-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; ++-; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; ++-; SM35-NEXT: shl.b64 %rd4, %rd1, 1; ++-; SM35-NEXT: not.b32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; ++-; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; ++-; SM35-NEXT: ret; ++- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) ++- ret i64 %val ++-} ++- ++-define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { ++-; SM20-LABEL: funnel_shift_left_64( ++-; SM20: { ++-; SM20-NEXT: .reg .b32 %r<5>; ++-; SM20-NEXT: .reg .b64 %rd<7>; ++-; SM20-EMPTY: ++-; SM20-NEXT: // %bb.0: ++-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; ++-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; ++-; SM20-NEXT: and.b32 %r2, %r1, 63; ++-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; ++-; SM20-NEXT: shr.u64 %rd4, %rd3, 1; ++-; SM20-NEXT: not.b32 %r3, %r1; ++-; SM20-NEXT: and.b32 %r4, %r3, 63; ++-; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; ++-; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; ++-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; ++-; SM20-NEXT: ret; ++-; ++-; SM35-LABEL: funnel_shift_left_64( ++-; SM35: { ++-; SM35-NEXT: .reg .b32 %r<5>; ++-; SM35-NEXT: .reg .b64 %rd<7>; ++-; SM35-EMPTY: ++-; SM35-NEXT: // %bb.0: ++-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; ++-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; ++-; SM35-NEXT: and.b32 %r2, %r1, 63; ++-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; ++-; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; ++-; SM35-NEXT: shr.u64 %rd4, %rd3, 1; ++-; SM35-NEXT: not.b32 %r3, %r1; ++-; SM35-NEXT: and.b32 %r4, %r3, 63; ++-; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; ++-; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; ++-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; ++-; SM35-NEXT: ret; ++- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) ++- ret i64 %val ++-} ++- ++diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll ++--- a/llvm/test/DebugInfo/NVPTX/debug-info.ll +++++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll ++@@ -25,10 +25,6 @@ ++ ; CHECK-DAG: .reg .b64 %rd<8>; ++ ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 ++ ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; ++-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; ++-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; ++-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; ++-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; ++ ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 ++ ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; ++ ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 ++@@ -42,6 +38,10 @@ ++ ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 ++ ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; ++ ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; +++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; ++ ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 ++ ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; ++ ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; ++@@ -2661,22 +2661,22 @@ ++ ; CHECK-NEXT:.b32 4579 // DW_AT_type ++ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine ++ ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin ++-; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc ++-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc +++; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc +++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc ++ ; CHECK-NEXT:.b8 1 // DW_AT_call_file ++ ; CHECK-NEXT:.b8 6 // DW_AT_call_line ++ ; CHECK-NEXT:.b8 11 // DW_AT_call_column ++ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine ++ ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin ++-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc ++-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc +++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc +++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc ++ ; CHECK-NEXT:.b8 1 // DW_AT_call_file ++ ; CHECK-NEXT:.b8 6 // DW_AT_call_line ++ ; CHECK-NEXT:.b8 24 // DW_AT_call_column ++ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine ++ ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin ++-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc ++-; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc +++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc +++; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc ++ ; CHECK-NEXT:.b8 1 // DW_AT_call_file ++ ; CHECK-NEXT:.b8 6 // DW_AT_call_line ++ ; CHECK-NEXT:.b8 37 // DW_AT_call_column diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 726a367..abe15ef 100644 +index abe15ef..af35fe7 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "8b4b7d28f7c344c728a9812aa99d9ad24edb40a2" -- LLVM_SHA256 = "f585b8955f66849929bbe0b657ea7ff5fe8f49880066a58b2a744065ddd4a521" -+ LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" -+ LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" +- LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" +- LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" ++ LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" ++ LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index b62c918736eb7e..4f6a0785270667 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "f9efe2966f00f8e7da8f7af3f8c8b3255cc158b8" - SHARDY_SHA256 = "6ca4c5f2de2102eca2a78ab64a443b2d327fd7b0ceb8c633a67cd1a2a316a2db" + SHARDY_COMMIT = "22e68fa19cfb2d28434a75d4d20d0efc182b166a" + SHARDY_SHA256 = "2b47b0ee994feca2bd782e20aca7d709e29bc870c2ac435aca967f7664c9f949" tf_http_archive( name = "shardy", From 8c9ae35d2302eebcd3bc8e793e054bf78bfe4199 Mon Sep 17 00:00:00 2001 From: Eric Salo Date: Wed, 25 Sep 2024 13:29:18 -0700 Subject: [PATCH 278/483] cleanup: remove api_version from BUILD files PiperOrigin-RevId: 678826174 --- tensorflow/compiler/mlir/lite/BUILD | 3 --- tensorflow/compiler/mlir/lite/metrics/BUILD | 1 - tensorflow/compiler/mlir/quantization/tensorflow/BUILD | 2 -- .../compiler/mlir/quantization/tensorflow/calibrator/BUILD | 1 - tensorflow/compiler/tf2tensorrt/BUILD | 1 - tensorflow/core/BUILD | 1 - tensorflow/core/debug/BUILD | 2 -- tensorflow/core/framework/BUILD | 1 - tensorflow/core/profiler/BUILD | 2 -- tensorflow/core/profiler/protobuf/BUILD | 4 ---- tensorflow/core/tfrt/graph_executor/BUILD | 1 - tensorflow/core/tpu/kernels/BUILD | 1 - tensorflow/core/util/autotune_maps/BUILD | 1 - tensorflow/distribute/experimental/rpc/proto/BUILD | 1 - tensorflow/python/BUILD | 1 - tensorflow/python/framework/BUILD | 1 - tensorflow/python/kernel_tests/proto/BUILD | 1 - tensorflow/python/tpu/BUILD | 1 - tensorflow/python/training/BUILD | 1 - 19 files changed, 27 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 707225f8a67556..4008cf266f9a22 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1882,21 +1882,18 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "types_py_proto", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":types_proto"], # ) # # py_proto_library( # name = "model_flags_py_proto", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":model_flags_proto"], # ) # # py_proto_library( # name = "converter_flags_py_proto", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":converter_flags_proto"], # ) diff --git a/tensorflow/compiler/mlir/lite/metrics/BUILD b/tensorflow/compiler/mlir/lite/metrics/BUILD index 2deb787ee627d2..1325962d8c0385 100644 --- a/tensorflow/compiler/mlir/lite/metrics/BUILD +++ b/tensorflow/compiler/mlir/lite/metrics/BUILD @@ -93,7 +93,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "converter_error_data_proto_py", -# api_version = 2, # visibility = [ # "//visibility:public", # ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 9c86ba1366869f..10b545777971f6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -555,7 +555,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "quantization_options_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":quantization_options_proto"], # ) @@ -582,7 +581,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "exported_model_py_pb2", -# api_version = 2, # deps = [":exported_model_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index 7931e3cd51e9db..b1f7caa3b0b0ec 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -110,7 +110,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "calibration_statistics_py_pb2", -# api_version = 2, # deps = [ # ":calibration_statistics_proto", # ], diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 86356d89f63b4e..fb88453879d5cf 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -1112,7 +1112,6 @@ pybind_extension( # py_proto_library( # name = "trt_engine_instance_proto_py_pb2", # has_services = 0, -# api_version = 2, # deps = [":trt_engine_instance_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1bc6f7ba2ece5f..f262e1dec8d02c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2032,7 +2032,6 @@ transitive_hdrs( # py_proto_library( # name = "protos_all_py_pb2", # has_services = 0, -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":protos_all"], # ) diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index cfd9fb86d55d6f..38dc34ed2588a7 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -287,13 +287,11 @@ tf_cc_binary( # py_proto_library( # name = "debug_service_py_pb2", # has_services = 1, -# api_version = 2, # deps = [":debug_service_proto"], # ) # # py_proto_library( # name = "debugger_event_metadata_py_pb2", -# api_version = 2, # deps = [":debugger_event_metadata_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index ee5080c90dd9c8..83176eb487262d 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -1807,7 +1807,6 @@ tf_proto_library( # py_proto_library( # name = "function_proto_py_pb2", # has_services = 0, -# api_version = 2, # deps = [ # ":function_proto", # ], diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index c5477810bf0fde..45b60e7f264704 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -179,14 +179,12 @@ filegroup( # py_proto_library( # name = "profiler_analysis_proto_py_pb2", # has_services = 1, -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":profiler_analysis_proto"], # ) # # py_proto_library( # name = "protos_all_py_pb2", -# api_version = 2, # visibility = [":friends"], # deps = [":protos_all"], # ) diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index 7a79e4a8ba7939..13cce56d193865 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -200,28 +200,24 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "xplane_py_pb2", -# api_version = 2, # visibility = [":friends"], # deps = [":xplane_proto"], # ) # # py_proto_library( # name = "memory_viewer_preprocess_py_pb2", -# api_version = 2, # visibility = [":memory_viewer_friends"], # deps = [":memory_viewer_preprocess_proto"], # ) # # py_proto_library( # name = "op_profile_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":op_profile_proto"], # ) # # py_proto_library( # name = "op_metrics_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":op_metrics_proto"], # ) diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 61d869fe2a767b..7c08bae7645d29 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -221,7 +221,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "config_proto_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":config_proto"], # ) diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 5c1dc5da889ede..23da9e8116a074 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -1551,7 +1551,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "sparse_core_layout_py_pb2", -# api_version = 2, # deps = [":sparse_core_layout_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index e428e14e170b40..f2d08b34f35a5e 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -148,7 +148,6 @@ tf_proto_library( # py_proto_library( # name = "autotune_map_py_pb2", # has_services = 0, -# api_version = 2, # visibility = ["//waymo/ml/deploy/system/autotuning:__subpackages__"], # deps = [":autotune_map_proto"], # ) diff --git a/tensorflow/distribute/experimental/rpc/proto/BUILD b/tensorflow/distribute/experimental/rpc/proto/BUILD index 097b4b38797619..6acb10c5140d30 100644 --- a/tensorflow/distribute/experimental/rpc/proto/BUILD +++ b/tensorflow/distribute/experimental/rpc/proto/BUILD @@ -28,7 +28,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "tf_rpc_service_py_pb2", -# api_version = 2, # deps = [":tf_rpc_service_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4331c4a77d8a21..3942f1712faaad 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1405,7 +1405,6 @@ py_strict_library( # py_proto_library( # name = "protos_all_py_pb2", # has_services = 0, -# api_version = 2, # deps = [":protos_all"], # ) # diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index f08a3e1347ec99..29b3cc611f3f2d 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -3315,7 +3315,6 @@ tf_python_pybind_extension( # py_proto_library( # name = "cpp_shape_inference_proto_py_pb2", # has_services = 0, -# api_version = 2, # deps = [":cpp_shape_inference_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/python/kernel_tests/proto/BUILD b/tensorflow/python/kernel_tests/proto/BUILD index 10bc4c327c5a59..e50955770c1def 100644 --- a/tensorflow/python/kernel_tests/proto/BUILD +++ b/tensorflow/python/kernel_tests/proto/BUILD @@ -155,7 +155,6 @@ tf_py_strict_test( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "test_example_proto_py", -# api_version = 2, # deps = [":test_example_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index 6e690be7feb65a..d7c88da50a9959 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -1013,7 +1013,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "tensor_tracer_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":tensor_tracer_proto"], # ) diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index a32f63e995d56f..ce75ffbc860031 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -781,7 +781,6 @@ tf_proto_library( # py_proto_library( # name = "checkpoint_state_py_pb2", # testonly = 0, -# api_version = 2, # deps = [":checkpoint_state"], # ) # copybara:uncomment_end From 6868479e8b39075323716fda8ecb6f1325b73406 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 13:47:31 -0700 Subject: [PATCH 279/483] Add collective permute and collective broadcast tests with two GPUs. PiperOrigin-RevId: 678833130 --- .../xla/xla/tests/collective_ops_test.cc | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index fcecf8f4a66cef..89245736389bda 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -652,6 +652,44 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { } } +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_TwoGPUs)) { + const char* const kModuleStr = R"( + HloModule test + + collective_broadcast { + p0 = u32[2] parameter(0) + ROOT result = u32[2] collective-broadcast(p0), replica_groups={{1, 0}} + } + + ENTRY test_computation { + replica = u32[] replica-id() + ten = u32[] constant(10) + sum = u32[] add(replica, ten) + p = u32[2] broadcast(sum), dimensions={} + cb = ((u32[2]), u32[2]) async-start(u32[2] %p), calls=collective_broadcast + ROOT res = u32[2] async-done(cb), calls=collective_broadcast + } + )"; + const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[1])); +} + XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { const char* const kModuleStr = R"( HloModule test @@ -694,6 +732,38 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { results[3])); } +XLA_TEST_F(CollectiveOpsTest, CollectivePermute_TwoGPUs) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + replica = u32[] replica-id() + ten = u32[] constant(10) + sum = u32[] add(replica, ten) + p = u32[2] broadcast(sum), dimensions={} + permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}} + ROOT copy = u32[2] copy(permute) + } + )"; + const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({10, 10}), + results[1])); +} + XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { const char* const kModuleStr = R"( HloModule test From 082f2451e09209bc0bdf8540723f6e0f66bb5fe2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 13:51:28 -0700 Subject: [PATCH 280/483] Reduce the number of comments so that transformations play nicely with dependencies. PiperOrigin-RevId: 678834709 --- tensorflow/lite/schema/BUILD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD index fbd167c1ac5625..a3a922051fb1b9 100644 --- a/tensorflow/lite/schema/BUILD +++ b/tensorflow/lite/schema/BUILD @@ -111,9 +111,10 @@ exports_files([ # srcs = ["//tensorflow/compiler/mlir/lite/schema:schema.fbs"], # compatible_with = get_compatible_with_portable(), # ) -# copybara:uncomment_end_and_comment_begin +# copybara:uncomment_end(google-only) + cc_library( - name = "schema_fbs", + name = "schema_fbs", # copybara:comment_replace name = "schema_fbs_for_oss", hdrs = [ ":schema_generated.h", "//tensorflow/compiler/mlir/lite/schema:schema_generated.h", @@ -123,7 +124,6 @@ cc_library( "@flatbuffers//:runtime_cc", ], ) -# copybara:comment_end # Generic schema for flatbuffer converter (but with mutable makes bigger). flatbuffer_cc_library( From 7cb5c59d54468e3352bc53b234b7c50f4bfce8e4 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Wed, 25 Sep 2024 14:02:21 -0700 Subject: [PATCH 281/483] Add a method to unfuse a given instruction from a fusion computation. PiperOrigin-RevId: 678839106 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 49 +++++++++ third_party/xla/xla/hlo/ir/hlo_instruction.h | 7 ++ .../xla/xla/service/hlo_instruction_test.cc | 99 +++++++++++++++++++ 3 files changed, 155 insertions(+) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index ffcfafbcc663d7..4b200a2c831c41 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -3253,6 +3253,55 @@ absl::Status HloInstruction::Defuse() { return module->RemoveEmbeddedComputation(fused_computation); } +absl::StatusOr HloInstruction::UnfuseInstruction( + HloInstruction* instruction) { + CHECK_EQ(opcode(), HloOpcode::kFusion); + + std::vector new_operands; + // Gather the operands that need to be extracted from the fusion. + for (int64_t operand_num = 0; operand_num < instruction->operand_count(); + ++operand_num) { + HloInstruction* operand = instruction->mutable_operand(operand_num); + if (operand->opcode() == HloOpcode::kParameter) { + // If the operand is a parameter of the fusion, we need to extract it. + HloInstruction* extracted_operand = + mutable_operand(operand->parameter_number()); + new_operands.push_back(extracted_operand); + } else if (operand->opcode() == HloOpcode::kConstant) { + HloInstruction* cloned_constant = AddInstruction(operand->Clone()); + new_operands.push_back(cloned_constant); + } else if (operand->opcode() == HloOpcode::kBroadcast && + operand->operand(0)->opcode() == HloOpcode::kConstant) { + HloInstruction* cloned_constant = + AddInstruction(operand->operand(0)->Clone()); + new_operands.push_back(AddInstruction( + operand->CloneWithNewOperands(operand->shape(), {cloned_constant}))); + } else { + return InvalidArgument( + "Unsupported operand type for unfusing: %s. Currently only " + "parameters and constants are supported.", + operand->ToString()); + } + } + + // Clone the instruction to be unfused. + HloInstruction* unfused_instruction = AddInstruction( + instruction->CloneWithNewOperands(instruction->shape(), new_operands)); + + // Add the unfused instruction as a parameter to the fusion instruction. + HloComputation* fusion_computation = fused_instructions_computation(); + + HloInstruction* new_parameter = AddFusionOperand(unfused_instruction); + // Replace the instruction in the fusion computation with the new parameter. + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_parameter)); + + // Remove the original instruction from the fusion computation. + TF_RETURN_IF_ERROR( + fusion_computation->RemoveInstructionAndUnusedOperands(instruction)); + + return unfused_instruction; +} + absl::Status HloInstruction::ReplaceUsesWith( absl::Span users, HloInstruction* new_producer) { TF_RET_CHECK( diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index c1b64c06c7633c..3ef42bfc41adc6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1685,6 +1685,13 @@ class HloInstruction { // Decomposes fusion back to individual parts. absl::Status Defuse(); + // Unfuses the given instruction from its fusion computation. If the given + // instruction is not fused, this is a no-op and returns nullptr. Returns a + // pointer to the newly unfused instruction if successful. Currently, fused + // instructions with parameter or constant operands are supported. + absl::StatusOr UnfuseInstruction( + HloInstruction* instruction); + // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use // of this instruction to avoid introducing cycles into the graph. diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index 7709bda6032e7f..1db4a17c8fc39f 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -3058,5 +3058,104 @@ TEST_F(HloInstructionTest, m::Add(m::Parameter(0), m::Parameter(1))))); } +TEST_F(HloInstructionTest, UnfuseInstruction) { + const std::string& hlo_string = R"( + HloModule mof + fusion_comp { + param0 = f32[10]{0} parameter(0) + param1 = f32[10]{0} parameter(1) + add = f32[10]{0} add(param0, param1) + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(param1, add) + } + + ENTRY main { + p0 = f32[10]{0} parameter(0) + p1 = f32[10]{0} parameter(1) + fusion.1 = (f32[10]{0}, f32[10]{0}) fusion(p0, p1), kind=kLoop, calls=fusion_comp + gte0 = f32[10]{0} get-tuple-element(fusion.1), index=0 + gte1 = f32[10]{0} get-tuple-element(fusion.1), index=1 + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(gte0, gte1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* fusion = FindInstruction(module.get(), "fusion.1"); + HloInstruction* add = fusion->fused_instructions_computation() + ->root_instruction() + ->mutable_operand(1); + TF_ASSERT_OK_AND_ASSIGN(auto unfused, fusion->UnfuseInstruction(add)); + EXPECT_THAT(unfused, GmockMatch(m::Add(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(HloInstructionTest, UnfuseInstruction2) { + const std::string& hlo_string = R"( + HloModule mof + fusion_comp { + param0 = f32[10]{0} parameter(0) + param1 = f32[10]{0} parameter(1) + add = f32[10]{0} add(param0, param1) + add2 = f32[10]{0} add(add, param1) + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(param1, add2) + } + + ENTRY main { + p0 = f32[10]{0} parameter(0) + p1 = f32[10]{0} parameter(1) + fusion.1 = (f32[10]{0}, f32[10]{0}) fusion(p0, p1), kind=kLoop, calls=fusion_comp + gte0 = f32[10]{0} get-tuple-element(fusion.1), index=0 + gte1 = f32[10]{0} get-tuple-element(fusion.1), index=1 + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(gte0, gte1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* fusion = FindInstruction(module.get(), "fusion.1"); + HloInstruction* add2 = fusion->fused_instructions_computation() + ->root_instruction() + ->mutable_operand(1); + HloInstruction* add = add2->mutable_operand(0); + + // add2 is not unfusable since it has non-const non-parameter operands. + EXPECT_FALSE(fusion->UnfuseInstruction(add2).ok()); + + TF_ASSERT_OK_AND_ASSIGN(auto unfused, fusion->UnfuseInstruction(add)); + EXPECT_THAT(unfused, GmockMatch(m::Add(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(HloInstructionTest, UnfuseInstructionWithConstantOperand) { + const std::string& hlo_string = R"( + HloModule mof + fusion_comp { + param0 = f32[10]{0} parameter(0) + param1 = f32[10]{0} parameter(1) + const = f32[] constant(1.0) + broadcast = f32[10]{0} broadcast(const), dimensions={} + add = f32[10]{0} add(param0, broadcast) + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(param1, add) + } + + ENTRY main { + p0 = f32[10]{0} parameter(0) + p1 = f32[10]{0} parameter(1) + fusion.1 = (f32[10]{0}, f32[10]{0}) fusion(p0, p1), kind=kLoop, calls=fusion_comp + gte0 = f32[10]{0} get-tuple-element(fusion.1), index=0 + gte1 = f32[10]{0} get-tuple-element(fusion.1), index=1 + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(gte0, gte1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* fusion = FindInstruction(module.get(), "fusion.1"); + HloInstruction* add = fusion->fused_instructions_computation() + ->root_instruction() + ->mutable_operand(1); + TF_ASSERT_OK_AND_ASSIGN(auto unfused, fusion->UnfuseInstruction(add)); + EXPECT_THAT(unfused, + GmockMatch(m::Add(m::Parameter(0), m::Broadcast(m::Constant())))); +} + } // namespace } // namespace xla From e41bf850299b093db9734f3f81fccf3ab1711948 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 25 Sep 2024 14:04:12 -0700 Subject: [PATCH 282/483] [XLA:Python] Use nanobind::hash instead of our own home-grown version. Cleanup only; no functional changes intended. PiperOrigin-RevId: 678839983 --- third_party/xla/xla/python/jax_jit.h | 2 +- third_party/xla/xla/python/nb_helpers.cc | 8 -------- third_party/xla/xla/python/nb_helpers.h | 4 ---- third_party/xla/xla/python/pjit.cc | 2 +- third_party/xla/xla/python/py_device_list.cc | 2 +- third_party/xla/xla/python/sharding.cc | 2 +- third_party/xla/xla/python/types.cc | 2 +- third_party/xla/xla/python/weakref_lru_cache.cc | 8 +++----- 8 files changed, 8 insertions(+), 22 deletions(-) diff --git a/third_party/xla/xla/python/jax_jit.h b/third_party/xla/xla/python/jax_jit.h index fa0fc2b78e89a0..79552702765061 100644 --- a/third_party/xla/xla/python/jax_jit.h +++ b/third_party/xla/xla/python/jax_jit.h @@ -134,7 +134,7 @@ H AbslHashValue(H h, const ArgumentSignature& s) { const auto& static_arg = s.static_args[i]; Py_hash_t hash; try { - hash = xla::nb_hash(static_arg); + hash = nanobind::hash(static_arg); } catch (const nanobind::python_error& e) { if (!e.matches(PyExc_TypeError)) throw; throw std::invalid_argument(absl::StrCat( diff --git a/third_party/xla/xla/python/nb_helpers.cc b/third_party/xla/xla/python/nb_helpers.cc index 80e3a6ee6d11f6..6a241ca79cf6ff 100644 --- a/third_party/xla/xla/python/nb_helpers.cc +++ b/third_party/xla/xla/python/nb_helpers.cc @@ -23,14 +23,6 @@ namespace nb = nanobind; namespace xla { -Py_hash_t nb_hash(nb::handle o) { - Py_hash_t h = PyObject_Hash(o.ptr()); - if (h == -1) { - throw nb::python_error(); - } - return h; -} - bool nb_isinstance(nanobind::handle inst, nanobind::handle cls) { int ret = PyObject_IsInstance(inst.ptr(), cls.ptr()); if (ret == -1) { diff --git a/third_party/xla/xla/python/nb_helpers.h b/third_party/xla/xla/python/nb_helpers.h index 845adb8692cf5e..c8d69acaa7bdf5 100644 --- a/third_party/xla/xla/python/nb_helpers.h +++ b/third_party/xla/xla/python/nb_helpers.h @@ -23,10 +23,6 @@ limitations under the License. namespace xla { -// Calls Python hash() on an object. -// TODO(phawkins): consider upstreaming this to nanobind. -Py_hash_t nb_hash(nanobind::handle o); - // Calls Python isinstance(inst, cls). // TODO(phawkins): consider upstreaming this to nanobind. bool nb_isinstance(nanobind::handle inst, nanobind::handle cls); diff --git a/third_party/xla/xla/python/pjit.cc b/third_party/xla/xla/python/pjit.cc index c14556d1b58651..a64ab9808459b0 100644 --- a/third_party/xla/xla/python/pjit.cc +++ b/third_party/xla/xla/python/pjit.cc @@ -164,7 +164,7 @@ class PjitFunctionCache { h = H::combine(std::move(h), key.function.ptr()); Py_hash_t hash; try { - hash = xla::nb_hash(key.global_cache_key); + hash = nb::hash(key.global_cache_key); } catch (const nanobind::python_error& e) { if (!e.matches(PyExc_TypeError)) throw; throw std::invalid_argument(absl::StrCat( diff --git a/third_party/xla/xla/python/py_device_list.cc b/third_party/xla/xla/python/py_device_list.cc index 22d701a5fb361f..a0ea40ce1efb81 100644 --- a/third_party/xla/xla/python/py_device_list.cc +++ b/third_party/xla/xla/python/py_device_list.cc @@ -104,7 +104,7 @@ int64_t PyDeviceList::Hash() { hash_ = absl::HashOf(std::get<0>(device_list_)); break; case 1: - hash_ = xla::nb_hash(std::get<1>(device_list_)); + hash_ = nb::hash(std::get<1>(device_list_)); break; default: throw nb::value_error("Unrecognized DeviceList type"); diff --git a/third_party/xla/xla/python/sharding.cc b/third_party/xla/xla/python/sharding.cc index cdad90eb430794..2c3d70a465d63b 100644 --- a/third_party/xla/xla/python/sharding.cc +++ b/third_party/xla/xla/python/sharding.cc @@ -128,7 +128,7 @@ size_t ShardingHash(nb::handle sharding) { return absl::Hash()(single_device_sharding->device().ptr()); } - return xla::nb_hash(sharding); + return nb::hash(sharding); } bool ShardingEqual(nb::handle a, nb::handle b) { diff --git a/third_party/xla/xla/python/types.cc b/third_party/xla/xla/python/types.cc index eaad6db5f16667..4a1de389cd5b5d 100644 --- a/third_party/xla/xla/python/types.cc +++ b/third_party/xla/xla/python/types.cc @@ -133,7 +133,7 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { } }; struct DtypeHash { - ssize_t operator()(const nb_dtype& key) const { return nb_hash(key); } + ssize_t operator()(const nb_dtype& key) const { return nb::hash(key); } }; static auto* custom_dtype_map = []() { const CustomDtypes& custom_dtypes = GetCustomDtypes(); diff --git a/third_party/xla/xla/python/weakref_lru_cache.cc b/third_party/xla/xla/python/weakref_lru_cache.cc index 1767e1dabb9cb1..2c2e5bfc222e2a 100644 --- a/third_party/xla/xla/python/weakref_lru_cache.cc +++ b/third_party/xla/xla/python/weakref_lru_cache.cc @@ -51,8 +51,7 @@ class HashablePyDictValue { template friend H AbslHashValue(H h, const HashablePyDictValue& value) { auto kv = *value.iter_; - return H::combine(std::move(h), xla::nb_hash(kv.first), - xla::nb_hash(kv.second)); + return H::combine(std::move(h), nb::hash(kv.first), nb::hash(kv.second)); } explicit HashablePyDictValue(const Iter& iter) : iter_(iter) {} @@ -93,8 +92,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this { template friend H AbslHashValue(H h, const Key& key) { - h = H::combine(std::move(h), xla::nb_hash(key.context), - xla::nb_hash(key.args)); + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); h = H::combine_unordered(std::move(h), HashablePyDictIter(key.kwargs.begin()), HashablePyDictIter(key.kwargs.end())); @@ -192,7 +190,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this { nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { nb::object context = cache_context_fn_(); std::shared_ptr cache_ptr = GetCache(UnboundWeakrefCacheEntry{ - weakref_key, this, static_cast(xla::nb_hash(weakref_key))}); + weakref_key, this, static_cast(nb::hash(weakref_key))}); Cache& cache = *cache_ptr; ++total_queries_; From 7a0737aab19dce0bab5408d3370c1bed8cb885fc Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Wed, 25 Sep 2024 14:25:30 -0700 Subject: [PATCH 283/483] Modify numpy upgrade info in the release notes PiperOrigin-RevId: 678847763 --- RELEASE.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index a156363a53d523..9820f420589ac5 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -20,10 +20,6 @@ * * -* TensorFlow now supports and is compiled with NumPy 2.0 by default. - Compatibility with NumPy 1.26 will be maintained until 2025, aligning with - community standard deprecation timeline [here](https://scientific-python.org/specs/spec-0000/). - ### Bug Fixes and Other Changes * @@ -75,6 +71,10 @@ This release contains contributions from many people at Google, as well as: * TensorRT support is disabled in CUDA builds for code health improvement. +* TensorFlow now supports and is compiled with NumPy 2.0 by default. Please see the [NumPy 2 release notes](https://numpy.org/doc/stable/release/2.0.0-notes.html) and the [NumPy 2 migration guide](https://numpy.org/devdocs/numpy_2_0_migration_guide.html#numpy-2-migration-guide). + * Note that NumPy's type promotion rules have been changed(See [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html#nep50)for details). This may change the precision at which computations happen, leading either to type errors or to numerical changes to results. + * Tensorflow will continue to support NumPy 1.26 until 2025, aligning with community standard deprecation timeline [here](https://scientific-python.org/specs/spec-0000/). + * Hermetic CUDA support is added. Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL distributions, and then use CUDA libraries and tools as dependencies in various Bazel targets. This enables more reproducible builds for Google ML projects and supported CUDA versions. From 10e7c9d930263b66ad98d7c3389f9a4b9ffc7ac0 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 25 Sep 2024 14:27:56 -0700 Subject: [PATCH 284/483] Move `tsl/protobuf/*` besides `error_codes.proto` to `xla/tsl/protobuf` `error_codes.proto` will be moved in a separate change PiperOrigin-RevId: 678848832 --- .../mlir/quantization/stablehlo/BUILD | 2 +- .../mlir/quantization/tensorflow/BUILD | 2 +- tensorflow/core/BUILD | 8 +- .../eager/context_distributed_manager.cc | 2 +- .../next_pluggable_device/BUILD | 4 +- ..._plugin_coordination_service_agent_test.cc | 4 +- tensorflow/core/data/BUILD | 2 +- tensorflow/core/data/service/BUILD | 4 +- tensorflow/core/data/service/client/BUILD | 4 +- tensorflow/core/data/service/common.proto | 2 +- tensorflow/core/data/service/snapshot/BUILD | 16 ++-- .../core/data/service/snapshot/file_utils.cc | 2 +- .../snapshot/snapshot_chunk_provider.cc | 2 +- .../snapshot/snapshot_chunk_provider_test.cc | 2 +- .../data/service/snapshot/snapshot_manager.cc | 2 +- .../data/service/snapshot/snapshot_manager.h | 2 +- .../service/snapshot/snapshot_manager_test.cc | 2 +- tensorflow/core/data/service/worker_impl.cc | 2 +- tensorflow/core/distributed_runtime/BUILD | 6 +- .../distributed_runtime/coordination/BUILD | 6 +- .../coordination_service_barrier_proxy.cc | 2 +- .../coordination_service_barrier_proxy.h | 2 +- ...coordination_service_barrier_proxy_test.cc | 4 +- .../eager/eager_service_impl.cc | 2 +- .../c_api_coordination_test.cc | 2 +- .../c_api_multi_client_function_test.cc | 2 +- .../c_api_multi_client_test.cc | 2 +- .../c_api_recoverable_jobs_test.cc | 2 +- .../c_api_session_coordination_test.cc | 2 +- tensorflow/core/distributed_runtime/master.cc | 2 +- .../core/distributed_runtime/master_env.h | 2 +- .../distributed_runtime/master_session.cc | 2 +- tensorflow/core/distributed_runtime/rpc/BUILD | 2 +- .../rpc/grpc_worker_service.cc | 2 +- .../core/distributed_runtime/session_mgr.cc | 6 +- tensorflow/core/distributed_runtime/worker.cc | 2 +- tensorflow/core/framework/BUILD | 4 +- tensorflow/core/framework/summary.proto | 2 +- tensorflow/core/function/runtime_client/BUILD | 4 +- .../core/platform/build_config.default.bzl | 2 +- tensorflow/core/protobuf/BUILD | 12 +-- tensorflow/core/protobuf/config.proto | 2 +- .../core/protobuf/conv_autotuning.proto | 2 +- tensorflow/core/protobuf/rpc_options.proto | 2 +- tensorflow/core/protobuf/status.proto | 2 +- tensorflow/core/util/autotune_maps/BUILD | 10 +-- .../util/autotune_maps/autotune_map.proto | 2 +- .../util/autotune_maps/autotune_serialize.cc | 2 +- .../util/autotune_maps/conv_map_wrapper.cc | 2 +- .../autotune_maps/conv_map_wrapper_test.cc | 2 +- .../util/autotune_maps/conv_parameters.proto | 2 +- tensorflow/python/BUILD | 4 +- tensorflow/python/client/BUILD | 4 +- .../collective_all_reduce_strategy.py | 2 +- .../parameter_server_strategy_v2.py | 2 +- tensorflow/python/eager/context.py | 2 +- tensorflow/python/framework/BUILD | 12 +-- tensorflow/python/proto_exports.py | 2 +- tensorflow/python/training/server_lib_test.py | 2 +- .../xla/third_party/tsl/tsl/platform/BUILD | 4 +- .../tsl/tsl/platform/default/build_config.bzl | 8 +- .../tsl/tsl/platform/status_test.cc | 2 +- .../tsl/tsl/platform/status_to_from_proto.cc | 2 +- .../tsl/tsl/platform/status_to_from_proto.h | 2 +- .../xla/third_party/tsl/tsl/protobuf/BUILD | 84 ----------------- third_party/xla/xla/BUILD | 2 +- third_party/xla/xla/autotuning.proto | 2 +- third_party/xla/xla/pjrt/distributed/BUILD | 6 +- .../xla/xla/pjrt/distributed/client.cc | 4 +- .../xla/xla/pjrt/distributed/service.cc | 2 +- third_party/xla/xla/python/ifrt/BUILD | 2 +- .../python/ifrt/plugin_program_serdes_test.cc | 2 +- .../xla/xla/python/ifrt_proxy/client/BUILD | 2 +- .../ifrt_proxy/client/grpc_host_buffer.cc | 2 +- .../xla/xla/python/ifrt_proxy/common/BUILD | 2 +- .../ifrt_proxy/common/ifrt_service.proto | 2 +- .../xla/xla/python/ifrt_proxy/server/BUILD | 2 +- .../ifrt_proxy/server/ifrt_backend_test.cc | 2 +- third_party/xla/xla/service/BUILD | 2 +- third_party/xla/xla/service/gpu/BUILD | 6 +- .../xla/xla/service/gpu/autotuning/BUILD | 2 +- .../autotuning/gemm_algorithm_picker_test.cc | 2 +- .../xla/xla/service/gpu/backend_configs.proto | 2 +- .../xla/service/gpu/ir_emitter_unnested.cc | 2 +- .../xla/service/gpu/stream_executor_util.cc | 2 +- .../xla/service/gpu/stream_executor_util.h | 2 +- .../xla/xla/service/gpu/transforms/BUILD | 4 +- .../gpu/transforms/cudnn_norm_rewriter.cc | 2 +- .../service/gpu/transforms/gemm_rewriter.cc | 2 +- .../xla/xla/service/xla_compile_result.proto | 2 +- third_party/xla/xla/stream_executor/BUILD | 12 +-- third_party/xla/xla/stream_executor/blas.h | 2 +- .../xla/xla/stream_executor/cuda/BUILD | 4 +- .../xla/xla/stream_executor/cuda/cuda_blas.cc | 2 +- .../xla/xla/stream_executor/cuda/cuda_dnn.cc | 2 +- .../xla/xla/stream_executor/cuda/cuda_dnn.h | 2 +- .../xla/xla/stream_executor/data_type.h | 2 +- third_party/xla/xla/stream_executor/dnn.cc | 2 +- third_party/xla/xla/stream_executor/dnn.h | 2 +- third_party/xla/xla/stream_executor/gpu/BUILD | 2 +- .../xla/stream_executor/gpu/gpu_blas_lt.cc | 2 +- .../xla/xla/stream_executor/lazy_op_runner.h | 2 +- third_party/xla/xla/tools/BUILD | 4 +- .../xla/xla/tools/xla_cpu_compile_lib_test.cc | 2 +- .../xla/xla/tools/xla_gpu_compile_lib_test.cc | 2 +- .../distributed_runtime/coordination/BUILD | 34 +++---- .../coordination/coordination_client.h | 2 +- .../coordination/coordination_service.cc | 4 +- .../coordination/coordination_service.h | 4 +- .../coordination_service_agent.cc | 4 +- .../coordination/coordination_service_agent.h | 2 +- .../coordination_service_agent_test.cc | 4 +- .../coordination_service_error_util.h | 2 +- .../coordination_service_error_util_test.cc | 2 +- ...ordination_service_recoverable_job_test.cc | 2 +- .../coordination_service_rpc_handler.cc | 2 +- .../coordination_service_rpc_handler.h | 2 +- .../coordination/coordination_service_test.cc | 4 +- .../tsl/distributed_runtime/preemption/BUILD | 8 +- .../preemption/preemption_sync_manager.cc | 2 +- .../preemption_sync_manager_test.cc | 2 +- .../xla/xla/tsl/distributed_runtime/rpc/BUILD | 8 +- .../rpc/coordination/BUILD | 6 +- .../coordination/grpc_coordination_client.cc | 2 +- .../grpc_coordination_service_impl.h | 4 +- .../distributed_runtime/rpc/grpc_channel.cc | 2 +- .../distributed_runtime/rpc/grpc_channel.h | 2 +- .../rpc/grpc_channel_test.cc | 2 +- .../tsl/distributed_runtime/rpc/grpc_util.h | 2 +- third_party/xla/xla/tsl/lib/histogram/BUILD | 4 +- .../xla/xla/tsl/lib/histogram/histogram.cc | 2 +- .../xla/tsl/lib/histogram/histogram_test.cc | 2 +- third_party/xla/xla/tsl/lib/monitoring/BUILD | 10 +-- .../tsl/lib/monitoring/collected_metrics.h | 2 +- .../tsl/lib/monitoring/collection_registry.h | 2 +- .../xla/xla/tsl/lib/monitoring/metric_def.h | 2 +- .../xla/xla/tsl/lib/monitoring/sampler.h | 4 +- .../xla/xla/tsl/lib/monitoring/test_utils.cc | 2 +- .../xla/xla/tsl/lib/monitoring/test_utils.h | 2 +- third_party/xla/xla/tsl/protobuf/BUILD | 90 ++++++++++++++++++- .../tsl/protobuf/coordination_config.proto | 0 .../tsl/protobuf/coordination_service.proto | 0 .../distributed_runtime_payloads.proto | 0 .../tsl => xla}/tsl/protobuf/dnn.proto | 0 .../tsl => xla}/tsl/protobuf/histogram.proto | 0 .../tsl/protobuf/rpc_options.proto | 0 .../tsl => xla}/tsl/protobuf/status.proto | 0 third_party/xla/xla/xla.bzl | 2 +- 148 files changed, 317 insertions(+), 319 deletions(-) rename third_party/xla/{third_party/tsl => xla}/tsl/protobuf/coordination_config.proto (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/protobuf/coordination_service.proto (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/protobuf/distributed_runtime_payloads.proto (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/protobuf/dnn.proto (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/protobuf/histogram.proto (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/protobuf/rpc_options.proto (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/protobuf/status.proto (100%) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 47510e04c27abe..ba674527171041 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -144,10 +144,10 @@ cc_library( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:str_util", - "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:unfuse_batch_norm", + "@local_xla//xla/tsl/protobuf:protos_all_cc", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_portable_api", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 10b545777971f6..e5696e4db86720 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -463,9 +463,9 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:str_util", - "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/mlir_hlo", + "@local_xla//xla/tsl/protobuf:protos_all_cc", "@stablehlo//:chlo_ops", ], # Alwayslink is required for registering the MLIR passes. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index f262e1dec8d02c..d0b9f6a8f21e41 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -175,9 +175,9 @@ tf_proto_library( "//tensorflow/core/grappler/costs:op_performance_data", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto", "@local_tsl//tsl/profiler/protobuf:xplane_proto", - "@local_tsl//tsl/protobuf:coordination_config_proto", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto", - "@local_tsl//tsl/protobuf:status_proto", + "@local_xla//xla/tsl/protobuf:coordination_config_proto", + "@local_xla//xla/tsl/protobuf:distributed_runtime_payloads_proto", + "@local_xla//xla/tsl/protobuf:status_proto", ], visibility = ["//visibility:public"], ) @@ -1468,7 +1468,7 @@ cc_library( "@local_xla//xla:autotune_results_proto_cc_impl", "@local_xla//xla:autotuning_proto_cc_impl", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", - "@local_tsl//tsl/protobuf:protos_all_cc_impl", + "@local_xla//xla/tsl/protobuf:protos_all_cc_impl", "@local_xla//xla:xla_proto_cc_impl", "@local_xla//xla:xla_data_proto_cc_impl", "@local_xla//xla/service:hlo_proto_cc_impl", diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index b36e2a2c6c6362..a82dfb8f1c7363 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -29,6 +29,7 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -54,7 +55,6 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/coordination_config.pb.h" #if !defined(IS_MOBILE_PLATFORM) #include "absl/base/thread_annotations.h" diff --git a/tensorflow/core/common_runtime/next_pluggable_device/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/BUILD index 57708229f106ad..1e46fb453f6e7d 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/BUILD @@ -336,12 +336,12 @@ tf_cc_test( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", "@local_xla//xla/tsl/distributed_runtime:call_options", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_client", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/protobuf:coordination_config_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], ) diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc index 5d62f8c58668c6..61811c5435c208 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc @@ -28,11 +28,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tensorflow/core/platform/status.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 748dfc17ce213e..00a7e14585c929 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -612,7 +612,7 @@ tf_cc_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ] + tf_protos_all(), ) diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 188418148f2fc3..8a76428a848dde 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -409,7 +409,7 @@ tf_cc_test( "@com_google_absl//absl/status", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(), ) @@ -1135,7 +1135,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:status_to_from_proto", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/data/service/client/BUILD b/tensorflow/core/data/service/client/BUILD index 60e2da30b5a8ac..025645c42e4142 100644 --- a/tensorflow/core/data/service/client/BUILD +++ b/tensorflow/core/data/service/client/BUILD @@ -102,7 +102,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -121,8 +121,8 @@ tf_cc_test( "//tensorflow/core/data/service:test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(), ) diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto index f0d9fd70a9e8e0..9d2825082efed1 100644 --- a/tensorflow/core/data/service/common.proto +++ b/tensorflow/core/data/service/common.proto @@ -2,10 +2,10 @@ syntax = "proto3"; package tensorflow.data; +import "xla/tsl/protobuf/status.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/protobuf/data_service.proto"; import "tensorflow/core/protobuf/snapshot.proto"; -import "tsl/protobuf/status.proto"; // Next tag: 2 message DatasetDef { diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD index 2b7d59674fabc3..ffc34db5936595 100644 --- a/tensorflow/core/data/service/snapshot/BUILD +++ b/tensorflow/core/data/service/snapshot/BUILD @@ -81,7 +81,7 @@ tf_cc_test( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -181,7 +181,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -287,7 +287,7 @@ tf_cc_test( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -321,7 +321,7 @@ cc_library( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -344,7 +344,7 @@ tf_cc_test( "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -394,7 +394,7 @@ cc_library( "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tstring", - "@local_tsl//tsl/protobuf:status_proto_cc", + "@local_xla//xla/tsl/protobuf:status_proto_cc", ], ) @@ -423,7 +423,7 @@ tf_cc_test( "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tstring", - "@local_tsl//tsl/protobuf:status_proto_cc", + "@local_xla//xla/tsl/protobuf:status_proto_cc", ], ) @@ -509,8 +509,8 @@ tf_cc_test( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla/tsl/lib/monitoring:cell_reader", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) diff --git a/tensorflow/core/data/service/snapshot/file_utils.cc b/tensorflow/core/data/service/snapshot/file_utils.cc index 0440b00b34f7f0..ec5397bdfd1ed9 100644 --- a/tensorflow/core/data/service/snapshot/file_utils.cc +++ b/tensorflow/core/data/service/snapshot/file_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/dataset.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tsl/platform/protobuf.h" #include "tsl/platform/random.h" #include "tsl/platform/status_to_from_proto.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc index 731f9435ffcd13..ff1e2caea35b00 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/framework/dataset.h" @@ -46,7 +47,6 @@ limitations under the License. #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tstring.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc index e40fd0ad918387..e6fcd97ef6d5dd 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/platform/tstring.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.cc b/tensorflow/core/data/service/snapshot/snapshot_manager.cc index a3845361762bd8..3351934b555fa6 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/tsl/lib/io/compression.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" @@ -52,7 +53,6 @@ limitations under the License. #include "tsl/platform/thread_annotations.h" #include "tsl/platform/threadpool.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.h b/tensorflow/core/data/service/snapshot/snapshot_manager.h index 8c53ae98650878..5db495f16c87ce 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.h +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.h @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/snapshot/prefetched_split_provider.h" #include "tensorflow/core/framework/dataset.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc index 65b3c59e8ecba4..7ab6e10f4a10c1 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index ebddaf184ce254..f8182f16e3e161 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.h" #include "tensorflow/core/data/service/common.pb.h" @@ -67,7 +68,6 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 143681cee500ed..00515c71df7917 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -172,10 +172,10 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/activity_watcher", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_rpc_handler", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], ) @@ -392,7 +392,7 @@ cc_library( "//tensorflow/core/platform:regexp", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc", + "@local_xla//xla/tsl/protobuf:rpc_options_proto_cc", ], ) @@ -491,7 +491,7 @@ cc_library( ":worker_cache", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc", + "@local_xla//xla/tsl/protobuf:rpc_options_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/coordination/BUILD b/tensorflow/core/distributed_runtime/coordination/BUILD index a964cb8d1a0f43..6c0bb0e705517c 100644 --- a/tensorflow/core/distributed_runtime/coordination/BUILD +++ b/tensorflow/core/distributed_runtime/coordination/BUILD @@ -60,8 +60,8 @@ cc_library( "@com_google_absl//absl/time", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], ) @@ -81,11 +81,11 @@ tf_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", "@local_xla//xla/tsl/distributed_runtime:call_options", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_client", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", + "@local_xla//xla/tsl/protobuf:coordination_config_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc index d24ef5f03a2896..0a97ac4ca79a44 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc @@ -27,11 +27,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h index 5d9aeeec3debc4..f1e1342502eac9 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h @@ -25,11 +25,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc index b74535ccb44d93..3e4eb73db5c85a 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc @@ -32,13 +32,13 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/threadpool.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 71addd352b76d5..2130100df90c33 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_distributed_manager.h" #include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" @@ -53,7 +54,6 @@ limitations under the License. #include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { namespace eager { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc index 356f0a08412fd9..9f803991417dce 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/blocking_counter.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc index 7a077cbe13b8d6..25eb3da148a23f 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc index cffe93d297a8df..01ab92e1939ceb 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc index 3d9ff3c459181f..3c04a22afb174b 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/strcat.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_session_coordination_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_session_coordination_test.cc index e5e9aad6f06ddf..1d78c7008b91e1 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_session_coordination_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_session_coordination_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/tf_datatype.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index f67f2b6052bcc1..6084a38f4d5563 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -35,6 +35,7 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/remote_device.h" @@ -56,7 +57,6 @@ limitations under the License. #include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h index 633e21df361386..51e99d7126c5d6 100644 --- a/tensorflow/core/distributed_runtime/master_env.h +++ b/tensorflow/core/distributed_runtime/master_env.h @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/public/session_options.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tsl { class Env; diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index cb08a53815fd73..b1197b3e137504 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/profile_handler.h" #include "tensorflow/core/common_runtime/stats_publisher_interface.h" @@ -64,7 +65,6 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" #include "tsl/platform/tracing.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 959a8abae6518a..0b7b731e6f51fb 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -208,9 +208,9 @@ tf_cuda_library( "//tensorflow/core/profiler/lib:scoped_memory_debug_annotation", "//tensorflow/core/protobuf:worker_proto_cc", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc", "@local_xla//xla/tsl/distributed_runtime/rpc:async_service_interface", "@local_xla//xla/tsl/distributed_runtime/rpc:grpc_call", + "@local_xla//xla/tsl/protobuf:rpc_options_proto_cc", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index bb705a9b3d3f19..ba00a466a0f841 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/grpc_call.h" +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device.h" @@ -56,7 +57,6 @@ limitations under the License. #include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #include "tsl/platform/tracing.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index a6b4df397b6b25..94a943c99ea77a 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -23,6 +23,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" +#include "xla/tsl/protobuf/distributed_runtime_payloads.pb.h" #include "tensorflow/core/activity_watcher/activity.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/renamed_device.h" @@ -35,9 +38,6 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/util/device_name_utils.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" -#include "tsl/protobuf/distributed_runtime_payloads.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 0922b04de0b0f8..27e8a205c65d75 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/tsl/protobuf/distributed_runtime_payloads.pb.h" #include "tensorflow/core/common_runtime/collective_executor_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/profiler/lib/device_profiler_session.h" #include "tsl/platform/tracing.h" -#include "tsl/protobuf/distributed_runtime_payloads.pb.h" namespace tensorflow { diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 83176eb487262d..5931211307f39b 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -1822,9 +1822,9 @@ tf_proto_library( ":tensor_proto", ":tensor_shape_proto", ":types_proto", - "@local_tsl//tsl/protobuf:histogram_proto", + "@local_xla//xla/tsl/protobuf:histogram_proto", ], - exports = ["@local_tsl//tsl/protobuf:histogram_proto"], + exports = ["@local_xla//xla/tsl/protobuf:histogram_proto"], ) tf_proto_library( diff --git a/tensorflow/core/framework/summary.proto b/tensorflow/core/framework/summary.proto index c6b515abea2517..9e219b027fc26a 100644 --- a/tensorflow/core/framework/summary.proto +++ b/tensorflow/core/framework/summary.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package tensorflow; -import public "tsl/protobuf/histogram.proto"; +import public "xla/tsl/protobuf/histogram.proto"; import "tensorflow/core/framework/tensor.proto"; diff --git a/tensorflow/core/function/runtime_client/BUILD b/tensorflow/core/function/runtime_client/BUILD index 970bfd92d7accd..e2576d14ebc18e 100644 --- a/tensorflow/core/function/runtime_client/BUILD +++ b/tensorflow/core/function/runtime_client/BUILD @@ -125,14 +125,14 @@ tf_python_pybind_extension( "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], otherwise = [ "//tensorflow/core/framework:function_proto_cc_headers_only", "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", ], ), ) diff --git a/tensorflow/core/platform/build_config.default.bzl b/tensorflow/core/platform/build_config.default.bzl index b39d6b913740e5..4eeea4882183f3 100644 --- a/tensorflow/core/platform/build_config.default.bzl +++ b/tensorflow/core/platform/build_config.default.bzl @@ -45,7 +45,7 @@ def tf_protos_all(): Label("//tensorflow/core:protos_all_cc_impl"), "@local_xla//xla:autotune_results_proto_cc_impl", "@local_xla//xla:autotuning_proto_cc_impl", - "@local_tsl//tsl/protobuf:protos_all_cc_impl", + "@local_xla//xla/tsl/protobuf:protos_all_cc_impl", ], otherwise = [Label("//tensorflow/core:protos_all_cc")], ) diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD index 30a0b7283a2e57..a6760507145143 100644 --- a/tensorflow/core/protobuf/BUILD +++ b/tensorflow/core/protobuf/BUILD @@ -71,7 +71,7 @@ tf_proto_library( srcs = ["conv_autotuning.proto"], make_default_target_header_only = True, protodeps = [ - "@local_tsl//tsl/protobuf:dnn_proto", + "@local_xla//xla/tsl/protobuf:dnn_proto", ], ) @@ -200,16 +200,16 @@ tf_proto_library( ":error_codes_proto_impl", "//tensorflow/core/framework:protos_all", "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", - "@local_tsl//tsl/protobuf:coordination_config_proto", - "@local_tsl//tsl/protobuf:rpc_options_proto", - "@local_tsl//tsl/protobuf:status_proto", + "@local_xla//xla/tsl/protobuf:coordination_config_proto", + "@local_xla//xla/tsl/protobuf:rpc_options_proto", + "@local_xla//xla/tsl/protobuf:status_proto", ], tags = ["alt_dep=//third_party/tensorflow/core:protos_all"], visibility = ["//visibility:public"], exports = [ - "@local_tsl//tsl/protobuf:rpc_options_proto", - "@local_tsl//tsl/protobuf:status_proto", "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", + "@local_xla//xla/tsl/protobuf:rpc_options_proto", + "@local_xla//xla/tsl/protobuf:status_proto", ], ) diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index f5ee231fac28e4..5c4f9bc2b54b68 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow; +import "xla/tsl/protobuf/coordination_config.proto"; import "tensorflow/core/framework/cost_graph.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/step_stats.proto"; @@ -9,7 +10,6 @@ import "tensorflow/core/protobuf/cluster.proto"; import "tensorflow/core/protobuf/debug.proto"; import "tensorflow/core/protobuf/rewriter_config.proto"; import "tensorflow/core/protobuf/rpc_options.proto"; -import "tsl/protobuf/coordination_config.proto"; option cc_enable_arenas = true; option java_outer_classname = "ConfigProtos"; diff --git a/tensorflow/core/protobuf/conv_autotuning.proto b/tensorflow/core/protobuf/conv_autotuning.proto index 21f1c2adbf5613..47ed3a1174899b 100644 --- a/tensorflow/core/protobuf/conv_autotuning.proto +++ b/tensorflow/core/protobuf/conv_autotuning.proto @@ -4,7 +4,7 @@ syntax = "proto3"; package tensorflow; -import "tsl/protobuf/dnn.proto"; +import "xla/tsl/protobuf/dnn.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/protobuf/rpc_options.proto b/tensorflow/core/protobuf/rpc_options.proto index db9216a7e7bb4c..03593a682e81cb 100644 --- a/tensorflow/core/protobuf/rpc_options.proto +++ b/tensorflow/core/protobuf/rpc_options.proto @@ -2,6 +2,6 @@ syntax = "proto3"; package tensorflow.dummy; -import public "tsl/protobuf/rpc_options.proto"; +import public "xla/tsl/protobuf/rpc_options.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/protobuf/status.proto b/tensorflow/core/protobuf/status.proto index dd6f703d100ae2..d7df8cf3e05af0 100644 --- a/tensorflow/core/protobuf/status.proto +++ b/tensorflow/core/protobuf/status.proto @@ -6,6 +6,6 @@ syntax = "proto3"; // code for some users that use JS through J2CL. package tensorflow.dummy; -import public "tsl/protobuf/status.proto"; +import public "xla/tsl/protobuf/status.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index f2d08b34f35a5e..1d5ce0d8676788 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -52,8 +52,8 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", "@local_xla//xla/tsl/lib/strings:proto_serialization", + "@local_xla//xla/tsl/protobuf:dnn_proto_cc", ], ) @@ -66,8 +66,8 @@ tf_cc_test( ":conv_parameters_proto_cc", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", "@local_xla//xla:test", + "@local_xla//xla/tsl/protobuf:dnn_proto_cc", ], ) @@ -78,7 +78,7 @@ tf_proto_library( ], protodeps = [ "//tensorflow/core/framework:types_proto", - "@local_tsl//tsl/protobuf:dnn_proto", + "@local_xla//xla/tsl/protobuf:dnn_proto", ], ) @@ -136,7 +136,7 @@ tf_proto_library( ], protodeps = [ "//tensorflow/core/util/autotune_maps:conv_parameters_proto", - "@local_tsl//tsl/protobuf:dnn_proto", + "@local_xla//xla/tsl/protobuf:dnn_proto", ], visibility = [ "//waymo/ml/deploy/benchmark:__subpackages__", @@ -179,12 +179,12 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core/platform:status", "//tensorflow/core/platform:str_util", - "@local_tsl//tsl/protobuf:dnn_proto_cc", "@local_xla//xla:status_macros", "@local_xla//xla/stream_executor:dnn", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/gpu:gpu_init", "@local_xla//xla/tsl/lib/strings:proto_serialization", + "@local_xla//xla/tsl/protobuf:dnn_proto_cc", ], ) diff --git a/tensorflow/core/util/autotune_maps/autotune_map.proto b/tensorflow/core/util/autotune_maps/autotune_map.proto index c655b3c1a5927d..79192075761bd9 100644 --- a/tensorflow/core/util/autotune_maps/autotune_map.proto +++ b/tensorflow/core/util/autotune_maps/autotune_map.proto @@ -21,8 +21,8 @@ syntax = "proto3"; package tensorflow; +import "xla/tsl/protobuf/dnn.proto"; import "tensorflow/core/util/autotune_maps/conv_parameters.proto"; -import "tsl/protobuf/dnn.proto"; message ConvMapProto { message Entry { diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.cc b/tensorflow/core/util/autotune_maps/autotune_serialize.cc index c601502a0d0512..fec18867bf3cc0 100644 --- a/tensorflow/core/util/autotune_maps/autotune_serialize.cc +++ b/tensorflow/core/util/autotune_maps/autotune_serialize.cc @@ -26,13 +26,13 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/util/activation_mode.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/protobuf/dnn.pb.h" namespace tensorflow { diff --git a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc index 0bd1122c132238..8044441680501b 100644 --- a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc +++ b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/protobuf/dnn.pb.h" namespace tensorflow { diff --git a/tensorflow/core/util/autotune_maps/conv_map_wrapper_test.cc b/tensorflow/core/util/autotune_maps/conv_map_wrapper_test.cc index 6279cd03ae25ac..5443e8e28c7193 100644 --- a/tensorflow/core/util/autotune_maps/conv_map_wrapper_test.cc +++ b/tensorflow/core/util/autotune_maps/conv_map_wrapper_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include "xla/test.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/util/autotune_maps/conv_parameters.proto b/tensorflow/core/util/autotune_maps/conv_parameters.proto index 03a9cfd005d6f6..aee217e80968c9 100644 --- a/tensorflow/core/util/autotune_maps/conv_parameters.proto +++ b/tensorflow/core/util/autotune_maps/conv_parameters.proto @@ -22,8 +22,8 @@ syntax = "proto3"; package tensorflow; +import "xla/tsl/protobuf/dnn.proto"; import "tensorflow/core/framework/types.proto"; -import "tsl/protobuf/dnn.proto"; // LINT.IfChange diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3942f1712faaad..b42ff159414e4e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1257,13 +1257,13 @@ tf_python_pybind_extension( "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], otherwise = [ "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", ], ), ) diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index a1d8ae68ecdd80..ed9e8466acea24 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -78,13 +78,13 @@ tf_python_pybind_extension( "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", "//tensorflow/core:version_lib", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], otherwise = [ "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", ], ), ) diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index c7dac643cefb37..8899a7d2c04cd6 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -19,6 +19,7 @@ import time import weakref +from xla.tsl.protobuf import coordination_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.distribute import collective_util @@ -49,7 +50,6 @@ from tensorflow.python.trackable import base from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export -from tsl.protobuf import coordination_config_pb2 # pylint: disable=line-too-long diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index d6bdcd0ac13d6f..9457fe576638b2 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -21,6 +21,7 @@ import os import threading +from xla.tsl.protobuf import coordination_config_pb2 from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -50,7 +51,6 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export -from tsl.protobuf import coordination_config_pb2 ALLOWED_TASK_TYPES = ("chief", "worker", "ps") diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 61e64a594a42cb..8fbb75dd8e9b9a 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -26,6 +26,7 @@ from absl import logging import numpy as np +from xla.tsl.protobuf import coordination_config_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_debug_info_pb2 from tensorflow.core.protobuf import config_pb2 @@ -46,7 +47,6 @@ from tensorflow.python.util import tf_contextlib from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export -from tsl.protobuf import coordination_config_pb2 # TODO(b/307794935): Remove after a solution is found. diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 29b3cc611f3f2d..5b3aef335e0586 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -895,13 +895,13 @@ tf_python_pybind_extension( "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], otherwise = [ "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", ], ), ) @@ -992,13 +992,13 @@ tf_python_pybind_extension( "@pybind11", ] + if_static( extra_deps = [ - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", ], otherwise = [ - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", @@ -1143,13 +1143,13 @@ tf_python_pybind_extension( "@pybind11", ] + if_static( extra_deps = [ - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", ], otherwise = [ - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", diff --git a/tensorflow/python/proto_exports.py b/tensorflow/python/proto_exports.py index c414936539df3d..34475ffb3a15f7 100644 --- a/tensorflow/python/proto_exports.py +++ b/tensorflow/python/proto_exports.py @@ -14,6 +14,7 @@ # ============================================================================== """Registers protos with tf_export that should be public.""" +from xla.tsl.protobuf import histogram_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import summary_pb2 @@ -21,7 +22,6 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.util import tf_export -from tsl.protobuf import histogram_pb2 AttrValue = tf_export.tf_export(v1=['AttrValue'])(attr_value_pb2.AttrValue) ConfigProto = tf_export.tf_export(v1=['ConfigProto'])(config_pb2.ConfigProto) diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py index 2478ab64e6be80..7bfddf38185b5f 100644 --- a/tensorflow/python/training/server_lib_test.py +++ b/tensorflow/python/training/server_lib_test.py @@ -18,6 +18,7 @@ import numpy as np +from xla.tsl.protobuf import rpc_options_pb2 from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 @@ -35,7 +36,6 @@ from tensorflow.python.training import input as input_ops from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import server_lib -from tsl.protobuf import rpc_options_pb2 class GrpcServerTest(test.TestCase): diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD index 88ee88876195ef..dfe09b5a36e1b9 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD @@ -323,7 +323,7 @@ cc_library( deps = [ ":status", "//tsl/protobuf:error_codes_proto_impl_cc", - "//tsl/protobuf:status_proto_cc", + "@local_xla//xla/tsl/protobuf:status_proto_cc", ] + tf_platform_deps("status"), ) @@ -1349,10 +1349,10 @@ tsl_cc_test( ":test", ":test_main", "//tsl/protobuf:error_codes_proto_impl_cc", - "//tsl/protobuf:status_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@local_xla//xla/tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl index 726f54634c5661..32b91de84c7e93 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl @@ -726,7 +726,7 @@ def tf_lib_proto_parsing_deps(): return [ ":protos_all_cc", clean_dep("@eigen_archive//:eigen3"), - clean_dep("//tsl/protobuf:protos_all_cc"), + clean_dep("@local_xla//xla/tsl/protobuf:protos_all_cc"), ] def tf_py_clif_cc(name, visibility = None, **kwargs): @@ -779,8 +779,8 @@ def tsl_cc_test( # TODO(ddunleavy) remove these and add proto deps to tests # granularly clean_dep("//tsl/protobuf:error_codes_proto_impl_cc_impl"), - clean_dep("//tsl/protobuf:histogram_proto_cc_impl"), - clean_dep("//tsl/protobuf:status_proto_cc_impl"), + clean_dep("@local_xla//xla/tsl/protobuf:histogram_proto_cc_impl"), + clean_dep("@local_xla//xla/tsl/protobuf:status_proto_cc_impl"), clean_dep("//tsl/profiler/protobuf:xplane_proto_cc_impl"), clean_dep("//tsl/profiler/protobuf:profiler_options_proto_cc_impl"), ], @@ -789,7 +789,7 @@ def tsl_cc_test( ) def tf_portable_proto_lib(): - return ["//tensorflow/core:protos_all_cc_impl", clean_dep("//tsl/protobuf:protos_all_cc_impl")] + return ["//tensorflow/core:protos_all_cc_impl", clean_dep("@local_xla//xla/tsl/protobuf:protos_all_cc_impl")] def tf_protobuf_compiler_deps(): return if_static( diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_test.cc b/third_party/xla/third_party/tsl/tsl/platform/status_test.cc index 6d9948fa68d99b..fbfdab8910d818 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/status_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_format.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/stack_frame.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/test.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc index 96ad290f92c71a..e83fa7d1bbc223 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc @@ -16,9 +16,9 @@ limitations under the License. #include +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/status.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h index 9891737f08159c..021e002ae4041d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h @@ -15,8 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD index 6922da086f24f4..fdf0014617a391 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD +++ b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD @@ -1,4 +1,3 @@ -# Placeholder: load py_proto_library load( "@local_xla//xla/tsl:tsl.bzl", "if_google", @@ -20,14 +19,6 @@ package( licenses = ["notice"], ) -tf_proto_library( - name = "dnn_proto", - srcs = ["dnn.proto"], - make_default_target_header_only = True, - protodeps = if_google(["//google/protobuf:wrappers"]), - visibility = ["//visibility:public"], -) - tf_proto_library( name = "error_codes_proto_impl", srcs = ["error_codes.proto"], @@ -35,78 +26,3 @@ tf_proto_library( protodeps = if_google(["//google/protobuf:any"]), visibility = ["//visibility:public"], ) - -tf_proto_library( - name = "status_proto", - srcs = ["status.proto"], - make_default_target_header_only = True, - protodeps = [":error_codes_proto_impl"], - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "histogram_proto", - srcs = ["histogram.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "coordination_config_proto", - srcs = ["coordination_config.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "coordination_service_proto", - srcs = ["coordination_service.proto"], - has_services = 1, - create_grpc_library = True, - create_java_proto = False, - create_service = True, - protodeps = if_google(["//google/protobuf:any"]), - visibility = ["//visibility:public"], -) - -# copybara:uncomment_begin(google-only) -# py_proto_library( -# name = "coordination_service_py_pb2", -# api_version = 2, -# visibility = ["//visibility:public"], -# deps = [":coordination_service_proto"], -# ) -# copybara:uncomment_end - -tf_proto_library( - name = "distributed_runtime_payloads_proto", - srcs = ["distributed_runtime_payloads.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "rpc_options_proto", - srcs = ["rpc_options.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "protos_all", - create_go_proto = False, - make_default_target_header_only = True, - protodeps = [ - # TODO(tlongeri): Conceptually, these fit into protos_all but adding them currently causes - # breakages (and they are not actually used). - "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", - ":coordination_config_proto", - ":distributed_runtime_payloads_proto", - ":error_codes_proto_impl", - ":histogram_proto", - ":rpc_options_proto", - ":status_proto", - "@local_xla//xla/tsl/protobuf:test_log_proto", - ] + if_google(["//google/protobuf:any"]), - visibility = ["//visibility:public"], -) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 039f05333174db..64553ebfafb0ee 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1248,7 +1248,7 @@ tf_proto_library( srcs = ["autotuning.proto"], make_default_target_header_only = True, protodeps = [ - "@local_tsl//tsl/protobuf:dnn_proto", + "//xla/tsl/protobuf:dnn_proto", ] + if_google([ "@com_google_protobuf//:any", "@com_google_protobuf//:duration", diff --git a/third_party/xla/xla/autotuning.proto b/third_party/xla/xla/autotuning.proto index 4cadf6dbb250eb..b3d6b8e380b4b8 100644 --- a/third_party/xla/xla/autotuning.proto +++ b/third_party/xla/xla/autotuning.proto @@ -9,7 +9,7 @@ package xla; import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; -import "tsl/protobuf/dnn.proto"; +import "xla/tsl/protobuf/dnn.proto"; message CudnnVersion { int32 major = 1; diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 481ddfa0cabc29..dc4933c984ddb4 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -28,6 +28,7 @@ cc_library( "//xla/tsl/distributed_runtime/coordination:coordination_service_impl", "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", + "//xla/tsl/protobuf:coordination_config_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", @@ -37,7 +38,6 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:random", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", ] + tsl_grpc_cc_dependencies(), ) @@ -73,6 +73,8 @@ cc_library( "//xla/tsl/distributed_runtime/coordination:coordination_client", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -81,8 +83,6 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/pjrt/distributed/client.cc b/third_party/xla/xla/pjrt/distributed/client.cc index 0f4f7fff9d809d..e447afccd73c01 100644 --- a/third_party/xla/xla/pjrt/distributed/client.cc +++ b/third_party/xla/xla/pjrt/distributed/client.cc @@ -36,9 +36,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/distributed/service.cc b/third_party/xla/xla/pjrt/distributed/service.cc index 6a8a77a5fca534..51729532b63709 100644 --- a/third_party/xla/xla/pjrt/distributed/service.cc +++ b/third_party/xla/xla/pjrt/distributed/service.cc @@ -24,10 +24,10 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace { diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 6665d25c6cc013..51795c0f3b32ba 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -651,10 +651,10 @@ xla_cc_test( ":serdes", ":serdes_proto_cc", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:status_proto_cc", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc b/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc index 4edfae40571cae..163b677e372ad1 100644 --- a/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc +++ b/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc @@ -19,9 +19,9 @@ #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/serdes.pb.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/statusor.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 8cf33e364e988f..ac827cfa0c643d 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -433,6 +433,7 @@ cc_library( "//xla/python/ifrt", "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/tsl/protobuf:status_proto_cc", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -441,7 +442,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:unbounded_work_queue", - "@local_tsl//tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc index 442381ef0abd50..b80f84b0593779 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc @@ -33,9 +33,9 @@ #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/unbounded_work_queue.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt_proxy/common/BUILD b/third_party/xla/xla/python/ifrt_proxy/common/BUILD index 724e22bb0659c4..23859f59e1ad8f 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/common/BUILD @@ -76,7 +76,7 @@ tf_proto_library( "//xla/python/ifrt:serdes_proto", "//xla/python/ifrt:shape_proto", "//xla/python/ifrt:sharding_proto", - "@local_tsl//tsl/protobuf:status_proto", + "//xla/tsl/protobuf:status_proto", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 3a047542402488..4a342c253af9cc 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -25,8 +25,8 @@ import "xla/python/ifrt/serdes.proto"; import "xla/python/ifrt/shape.proto"; import "xla/python/ifrt/sharding.proto"; import "xla/python/ifrt_proxy/common/types.proto"; +import "xla/tsl/protobuf/status.proto"; import "xla/xla_data.proto"; -import "tsl/protobuf/status.proto"; option cc_enable_arenas = true; diff --git a/third_party/xla/xla/python/ifrt_proxy/server/BUILD b/third_party/xla/xla/python/ifrt_proxy/server/BUILD index d426fd428fe62c..6eb717ed18dc5e 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/server/BUILD @@ -183,6 +183,7 @@ ifrt_proxy_cc_test( "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:ref_count", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:status_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", @@ -203,7 +204,6 @@ ifrt_proxy_cc_test( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 160d6fe2885a61..fa23f80fb43725 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -70,6 +70,7 @@ #include "xla/test.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/status.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -79,7 +80,6 @@ #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 6fa72a430d8fde..93e581a80c78e6 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -8475,7 +8475,7 @@ tf_proto_library( make_default_target_header_only = True, protodeps = [ ":hlo_proto", - "@local_tsl//tsl/protobuf:status_proto", + "//xla/tsl/protobuf:status_proto", ] + if_google(["@com_google_protobuf//:duration"]), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 7045d03c85cb73..edc11b7c230bf2 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -73,7 +73,7 @@ tf_proto_library( protodeps = [ "//xla:xla_data_proto", "//xla:autotuning_proto", - "@local_tsl//tsl/protobuf:dnn_proto", + "//xla/tsl/protobuf:dnn_proto", ], ) @@ -392,6 +392,7 @@ cc_library( "//xla/stream_executor:launch_dim", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/integrations:device_mem_allocator", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -419,7 +420,6 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:human_readable_json", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", "@triton//:TritonDialects", ] + if_gpu_is_configured([ "//xla/service/gpu/runtime:cholesky_thunk", @@ -2290,6 +2290,7 @@ cc_library( "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", "//xla/stream_executor:typed_kernel_factory", + "//xla/tsl/protobuf:dnn_proto_cc", "//xla/tsl/util:env_var", "//xla/tsl/util/proto:proto_utils", "@com_google_absl//absl/algorithm:container", @@ -2306,7 +2307,6 @@ cc_library( "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ ":stream_executor_util_kernel", ]), diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 42d9d85bf528a9..7a51cf0c3c8e11 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -325,12 +325,12 @@ xla_test( "//xla/stream_executor:semantic_version", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc index b88c2b0916ce08..906c908d8b281d 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc @@ -35,10 +35,10 @@ limitations under the License. #include "xla/stream_executor/semantic_version.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/xla.pb.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/dnn.pb.h" namespace xla::gpu { namespace { diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index b23fe4f95629a7..b77d302d64e9ae 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -3,8 +3,8 @@ syntax = "proto3"; package xla.gpu; import "xla/autotuning.proto"; +import "xla/tsl/protobuf/dnn.proto"; import "xla/xla_data.proto"; -import "tsl/protobuf/dnn.proto"; // Backend configs for XLA:GPU. // diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index b4406783adcb0e..b7da6271befc9f 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -159,12 +159,12 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/human_readable_json.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" #include "triton/Dialect/Triton/IR/Dialect.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 5a515a8a2d5ce8..7f8bed25afd430 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -58,13 +58,13 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/typed_kernel_factory.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.h b/third_party/xla/xla/service/gpu/stream_executor_util.h index 877ad8bcc62f48..d0338595f9f17d 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.h +++ b/third_party/xla/xla/service/gpu/stream_executor_util.h @@ -37,8 +37,8 @@ limitations under the License. #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/protobuf/dnn.pb.h" // Helper functions for interacting with StreamExecutor. diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index e70d040a304457..1a04da032cfce6 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -987,6 +987,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/stream_executor", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -997,7 +998,6 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", @@ -1690,6 +1690,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1701,7 +1702,6 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc index 5d5e089933fd88..752549dc7ec501 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc @@ -43,12 +43,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 32ed147415b962..9ba39cf977c0bb 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -63,13 +63,13 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/semantic_version.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/xla_compile_result.proto b/third_party/xla/xla/service/xla_compile_result.proto index 5846b8b11bacc2..7634661596a5ad 100644 --- a/third_party/xla/xla/service/xla_compile_result.proto +++ b/third_party/xla/xla/service/xla_compile_result.proto @@ -19,7 +19,7 @@ package xla; import "google/protobuf/duration.proto"; import "xla/service/hlo.proto"; -import "tsl/protobuf/status.proto"; +import "xla/tsl/protobuf/status.proto"; // Statistics on how long various parts of compilation took. // Not all durations may be relevant for all producers of this message, in diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 295abcb3cd7740..b0e8b0bdf5d299 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -108,6 +108,7 @@ cc_library( "//xla/tsl/framework:device_id", "//xla/tsl/framework:device_type", "//xla/tsl/lib/gtl:int_type", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -130,7 +131,6 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_static([ ":stream_executor_impl", ]) + if_google([ @@ -282,8 +282,8 @@ cc_library( name = "data_type", hdrs = ["data_type.h"], deps = [ + "//xla/tsl/protobuf:dnn_proto_cc", "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -412,12 +412,12 @@ cc_library( ":device_memory", ":numeric_options", "//xla/stream_executor/platform", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -432,6 +432,7 @@ cc_library( ":numeric_options", "//xla/stream_executor/platform", "//xla/tsl/lib/strings:proto_serialization", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -447,7 +448,6 @@ cc_library( "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_google(["@com_google_protobuf//:wrappers_cc_proto"]), ) @@ -466,11 +466,11 @@ cc_library( ":dnn", ":stream", ":stream_executor_h", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -791,7 +791,7 @@ cc_library( ":stream_common", ":stream_executor_common", ":stream_executor_h", - ] + if_oss(["@local_tsl//tsl/protobuf:dnn_proto_cc_impl"]), + ] + if_oss(["//xla/tsl/protobuf:dnn_proto_cc_impl"]), ) #===--------------------------------------------------------------------------------------------===# diff --git a/third_party/xla/xla/stream_executor/blas.h b/third_party/xla/xla/stream_executor/blas.h index cf235c83a9729c..73814f0e467a3e 100644 --- a/third_party/xla/xla/stream_executor/blas.h +++ b/third_party/xla/xla/stream_executor/blas.h @@ -35,8 +35,8 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/port.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/errors.h" -#include "tsl/protobuf/dnn.pb.h" namespace Eigen { struct half; diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 68c13a736099fb..a94edde0814b72 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -379,6 +379,7 @@ cuda_only_cc_library( "//xla/stream_executor/platform", "//xla/tsl/cuda:cublas", "//xla/tsl/cuda:cublas_lt", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -395,7 +396,6 @@ cuda_only_cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_static([ "@local_tsl//tsl/platform:tensor_float_32_utils", ]), @@ -495,6 +495,7 @@ cuda_only_cc_library( "//xla/stream_executor/gpu:scoped_activate_context", "//xla/stream_executor/platform", "//xla/tsl/cuda:cudnn", + "//xla/tsl/protobuf:dnn_proto_cc", "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -519,7 +520,6 @@ cuda_only_cc_library( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_tsl//tsl/platform:tensor_float_32_utils", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], alwayslink = True, ) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc index f6628ab0edfad4..297452d7c3ea78 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc @@ -54,11 +54,11 @@ limitations under the License. #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" -#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace cuda { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 6d09f0f627ad91..29f7ddd1754df2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -67,13 +67,13 @@ limitations under the License. #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" -#include "tsl/protobuf/dnn.pb.h" // clang-format off #include "third_party/gpus/cuda/include/library_types.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 3a223731347766..3eb702b9f4415e 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -34,7 +34,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/numeric_options.h" -#include "tsl/protobuf/dnn.pb.h" +#include "xla/tsl/protobuf/dnn.pb.h" #if CUDNN_VERSION >= 8100 #include "third_party/cudnn_frontend/include/cudnn_frontend.h" diff --git a/third_party/xla/xla/stream_executor/data_type.h b/third_party/xla/xla/stream_executor/data_type.h index ebac59ba7c4eae..03b09f9b644f07 100644 --- a/third_party/xla/xla/stream_executor/data_type.h +++ b/third_party/xla/xla/stream_executor/data_type.h @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/protobuf/dnn.pb.h" namespace Eigen { struct bfloat16; diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index 951b2f6e147cd8..10270d0b3c1be2 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -42,8 +42,8 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace dnn { diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index b1b89ff1c59d59..e6d3e67c68c87d 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -44,8 +44,8 @@ limitations under the License. #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/logging.h" -#include "tsl/protobuf/dnn.pb.h" namespace Eigen { struct half; diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 711104c6ffc673..7c0bdb7d3c0361 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -657,6 +657,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:host_or_device_scalar", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -664,7 +665,6 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", ]) + if_static([ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc index 6f931aeb6324fd..6a604e20619455 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -27,9 +27,9 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA #include "tsl/platform/tensor_float_32_utils.h" #endif diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h index bf964e05bbaae6..f3c8d004397639 100644 --- a/third_party/xla/xla/stream_executor/lazy_op_runner.h +++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h @@ -29,8 +29,8 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace dnn { diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 49aab5056cd884..09fddc032ccd9a 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -853,6 +853,7 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:status_proto_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest", @@ -864,7 +865,6 @@ xla_test( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/protobuf:status_proto_cc", ] + if_google(["@com_google_protobuf//:duration_cc_proto"]), ) @@ -896,6 +896,7 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:status_proto_cc", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", @@ -903,7 +904,6 @@ xla_test( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc index 62c06734ddb990..fcfa0001c7d1c3 100644 --- a/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tools/xla_compile_lib.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/status.pb.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc index bc34c8790fb14e..87084469f4cfe2 100644 --- a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tools/xla_compile_lib.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/status.pb.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index 91e77bda9d79e6..1caa5d166f5ce9 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -16,11 +16,11 @@ cc_library( srcs = ["coordination_service_error_util.cc"], hdrs = ["coordination_service_error_util.h"], deps = [ + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@local_tsl//tsl/platform:regexp", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -29,12 +29,12 @@ tsl_cc_test( srcs = ["coordination_service_error_util_test.cc"], deps = [ ":coordination_service_error_util", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -43,8 +43,8 @@ cc_library( hdrs = ["coordination_client.h"], deps = [ "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -53,14 +53,14 @@ cc_library( hdrs = ["coordination_service.h"], deps = [ ":coordination_client", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -75,6 +75,8 @@ tsl_gpu_library( ":coordination_service", ":coordination_service_error_util", "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -91,8 +93,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], alwayslink = 1, ) @@ -118,6 +118,8 @@ tsl_cc_test( ":test_device_proto_cc", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -131,8 +133,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -146,6 +146,8 @@ tsl_gpu_library( "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/framework:cancellation", "//xla/tsl/lib/monitoring:gauge", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", @@ -159,8 +161,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -172,6 +172,8 @@ tsl_cc_test( ":coordination_service_agent", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:coordination_config_proto_cc_impl", + "//xla/tsl/protobuf:coordination_service_proto_cc_impl", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -182,8 +184,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc_impl", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_impl", ], ) @@ -197,6 +197,7 @@ cc_library( ":coordination_service", ":coordination_service_agent", ":coordination_service_error_util", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -206,7 +207,6 @@ cc_library( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -222,6 +222,9 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:coordination_config_proto_cc_impl", + "//xla/tsl/protobuf:coordination_service_proto_cc_impl", + "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -233,9 +236,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc_impl", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_impl", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h index cea5ba4890d37b..71bc536af63135 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h @@ -20,8 +20,8 @@ limitations under the License. #include #include "xla/tsl/distributed_runtime/call_options.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { using tensorflow::BarrierRequest; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index a6e835c8e6a2d4..de0033e92fb33e 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -46,12 +46,12 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "xla/tsl/util/device_name_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h index 9ef96f1f6b425a..75fe592e59a7b1 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -30,10 +30,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { class Env; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc index 4290fba754f880..b72ae29db85056 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -45,12 +45,12 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/framework/cancellation.h" #include "xla/tsl/lib/monitoring/gauge.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { using tensorflow::CoordinatedTask; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h index 3ec188ac251801..6c58501d51e112 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h @@ -28,8 +28,8 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { class CoordinationServiceConfig; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc index 1281ea8f78988f..8e6783f0bccc97 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc @@ -30,11 +30,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h index e1a3cdc06eefe9..07df399979a6fa 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" -#include "tsl/protobuf/coordination_service.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc index 535f471f0a3fc1..f021ae25a15435 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { using ::tensorflow::CoordinatedTask; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc index 3ec3290c9507e1..737091b1ca7fc3 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc @@ -32,11 +32,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc index 200db9df7ee232..0294d2bbbaab75 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc @@ -30,9 +30,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h index 537a5d5be3a652..2b9ca2ef9f3d2e 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h @@ -20,9 +20,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { class CoordinationServiceRpcHandler { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 5a0093c9d6c972..84d4f19d61f4de 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -37,13 +37,13 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/distributed_runtime/coordination/test_device.pb.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD index b533f7b4bd88ef..1789c00cdd4316 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD @@ -55,6 +55,7 @@ cc_library( "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "//xla/tsl/lib/monitoring:gauge", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -65,7 +66,6 @@ cc_library( "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -83,6 +83,9 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", + "//xla/tsl/protobuf:coordination_config_proto_cc_impl", + "//xla/tsl/protobuf:coordination_service_proto_cc_impl", + "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -92,8 +95,5 @@ tsl_cc_test( "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc_impl", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_impl", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index ee85f70e04b771..c6e41a9f030f62 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -37,9 +37,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" #include "xla/tsl/lib/monitoring/gauge.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc index e02c7b03b7f917..616b8ccd5fcf99 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc @@ -34,10 +34,10 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD index 5b7462f31a7ffd..985a709fe09aeb 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD @@ -36,13 +36,13 @@ cc_library( srcs = ["grpc_util.cc"], hdrs = ["grpc_util.h"], deps = [ + "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:stringprintf", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc", ] + tsl_grpc_cc_dependencies(), ) @@ -56,12 +56,12 @@ tsl_cc_test( deps = [ ":grpc_util", ":test_request_proto_cc_impl", + "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", ] + tsl_grpc_cc_dependencies(), ) @@ -84,6 +84,7 @@ cc_library( ":grpc_channel_common", ":grpc_util", "//xla/tsl/lib/gtl:map_util", + "//xla/tsl/protobuf:rpc_options_proto_cc", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", @@ -96,7 +97,6 @@ cc_library( "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc", ] + tsl_grpc_cc_dependencies(), ) @@ -109,12 +109,12 @@ tsl_cc_test( deps = [ ":grpc_channel", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:rpc_options_proto_cc_impl", "//xla/tsl/util:device_name_utils", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc_impl", ], ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD index 838c2cbdf5ab5c..c14352c8375163 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD @@ -21,13 +21,13 @@ cc_library( "//xla/tsl/distributed_runtime/rpc:grpc_client_cq_tag", "//xla/tsl/distributed_runtime/rpc:grpc_state", "//xla/tsl/distributed_runtime/rpc:grpc_util", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ] + tsl_grpc_cc_dependencies(), ) @@ -42,12 +42,12 @@ cc_library( "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc:grpc_call", "//xla/tsl/distributed_runtime/rpc:grpc_util", + "//xla/tsl/protobuf:coordination_service_cc_grpc_proto", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/protobuf:coordination_service_cc_grpc_proto", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc index 6bd7885d2cafb7..639e2f2e10ec25 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc @@ -34,10 +34,10 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "xla/tsl/distributed_runtime/rpc/grpc_state.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h index 4b6c74e0b870af..0fdaafc9f579bb 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h @@ -30,9 +30,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/grpc_call.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" +#include "xla/tsl/protobuf/coordination_service.grpc.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/coordination_service.grpc.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc index bd925a64bf0335..2ebb8cc7e9499b 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc @@ -27,6 +27,7 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "xla/tsl/distributed_runtime/rpc/grpc_channel_common.h" #include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "xla/tsl/util/device_name_utils.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/platform/strcat.h" #include "tsl/platform/thread_annotations.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.h index d1fcba72793483..de9aadff1db4af 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.h @@ -24,7 +24,7 @@ limitations under the License. #include "grpcpp/grpcpp.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" -#include "tsl/protobuf/rpc_options.pb.h" +#include "xla/tsl/protobuf/rpc_options.pb.h" namespace tsl { using tensorflow::RPCOptions; diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc index 80c976640fa6f1..2790b0cd65dc44 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "xla/tsl/util/device_name_utils.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tsl { #define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h index 8e3e328cefd8eb..d39eb8e0f1be56 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h @@ -23,11 +23,11 @@ limitations under the License. #include "absl/strings/cord.h" #include "grpcpp/grpcpp.h" #include "grpcpp/support/byte_buffer.h" +#include "xla/tsl/protobuf/distributed_runtime_payloads.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/stringprintf.h" -#include "tsl/protobuf/distributed_runtime_payloads.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/histogram/BUILD b/third_party/xla/xla/tsl/lib/histogram/BUILD index cbd206f6bd8083..fc486455c63c14 100644 --- a/third_party/xla/xla/tsl/lib/histogram/BUILD +++ b/third_party/xla/xla/tsl/lib/histogram/BUILD @@ -20,12 +20,12 @@ cc_library( hdrs = ["histogram.h"], visibility = ["//visibility:public"], deps = [ + "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], alwayslink = True, ) @@ -55,9 +55,9 @@ tsl_cc_test( ], deps = [ ":histogram", + "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) diff --git a/third_party/xla/xla/tsl/lib/histogram/histogram.cc b/third_party/xla/xla/tsl/lib/histogram/histogram.cc index e8203549272547..35ff514e1fe1dd 100644 --- a/third_party/xla/xla/tsl/lib/histogram/histogram.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram.cc @@ -20,10 +20,10 @@ limitations under the License. #include +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace histogram { diff --git a/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc index 4051d98f49ab97..1b2f1827521a17 100644 --- a/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace histogram { diff --git a/third_party/xla/xla/tsl/lib/monitoring/BUILD b/third_party/xla/xla/tsl/lib/monitoring/BUILD index 008246504b846d..138efecd6b2580 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/BUILD +++ b/third_party/xla/xla/tsl/lib/monitoring/BUILD @@ -72,6 +72,7 @@ cc_library( ":collection_registry", ":metric_def", "//xla/tsl/lib/histogram", + "//xla/tsl/protobuf:histogram_proto_cc", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@local_tsl//tsl/platform", @@ -80,7 +81,6 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) @@ -100,9 +100,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", + "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) @@ -115,6 +115,7 @@ cc_library( ":collected_metrics", ":metric_def", ":types", + "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", @@ -123,7 +124,6 @@ cc_library( "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) @@ -135,7 +135,7 @@ cc_library( deps = [ ":metric_def", ":types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", + "//xla/tsl/protobuf:histogram_proto_cc", ], ) @@ -201,12 +201,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", + "//xla/tsl/protobuf:histogram_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) diff --git a/third_party/xla/xla/tsl/lib/monitoring/collected_metrics.h b/third_party/xla/xla/tsl/lib/monitoring/collected_metrics.h index 48b655c2a8a2ba..8e305493e83c6b 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/collected_metrics.h +++ b/third_party/xla/xla/tsl/lib/monitoring/collected_metrics.h @@ -27,7 +27,7 @@ limitations under the License. #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/types.h" -#include "tsl/protobuf/histogram.pb.h" +#include "xla/tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h b/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h index 46e93e5c2e46f9..6c48ea9114c8db 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h +++ b/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h @@ -110,6 +110,7 @@ class CollectionRegistry { #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" @@ -117,7 +118,6 @@ class CollectionRegistry { #include "tsl/platform/stringpiece.h" #include "tsl/platform/thread_annotations.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/metric_def.h b/third_party/xla/xla/tsl/lib/monitoring/metric_def.h index 05a1c44da5b9a9..dcee3f92db4c30 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/metric_def.h +++ b/third_party/xla/xla/tsl/lib/monitoring/metric_def.h @@ -22,9 +22,9 @@ limitations under the License. #include #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/sampler.h b/third_party/xla/xla/tsl/lib/monitoring/sampler.h index 34e7d79b9cced3..3976e312876cb4 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/sampler.h +++ b/third_party/xla/xla/tsl/lib/monitoring/sampler.h @@ -29,10 +29,10 @@ limitations under the License. #include #include "xla/tsl/lib/monitoring/metric_def.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { @@ -125,11 +125,11 @@ class Sampler { #include "xla/tsl/lib/histogram/histogram.h" #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" #include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc b/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc index b895ead9d0b923..3691130880ab24 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc +++ b/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/errors.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/test_utils.h b/third_party/xla/xla/tsl/lib/monitoring/test_utils.h index d761c9d27ce039..85101ebffc6d69 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/test_utils.h +++ b/third_party/xla/xla/tsl/lib/monitoring/test_utils.h @@ -19,8 +19,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/protobuf/BUILD b/third_party/xla/xla/tsl/protobuf/BUILD index d4d6f822814797..b80695437d3423 100644 --- a/third_party/xla/xla/tsl/protobuf/BUILD +++ b/third_party/xla/xla/tsl/protobuf/BUILD @@ -1,7 +1,6 @@ -load( - "@local_tsl//tsl/platform:build_config.bzl", - "tf_proto_library", -) +load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") + +# copybara:uncomment load("@rules_python//python:proto.bzl", "py_proto_library") load( "//xla/tsl:tsl.bzl", "if_google", @@ -36,3 +35,86 @@ tf_proto_library( ]), visibility = ["//visibility:public"], ) + +tf_proto_library( + name = "dnn_proto", + srcs = ["dnn.proto"], + make_default_target_header_only = True, + protodeps = if_google(["@com_google_protobuf//:wrappers"]), + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "status_proto", + srcs = ["status.proto"], + make_default_target_header_only = True, + protodeps = ["@local_tsl//tsl/protobuf:error_codes_proto_impl"], + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "histogram_proto", + srcs = ["histogram.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "coordination_config_proto", + srcs = ["coordination_config.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "coordination_service_proto", + srcs = ["coordination_service.proto"], + has_services = 1, + create_grpc_library = True, + create_java_proto = False, + create_service = True, + protodeps = if_google(["@com_google_protobuf//:any"]), + visibility = ["//visibility:public"], +) + +# copybara:uncomment_begin(google-only) +# py_proto_library( +# name = "coordination_service_py_pb2", +# api_version = 2, +# visibility = ["//visibility:public"], +# deps = [":coordination_service_proto"], +# ) +# copybara:uncomment_end + +tf_proto_library( + name = "distributed_runtime_payloads_proto", + srcs = ["distributed_runtime_payloads.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "rpc_options_proto", + srcs = ["rpc_options.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "protos_all", + create_go_proto = False, + make_default_target_header_only = True, + protodeps = [ + # TODO(tlongeri): Conceptually, these fit into protos_all but adding them currently causes + # breakages (and they are not actually used). + ":bfc_memory_map_proto", + ":coordination_config_proto", + ":distributed_runtime_payloads_proto", + "@local_tsl//tsl/protobuf:error_codes_proto_impl", + ":histogram_proto", + ":rpc_options_proto", + ":status_proto", + ":test_log_proto", + ] + if_google(["@com_google_protobuf//:any"]), + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/coordination_config.proto b/third_party/xla/xla/tsl/protobuf/coordination_config.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/coordination_config.proto rename to third_party/xla/xla/tsl/protobuf/coordination_config.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/coordination_service.proto b/third_party/xla/xla/tsl/protobuf/coordination_service.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/coordination_service.proto rename to third_party/xla/xla/tsl/protobuf/coordination_service.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/distributed_runtime_payloads.proto b/third_party/xla/xla/tsl/protobuf/distributed_runtime_payloads.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/distributed_runtime_payloads.proto rename to third_party/xla/xla/tsl/protobuf/distributed_runtime_payloads.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto b/third_party/xla/xla/tsl/protobuf/dnn.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto rename to third_party/xla/xla/tsl/protobuf/dnn.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/histogram.proto b/third_party/xla/xla/tsl/protobuf/histogram.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/histogram.proto rename to third_party/xla/xla/tsl/protobuf/histogram.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/rpc_options.proto b/third_party/xla/xla/tsl/protobuf/rpc_options.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/rpc_options.proto rename to third_party/xla/xla/tsl/protobuf/rpc_options.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/status.proto b/third_party/xla/xla/tsl/protobuf/status.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/status.proto rename to third_party/xla/xla/tsl/protobuf/status.proto diff --git a/third_party/xla/xla/xla.bzl b/third_party/xla/xla/xla.bzl index a81deeb562893c..43f002ab499ee1 100644 --- a/third_party/xla/xla/xla.bzl +++ b/third_party/xla/xla/xla.bzl @@ -54,7 +54,7 @@ _XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_static(extra_deps = [], otherwise = [ "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl", "@local_tsl//tsl/profiler/utils:time_utils_impl", - "@local_tsl//tsl/protobuf:protos_all_cc_impl", + "//xla/tsl/protobuf:protos_all_cc_impl", ]) + if_cuda_is_configured([ Label("//xla/stream_executor/cuda:all_runtime"), Label("//xla/stream_executor/cuda:stream_executor_cuda"), From cd7b8d76c2f21a27935d6f67459c87e3b69d4f62 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 15:32:03 -0700 Subject: [PATCH 285/483] Simplify barrier time out logging. PiperOrigin-RevId: 678871568 --- .../coordination/coordination_service.cc | 64 +++++++++---------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index de0033e92fb33e..8f6cd25717157f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -505,9 +506,9 @@ void CoordinationServiceStandaloneImpl::CheckHeartbeatTimeout() { } void CoordinationServiceStandaloneImpl::CheckBarrierTimeout() { - const bool has_service_to_client_connection = client_cache_ != nullptr; absl::flat_hash_map expired_barriers; uint64_t current_time_micros = Env::Default()->NowMicros(); + std::optional shutdown_error; { absl::MutexLock l(&state_mu_); // Gather barriers which have timed out. @@ -521,46 +522,40 @@ void CoordinationServiceStandaloneImpl::CheckBarrierTimeout() { for (const auto& [barrier_id, barrier] : expired_barriers) { std::string pending_tasks; int pending_task_count = 0; + // Count and track pending tasks that have not reached the barrier. for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { if (at_barrier) { continue; } ++pending_task_count; - if (pending_task_count > kPendingTaskLogLimit) { - break; + if (pending_task_count < kPendingTaskLogLimit) { + absl::StrAppend(&pending_tasks, GetTaskName(task), "\n"); } - absl::StrAppend(&pending_tasks, GetTaskName(task), "\n"); } + const int64_t tasks_at_barrier = + barrier->tasks_at_barrier.size() - pending_task_count; std::string error_message = absl::StrFormat( - "Barrier timed out. This usually happens because a task " - "triggered the barrier unexpectedly early, or some tasks are " - "too slow. Please look at the other task logs to debug " - "further. Barrier_id: %s. The first task at the barrier: " - "%s. ", - barrier_id, GetTaskName(barrier->initiating_task)); - if (pending_task_count > kPendingTaskLogLimit) { - absl::StrAppend( - &error_message, "Too many tasks have timed out. The first ", - kPendingTaskLogLimit, " timed out task names:\n", pending_tasks); - } else { - absl::StrAppend(&error_message, - "Total Number of tasks already at the barrier: ", - barrier->tasks_at_barrier.size() - pending_task_count, - "/", barrier->tasks_at_barrier.size(), - ". Timed out task names:\n", pending_tasks); + "Barrier timed out. Id: %s. This usually happens because a task " + "triggered the barrier too early or too slowly. Please look at the " + "task logs (both timed out and first task) to debug further.\n" + "# of tasks that reached the barrier: %d/%d.\nThe first " + "task at the barrier: %s. Some timed out task names:\n%s", + barrier_id, tasks_at_barrier, barrier->tasks_at_barrier.size(), + GetTaskName(barrier->initiating_task), pending_tasks); + if (barrier_id == shutdown_barrier_id_) { + shutdown_error = error_message; } const absl::Status error = MakeCoordinationError(absl::DeadlineExceededError(error_message)); PassBarrier(barrier_id, error, barrier); } } - if (!has_service_to_client_connection && - expired_barriers.contains(shutdown_barrier_id_)) { + const bool has_service_to_client_connection = client_cache_ != nullptr; + if (!has_service_to_client_connection && shutdown_error) { // Error cannot be propagated through service-to-client connection. SendErrorPollingResponseOrStopService( - MakeCoordinationError(absl::DeadlineExceededError( - "Shutdown barrier timed out. Check the task logs for an " - "earlier error."))); + MakeCoordinationError(absl::DeadlineExceededError(absl::StrCat( + "Shutdown barrier timed out. Error: ", *shutdown_error)))); } } @@ -815,8 +810,8 @@ absl::Status CoordinationServiceStandaloneImpl::DisconnectTask( for (const auto& barrier_id : cluster_state_[task_name]->GetOngoingBarriers()) { absl::Status error = MakeCoordinationError(absl::InternalError(absl::StrCat( - "Barrier failed from a disconnected task. Barrier Id: ", barrier_id, - ", Task: ", task_name))); + "Barrier failed because a task has disconnected. Barrier Id: ", + barrier_id, ", Task: ", task_name))); PassBarrier(barrier_id, error, &barriers_[barrier_id]); } @@ -1161,10 +1156,11 @@ void CoordinationServiceStandaloneImpl::SetTaskError(std::string_view task_name, cluster_state_[task_name]->SetError(error); for (const auto& barrier_id : cluster_state_[task_name]->GetOngoingBarriers()) { - absl::Status error = MakeCoordinationError(absl::InternalError(absl::StrCat( - "Barrier failed from a task error. Barrier Id: ", barrier_id, - ", Task: ", task_name))); - PassBarrier(barrier_id, error, &barriers_[barrier_id]); + absl::Status barrier_error = + MakeCoordinationError(absl::InternalError(absl::StrCat( + "Barrier failed beacuse a task is in error. Barrier Id: ", + barrier_id, ", Task: ", task_name, "Error: ", error.message()))); + PassBarrier(barrier_id, barrier_error, &barriers_[barrier_id]); } LOG(ERROR) << task_name @@ -1452,9 +1448,9 @@ void CoordinationServiceStandaloneImpl::PassBarrier(std::string_view barrier_id, "an earlier error to identify the root cause."; } absl::Status shutdown_error = MakeCoordinationError(absl::InternalError( - absl::StrCat("Shutdown barrier has been passed with status: '", - barrier->result.ToString(), - "', but this task is not at the barrier yet."))); + absl::StrCat("Shutdown barrier has failed, but this task is not at the " + "barrier yet.\nBarrier result: '", + barrier->result.message()))); for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { if (at_barrier) { // Disconnect tasks that reached the barrier. From 85af80edf0257e8dfc7e82ecc783ae082df1f7a2 Mon Sep 17 00:00:00 2001 From: Luke Boyer Date: Wed, 25 Sep 2024 15:52:09 -0700 Subject: [PATCH 286/483] Move lrt to runtime PiperOrigin-RevId: 678878146 --- tensorflow/compiler/mlir/lite/BUILD | 1 + tensorflow/compiler/mlir/lite/core/BUILD | 1 + .../compiler/mlir/lite/experimental/lrt/BUILD | 35 ------------------- tensorflow/lite/core/c/BUILD | 1 - tensorflow/lite/experimental/lrt/BUILD | 17 +++++++++ .../lite/experimental/lrt/apply_plugin.cc | 16 ++++----- .../mlir => }/lite/experimental/lrt/c/BUILD | 2 +- .../lite/experimental/lrt/c/lite_rt_common.h | 6 ++-- .../lrt/c/lite_rt_compiler_plugin.h | 10 +++--- .../lite/experimental/lrt/c/lite_rt_model.h | 10 +++--- .../lite/experimental/lrt/c/lite_rt_op_code.h | 6 ++-- .../lite/experimental/lrt/c/lite_rt_support.h | 8 ++--- .../mlir => }/lite/experimental/lrt/cc/BUILD | 4 +-- .../experimental/lrt/cc/lite_rt_support.h | 12 +++---- .../lite/experimental/lrt/core/BUILD | 20 +++++------ .../lite/experimental/lrt/core/algo.h | 14 ++++---- .../lite/experimental/lrt/core/algo_test.cc | 14 ++++---- .../lite/experimental/lrt/core/graph_tools.h | 14 ++++---- .../experimental/lrt/core/lite_rt_common.cc | 2 +- .../lrt/core/lite_rt_model_init.cc | 13 ++++--- .../lrt/core/lite_rt_model_init.h | 8 ++--- .../lite/experimental/lrt/core/model.cc | 8 ++--- .../lite/experimental/lrt/core/model.h | 14 ++++---- .../lite/experimental/lrt/core/model_test.cc | 14 ++++---- .../lite/experimental/lrt/examples/BUILD | 18 +++++----- .../lrt/examples/mul_op_plugin.cc | 12 +++---- .../lrt/examples/mul_op_plugin_test.cc | 14 ++++---- .../lite/experimental/lrt/test_data/BUILD | 12 +++---- .../experimental/lrt/test_data/add_cst.mlir | 0 .../lrt/test_data/add_simple.mlir | 0 .../lrt/test_data/mul_simple.mlir | 0 .../lrt/test_data/simple_multi_op.mlir | 0 .../lrt/test_data/test_data_util.h | 14 ++++---- 33 files changed, 151 insertions(+), 169 deletions(-) delete mode 100644 tensorflow/compiler/mlir/lite/experimental/lrt/BUILD rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/apply_plugin.cc (91%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/c/BUILD (91%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/c/lite_rt_common.h (89%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/c/lite_rt_compiler_plugin.h (89%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/c/lite_rt_model.h (93%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/c/lite_rt_op_code.h (98%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/c/lite_rt_support.h (84%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/cc/BUILD (84%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/cc/lite_rt_support.h (92%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/BUILD (78%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/algo.h (95%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/algo_test.cc (94%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/graph_tools.h (95%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/lite_rt_common.cc (92%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/lite_rt_model_init.cc (97%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/lite_rt_model_init.h (86%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/model.cc (93%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/model.h (91%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/core/model_test.cc (95%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/examples/BUILD (63%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/examples/mul_op_plugin.cc (91%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/examples/mul_op_plugin_test.cc (83%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/test_data/BUILD (75%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/test_data/add_cst.mlir (100%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/test_data/add_simple.mlir (100%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/test_data/mul_simple.mlir (100%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/test_data/simple_multi_op.mlir (100%) rename tensorflow/{compiler/mlir => }/lite/experimental/lrt/test_data/test_data_util.h (78%) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 4008cf266f9a22..7bb70a19f4f116 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -27,6 +27,7 @@ package_group( "//third_party/iree/...", "//third_party/odml/infra/...", "//tensorflow/compiler/mlir/...", + "//tensorflow/lite/experimental/lrt/...", "//tensorflow/lite/python/...", "//waymo/accelerator/alpine/tools/...", "//waymo/ml/compiler/mlir/...", diff --git a/tensorflow/compiler/mlir/lite/core/BUILD b/tensorflow/compiler/mlir/lite/core/BUILD index d76299aa723d51..d359046609313f 100644 --- a/tensorflow/compiler/mlir/lite/core/BUILD +++ b/tensorflow/compiler/mlir/lite/core/BUILD @@ -29,6 +29,7 @@ cc_library( visibility = [ "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/lite/core:__pkg__", + "//tensorflow/lite/experimental/lrt:__subpackages__", ], deps = [ ":macros", diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/BUILD deleted file mode 100644 index 5d6dbd10c9c94f..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], -) - -cc_binary( - name = "apply_plugin", - srcs = [ - "apply_plugin.cc", - # TODO: b/366821557 - Support pre-compiled plugins as data dependencies. - "//tensorflow/compiler/mlir/lite/experimental/lrt/examples:mul_op_plugin_so", - ], - deps = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:algo", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:api_internal", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:lite_rt_model_init", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:model", - "//tensorflow/lite/schema:schema_fbs", - "@llvm-project//llvm:Support", - ], -) diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index 00a1a27ec6d819..af5bb111400c5b 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -15,7 +15,6 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__", "//tensorflow/lite:__subpackages__", ], licenses = ["notice"], diff --git a/tensorflow/lite/experimental/lrt/BUILD b/tensorflow/lite/experimental/lrt/BUILD index cd9efefb75dcab..04b5e0b2ab3c7c 100644 --- a/tensorflow/lite/experimental/lrt/BUILD +++ b/tensorflow/lite/experimental/lrt/BUILD @@ -16,3 +16,20 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], ) + +cc_binary( + name = "apply_plugin", + srcs = [ + "apply_plugin.cc", + # TODO: b/366821557 - Support pre-compiled plugins as data dependencies. + "//tensorflow/lite/experimental/lrt/examples:mul_op_plugin_so", + ], + deps = [ + "//tensorflow/lite/experimental/lrt/core:algo", + "//tensorflow/lite/experimental/lrt/core:api_internal", + "//tensorflow/lite/experimental/lrt/core:lite_rt_model_init", + "//tensorflow/lite/experimental/lrt/core:model", + "//tensorflow/lite/schema:schema_fbs", + "@llvm-project//llvm:Support", + ], +) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/apply_plugin.cc b/tensorflow/lite/experimental/lrt/apply_plugin.cc similarity index 91% rename from tensorflow/compiler/mlir/lite/experimental/lrt/apply_plugin.cc rename to tensorflow/lite/experimental/lrt/apply_plugin.cc index c6e5543d000484..49756ddebd1557 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/apply_plugin.cc +++ b/tensorflow/lite/experimental/lrt/apply_plugin.cc @@ -21,14 +21,14 @@ #include #include "llvm/Support/CommandLine.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/algo.h" +#include "tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h" +#include "tensorflow/lite/experimental/lrt/core/model.h" // NOLINTNEXTLINE static llvm::cl::opt model_path( diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/BUILD b/tensorflow/lite/experimental/lrt/c/BUILD similarity index 91% rename from tensorflow/compiler/mlir/lite/experimental/lrt/c/BUILD rename to tensorflow/lite/experimental/lrt/c/BUILD index d7cd167f2e40f3..e384ae13966e75 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/BUILD +++ b/tensorflow/lite/experimental/lrt/c/BUILD @@ -14,7 +14,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], + default_visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], ) cc_library( diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h b/tensorflow/lite/experimental/lrt/c/lite_rt_common.h similarity index 89% rename from tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h rename to tensorflow/lite/experimental/lrt/c/lite_rt_common.h index 9cd70dd4bc5168..423f208e90d7e0 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h +++ b/tensorflow/lite/experimental/lrt/c/lite_rt_common.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ #ifdef __cplusplus extern "C" { @@ -66,4 +66,4 @@ LrtStatus StatusOk(); } #endif // __cplusplus -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h b/tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h similarity index 89% rename from tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h rename to tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h index a254ae00a85ec5..20f82ebcdbc3e6 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h +++ b/tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" #ifdef __cplusplus extern "C" { @@ -92,4 +92,4 @@ LrtStatus LrtCompiledResultGetNumCalls(LrtCompiledResult compiled_result, } #endif // __cplusplus -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h b/tensorflow/lite/experimental/lrt/c/lite_rt_model.h similarity index 93% rename from tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h rename to tensorflow/lite/experimental/lrt/c/lite_rt_model.h index 190294c2d8a287..41d085aeab116b 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h +++ b/tensorflow/lite/experimental/lrt/c/lite_rt_model.h @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ #include #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" #include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" #ifdef __cplusplus extern "C" { @@ -178,4 +178,4 @@ LrtStatus PushOp(LrtOpList op_list, LrtOp op); } #endif // __cplusplus -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h b/tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h similarity index 98% rename from tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h rename to tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h index 52e39f463a0e07..ee0f0d277229f4 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h +++ b/tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ #include "tensorflow/lite/builtin_ops.h" @@ -241,4 +241,4 @@ typedef enum { } #endif // __cplusplus -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h b/tensorflow/lite/experimental/lrt/c/lite_rt_support.h similarity index 84% rename from tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h rename to tensorflow/lite/experimental/lrt/c/lite_rt_support.h index 20c8ce08be76ef..e1a55c5590b308 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h +++ b/tensorflow/lite/experimental/lrt/c/lite_rt_support.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" // IWYU pragma: keep #ifdef __cplusplus extern "C" { @@ -58,4 +58,4 @@ extern "C" { } // extern "C" #endif -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/cc/BUILD b/tensorflow/lite/experimental/lrt/cc/BUILD similarity index 84% rename from tensorflow/compiler/mlir/lite/experimental/lrt/cc/BUILD rename to tensorflow/lite/experimental/lrt/cc/BUILD index 351b5d14748d78..e1c81536225436 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/cc/BUILD +++ b/tensorflow/lite/experimental/lrt/cc/BUILD @@ -14,7 +14,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], + default_visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], ) cc_library( @@ -24,8 +24,8 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h b/tensorflow/lite/experimental/lrt/cc/lite_rt_support.h similarity index 92% rename from tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h rename to tensorflow/lite/experimental/lrt/cc/lite_rt_support.h index f030c37079a616..b95d874cacc886 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h +++ b/tensorflow/lite/experimental/lrt/cc/lite_rt_support.h @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ #include #include #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h" // IWYU pragma: export +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_support.h" // IWYU pragma: export #ifndef NDEBUG #include // IWYU pragma: keep #endif @@ -191,4 +191,4 @@ class LrtResult { #define LRT_ASSIGN_OR_RETURN_RESULT(decl, expr, ty) \ _ASSIGN_OR_RETURN_RESULT(decl, expr, ty, _CONCAT_NAME(_result, __COUNTER__)) -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD b/tensorflow/lite/experimental/lrt/core/BUILD similarity index 78% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD rename to tensorflow/lite/experimental/lrt/core/BUILD index d584670b49730e..1776aff33b41ba 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD +++ b/tensorflow/lite/experimental/lrt/core/BUILD @@ -14,19 +14,19 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], + default_visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], ) cc_library( name = "api_internal", srcs = ["lite_rt_common.cc"], hdrs = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_common.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_compiler_plugin.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_model.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_op_code.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_support.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/cc:lite_rt_support.h", + "//tensorflow/lite/experimental/lrt/c:lite_rt_common.h", + "//tensorflow/lite/experimental/lrt/c:lite_rt_compiler_plugin.h", + "//tensorflow/lite/experimental/lrt/c:lite_rt_model.h", + "//tensorflow/lite/experimental/lrt/c:lite_rt_op_code.h", + "//tensorflow/lite/experimental/lrt/c:lite_rt_support.h", + "//tensorflow/lite/experimental/lrt/cc:lite_rt_support.h", ], deps = [ "//tensorflow/lite:builtin_ops", @@ -56,8 +56,8 @@ cc_library( deps = [ ":api_internal", ":model", - "//tensorflow/compiler/mlir/lite:allocation", "//tensorflow/compiler/mlir/lite/core:model_builder_base", + "//tensorflow/lite:allocation", "//tensorflow/lite:framework", "//tensorflow/lite:stderr_reporter", "//tensorflow/lite/c:c_api_types", @@ -75,7 +75,7 @@ cc_test( ":api_internal", ":graph_tools", ":lite_rt_model_init", - "//tensorflow/compiler/mlir/lite/experimental/lrt/test_data:test_data_util", + "//tensorflow/lite/experimental/lrt/test_data:test_data_util", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", "@flatbuffers//:runtime_cc", @@ -104,7 +104,7 @@ cc_test( ":api_internal", ":graph_tools", ":model", - "//tensorflow/compiler/mlir/lite/experimental/lrt/test_data:test_data_util", + "//tensorflow/lite/experimental/lrt/test_data:test_data_util", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h b/tensorflow/lite/experimental/lrt/core/algo.h similarity index 95% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h rename to tensorflow/lite/experimental/lrt/core/algo.h index d489b2d287c652..96672f1f0aaa3c 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h +++ b/tensorflow/lite/experimental/lrt/core/algo.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ #include #include @@ -24,10 +24,10 @@ #include "absl/log/check.h" #include "llvm/ADT/MapVector.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" // NOLINTBEGIN @@ -358,6 +358,6 @@ inline void GraphSlicer::CloneInto(const LrtOpT& old_op) { } // namespace algo -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ // NOLINTEND diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/algo_test.cc b/tensorflow/lite/experimental/lrt/core/algo_test.cc similarity index 94% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/algo_test.cc rename to tensorflow/lite/experimental/lrt/core/algo_test.cc index 90c1b55e6e2abf..fa09a5907cc5ac 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/algo_test.cc +++ b/tensorflow/lite/experimental/lrt/core/algo_test.cc @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h" +#include "tensorflow/lite/experimental/lrt/core/algo.h" #include #include #include #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/graph_tools.h" +#include "tensorflow/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/lrt/test_data/test_data_util.h" namespace { diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h b/tensorflow/lite/experimental/lrt/core/graph_tools.h similarity index 95% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h rename to tensorflow/lite/experimental/lrt/core/graph_tools.h index 49c937fb74af20..2274c943377cba 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h +++ b/tensorflow/lite/experimental/lrt/core/graph_tools.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ #include #include @@ -27,10 +27,10 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" #define _D_MATCH_TRUE(v) \ { \ @@ -352,4 +352,4 @@ inline bool MatchNoBuffer(LrtTensor tensor) { } } // namespace graph_tools -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_common.cc b/tensorflow/lite/experimental/lrt/core/lite_rt_common.cc similarity index 92% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_common.cc rename to tensorflow/lite/experimental/lrt/core/lite_rt_common.cc index 0af9d776ffdd40..5592709e6c9548 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_common.cc +++ b/tensorflow/lite/experimental/lrt/core/lite_rt_common.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" struct LrtStatusT { LrtStatusCode code; diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.cc b/tensorflow/lite/experimental/lrt/core/lite_rt_model_init.cc similarity index 97% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.cc rename to tensorflow/lite/experimental/lrt/core/lite_rt_model_init.cc index 4d9d94f18de0fd..d569edd6a9b0d0 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.cc +++ b/tensorflow/lite/experimental/lrt/core/lite_rt_model_init.cc @@ -29,14 +29,13 @@ #include "absl/log/check.h" #include "flatbuffers/verifier.h" // from @flatbuffers -#include "tensorflow/compiler/mlir/lite/allocation.h" #include "tensorflow/compiler/mlir/lite/core/model_builder_base.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h" +#include "tensorflow/lite/experimental/lrt/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/stderr_reporter.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h b/tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h similarity index 86% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h rename to tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h index 1e5219ccd7d050..61b044931c6157 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h +++ b/tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" #ifdef __cplusplus extern "C" { @@ -63,4 +63,4 @@ using UniqueLrtModel = std::unique_ptr; #endif // __cplusplus -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.cc b/tensorflow/lite/experimental/lrt/core/model.cc similarity index 93% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/model.cc rename to tensorflow/lite/experimental/lrt/core/model.cc index 02e024532c98fe..44a96f15feb31f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.cc +++ b/tensorflow/lite/experimental/lrt/core/model.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/lrt/core/model.h" #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" // // Model diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h b/tensorflow/lite/experimental/lrt/core/model.h similarity index 91% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h rename to tensorflow/lite/experimental/lrt/core/model.h index 72d0f7d4e0e8af..ad7709ca331fe4 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h +++ b/tensorflow/lite/experimental/lrt/core/model.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ #include #ifndef NDEBUG @@ -24,11 +24,11 @@ #include #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" #include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" #include "tensorflow/lite/schema/schema_generated.h" // @@ -239,4 +239,4 @@ inline void DumpOp(const LrtOpT& op) { _LRT_D_MSG(""); \ debug::DumpOp(op); -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model_test.cc b/tensorflow/lite/experimental/lrt/core/model_test.cc similarity index 95% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/model_test.cc rename to tensorflow/lite/experimental/lrt/core/model_test.cc index 34923c6bc433e6..6ad6356b977381 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model_test.cc +++ b/tensorflow/lite/experimental/lrt/core/model_test.cc @@ -26,13 +26,13 @@ #include "flatbuffers/verifier.h" // from @flatbuffers #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/graph_tools.h" +#include "tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h" +#include "tensorflow/lite/experimental/lrt/test_data/test_data_util.h" #include "tensorflow/lite/schema/schema_generated.h" namespace { diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD b/tensorflow/lite/experimental/lrt/examples/BUILD similarity index 63% rename from tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD rename to tensorflow/lite/experimental/lrt/examples/BUILD index fbb21622ab2d30..b9b31dce413d76 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD +++ b/tensorflow/lite/experimental/lrt/examples/BUILD @@ -21,16 +21,16 @@ cc_library( name = "mul_op_plugin", srcs = ["mul_op_plugin.cc"], deps = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/cc:lite_rt_cc_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:graph_tools", + "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", + "//tensorflow/lite/experimental/lrt/cc:lite_rt_cc_api", + "//tensorflow/lite/experimental/lrt/core:graph_tools", ], ) cc_shared_library( name = "mul_op_plugin_so", shared_lib_name = "mul_op_plugin.so", - visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], + visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], deps = [":mul_op_plugin"], ) @@ -40,11 +40,11 @@ cc_test( tags = ["no_oss"], deps = [ ":mul_op_plugin", # buildcleaner: keep - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/cc:lite_rt_cc_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:graph_tools", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:model", - "//tensorflow/compiler/mlir/lite/experimental/lrt/test_data:test_data_util", + "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", + "//tensorflow/lite/experimental/lrt/cc:lite_rt_cc_api", + "//tensorflow/lite/experimental/lrt/core:graph_tools", + "//tensorflow/lite/experimental/lrt/core:model", + "//tensorflow/lite/experimental/lrt/test_data:test_data_util", "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin.cc b/tensorflow/lite/experimental/lrt/examples/mul_op_plugin.cc similarity index 91% rename from tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin.cc rename to tensorflow/lite/experimental/lrt/examples/mul_op_plugin.cc index 867195abf29377..78e68442a0855f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin.cc +++ b/tensorflow/lite/experimental/lrt/examples/mul_op_plugin.cc @@ -19,12 +19,12 @@ #include #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/graph_tools.h" // // Configurations diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin_test.cc b/tensorflow/lite/experimental/lrt/examples/mul_op_plugin_test.cc similarity index 83% rename from tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin_test.cc rename to tensorflow/lite/experimental/lrt/examples/mul_op_plugin_test.cc index 46eafd32136a38..703d5af8fefc67 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin_test.cc +++ b/tensorflow/lite/experimental/lrt/examples/mul_op_plugin_test.cc @@ -19,13 +19,13 @@ #include #include "absl/log/check.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/graph_tools.h" +#include "tensorflow/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/lrt/test_data/test_data_util.h" namespace { diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/BUILD b/tensorflow/lite/experimental/lrt/test_data/BUILD similarity index 75% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/BUILD rename to tensorflow/lite/experimental/lrt/test_data/BUILD index cee72cde127b9a..1123dc36108ca7 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/BUILD +++ b/tensorflow/lite/experimental/lrt/test_data/BUILD @@ -14,7 +14,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], + default_visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], ) # TODO: b/365295276 - Make custom rule and move to `.sh`. @@ -37,7 +37,7 @@ genrule( srcs = glob(["*.mlir"]), outs = [s.removesuffix(".mlir") + ".tflite" for s in glob(["*.mlir"])], cmd = CMD, - tools = ["//tensorflow/compiler/mlir/lite:tf_tfl_translate"], + tools = [CONVERTER], ) cc_library( @@ -46,10 +46,10 @@ cc_library( hdrs = ["test_data_util.h"], data = [":tflite_test_data"], deps = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/cc:lite_rt_cc_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:lite_rt_model_init", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:model", + "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", + "//tensorflow/lite/experimental/lrt/cc:lite_rt_cc_api", + "//tensorflow/lite/experimental/lrt/core:lite_rt_model_init", + "//tensorflow/lite/experimental/lrt/core:model", "@com_google_absl//absl/log:check", "@local_tsl//tsl/platform", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/add_cst.mlir b/tensorflow/lite/experimental/lrt/test_data/add_cst.mlir similarity index 100% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/add_cst.mlir rename to tensorflow/lite/experimental/lrt/test_data/add_cst.mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/add_simple.mlir b/tensorflow/lite/experimental/lrt/test_data/add_simple.mlir similarity index 100% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/add_simple.mlir rename to tensorflow/lite/experimental/lrt/test_data/add_simple.mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/mul_simple.mlir b/tensorflow/lite/experimental/lrt/test_data/mul_simple.mlir similarity index 100% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/mul_simple.mlir rename to tensorflow/lite/experimental/lrt/test_data/mul_simple.mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/simple_multi_op.mlir b/tensorflow/lite/experimental/lrt/test_data/simple_multi_op.mlir similarity index 100% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/simple_multi_op.mlir rename to tensorflow/lite/experimental/lrt/test_data/simple_multi_op.mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h b/tensorflow/lite/experimental/lrt/test_data/test_data_util.h similarity index 78% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h rename to tensorflow/lite/experimental/lrt/test_data/test_data_util.h index 889a19711b190f..8ef958781c80bf 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h +++ b/tensorflow/lite/experimental/lrt/test_data/test_data_util.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ // NOLINTNEXTLINE #include @@ -21,9 +21,9 @@ #include #include "absl/log/check.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h" #include "tsl/platform/platform.h" #define _ASSERT_RESULT_OK_ASSIGN(decl, expr, result) \ @@ -46,7 +46,7 @@ inline std::string GetTestFilePath(std::string_view filename) { static constexpr std::string_view kTestDataDir = - "tensorflow/compiler/mlir/lite/experimental/lrt/" + "tensorflow/lite/experimental/lrt/" "test_data/"; std::filesystem::path result_path; @@ -68,4 +68,4 @@ inline UniqueLrtModel LoadTestFileModel(std::string_view filename) { return UniqueLrtModel(model); } -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ From 571137ac8ed47220850f43b5838e9d66303584c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 16:37:33 -0700 Subject: [PATCH 287/483] [refactor]Move shutdown barrier hook to a separate method. PiperOrigin-RevId: 678892880 --- .../coordination/coordination_service.cc | 71 +++++++++++-------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index 8f6cd25717157f..bb10595516c73b 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -174,7 +174,6 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { ABSL_LOCKS_EXCLUDED(state_mu_); void SetTaskError(std::string_view task_name, absl::Status error) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - void AggregateClusterDevices() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); absl::Status DisconnectTask(const CoordinatedTask& task) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); @@ -196,6 +195,12 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { void PassBarrier(std::string_view barrier_id, absl::Status result, BarrierState* barrier) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-barrier hook to aggregate device info. + void AggregateClusterDevices() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-shutdown barrier hook to disconnect tasks that acked and propagate + // errors to those that have not. + void CompleteShutdownAfterBarrier(absl::Status result, BarrierState* barrier) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Check if participating tasks are specified correctly across barrier calls. bool ValidateTaskArgs( const std::vector& tasks_args, @@ -1435,36 +1440,10 @@ void CoordinationServiceStandaloneImpl::PassBarrier(std::string_view barrier_id, cluster_state_[GetTaskName(task)]->ExitBarrier(barrier_id); } - // Special hook for shutdown barrier to disconnect tasks at the barrier. + // Special hook for shutdown barrier to disconnect tasks at the barrier and + // propagate errors to those that have not. if (barrier_id == shutdown_barrier_id_) { - if (result.ok()) { - LOG(INFO) << "Shutdown barrier in coordination service has passed."; - } else { - LOG(ERROR) << "Shutdown barrier in coordination service has failed:\n" - << result - << "\nThis suggests that the workers are out of sync. Either " - "at least one worker is too fast in its execution / " - "crashed early or too slow / hanging. Check the logs for " - "an earlier error to identify the root cause."; - } - absl::Status shutdown_error = MakeCoordinationError(absl::InternalError( - absl::StrCat("Shutdown barrier has failed, but this task is not at the " - "barrier yet.\nBarrier result: '", - barrier->result.message()))); - for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { - if (at_barrier) { - // Disconnect tasks that reached the barrier. - absl::Status disconnect_status = DisconnectTask(task); - if (!disconnect_status.ok()) { - LOG(ERROR) << disconnect_status; - } - } else { - // Propagate errors to straggling tasks that have not reached the - // barrier. The barrier must have failed if any task did not reach the - // barrier. - ReportServiceErrorToTaskAsync(task, shutdown_error); - } - } + CompleteShutdownAfterBarrier(result, barrier); } barrier->tasks_at_barrier.clear(); ongoing_barriers_.erase(barrier_id); @@ -1557,6 +1536,38 @@ void CoordinationServiceStandaloneImpl::AggregateClusterDevices() { cluster_devices_ = post_aggregate_device_fn_(cluster_devices_); } } + +void CoordinationServiceStandaloneImpl::CompleteShutdownAfterBarrier( + absl::Status result, BarrierState* barrier) { + if (result.ok()) { + LOG(INFO) << "Shutdown barrier in coordination service has passed."; + } else { + LOG(ERROR) << "Shutdown barrier in coordination service has failed:\n" + << result + << "\nThis suggests that the workers are out of sync. Either " + "at least one worker is too fast in its execution / " + "crashed early or too slow / hanging. Check the logs for " + "an earlier error to identify the root cause."; + } + absl::Status shutdown_error = MakeCoordinationError(absl::InternalError( + absl::StrCat("Shutdown barrier has failed, but this task is not at the " + "barrier yet.\nBarrier result: '", + barrier->result.message()))); + for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { + if (at_barrier) { + // Disconnect tasks that reached the barrier. + absl::Status disconnect_status = DisconnectTask(task); + if (!disconnect_status.ok()) { + LOG(ERROR) << disconnect_status; + } + } else { + // Propagate errors to straggling tasks that have not reached the + // barrier. The barrier must have failed if any task did not reach the + // barrier. + ReportServiceErrorToTaskAsync(task, shutdown_error); + } + } +} } // namespace std::unique_ptr EnableCoordinationService( From 0b13c8e0d26daacf71ed70959ab6d0e5ba32a055 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Wed, 25 Sep 2024 16:40:23 -0700 Subject: [PATCH 288/483] [IFRT] Add simple serialization and deserialization of IFRT IR programs. PiperOrigin-RevId: 678893745 --- .../xla/xla/python/ifrt/hlo/hlo_program.h | 2 - third_party/xla/xla/python/ifrt/ir/BUILD | 41 ++++++ .../xla/xla/python/ifrt/ir/ifrt_ir_program.h | 11 ++ .../python/ifrt/ir/ifrt_ir_program_serdes.cc | 92 +++++++++++++ .../ifrt/ir/ifrt_ir_program_serdes_test.cc | 126 ++++++++++++++++++ third_party/xla/xla/python/ifrt/support/BUILD | 29 ++++ .../xla/python/ifrt/support/module_parsing.cc | 73 ++++++++++ .../xla/python/ifrt/support/module_parsing.h | 40 ++++++ 8 files changed, 412 insertions(+), 2 deletions(-) create mode 100644 third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes.cc create mode 100644 third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc create mode 100644 third_party/xla/xla/python/ifrt/support/module_parsing.cc create mode 100644 third_party/xla/xla/python/ifrt/support/module_parsing.h diff --git a/third_party/xla/xla/python/ifrt/hlo/hlo_program.h b/third_party/xla/xla/python/ifrt/hlo/hlo_program.h index b084c987b9931f..37802019e4bb7b 100644 --- a/third_party/xla/xla/python/ifrt/hlo/hlo_program.h +++ b/third_party/xla/xla/python/ifrt/hlo/hlo_program.h @@ -17,9 +17,7 @@ limitations under the License. #define XLA_PYTHON_IFRT_HLO_HLO_PROGRAM_H_ #include -#include #include -#include #include "llvm/Support/ExtensibleRTTI.h" #include "mlir/IR/BuiltinOps.h" diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index 22d1979f9e3ed0..e27d2cdc70374d 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -1,4 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//xla:xla.bzl", "xla_cc_test") load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( @@ -184,3 +185,43 @@ cc_library( "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "ifrt_ir_program_serdes", + srcs = ["ifrt_ir_program_serdes.cc"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":ifrt_ir_program", + "//xla/mlir/utils:error_util", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt/support:module_parsing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:statusor", + ], + alwayslink = True, +) + +xla_cc_test( + name = "ifrt_ir_program_serdes_test", + srcs = ["ifrt_ir_program_serdes_test.cc"], + deps = [ + ":ifrt_ir_program", + ":ifrt_ir_program_serdes", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt/support:module_parsing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h index 809afb7819322f..c8f5e6cde1ca1d 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h @@ -25,6 +25,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "llvm/Support/ExtensibleRTTI.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/executable.h" @@ -37,10 +39,19 @@ struct IfrtIRProgram : llvm::RTTIExtends { IfrtIRProgram() = default; explicit IfrtIRProgram(mlir::ModuleOp mlir_module) : mlir_module(std::move(mlir_module)) {} + IfrtIRProgram(std::unique_ptr context, + mlir::OwningOpRef module) + : mlir_module(*module), + mlir_context(std::move(context)), + owning_mlir_module(std::move(module)) {} mlir::ModuleOp mlir_module; static char ID; // NOLINT + + private: + std::unique_ptr mlir_context; + mlir::OwningOpRef owning_mlir_module; }; // CompileOptions for an IFRT IR program. diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes.cc new file mode 100644 index 00000000000000..e666dbf6275353 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes.cc @@ -0,0 +1,92 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Support/LLVM.h" +#include "xla/mlir/utils/error_util.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/support/module_parsing.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { + +namespace { + +class IfrtIRProgramSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::IfrtIRProgram"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const auto& program = llvm::cast(serializable); + if (program.mlir_module == nullptr) { + return absl::InvalidArgumentError("Unable to serialize null MLIR module"); + } + std::string serialized; + llvm::raw_string_ostream out(serialized); + mlir::BytecodeWriterConfig config; + mlir::BaseScopedDiagnosticHandler diagnostic_handler( + program.mlir_module->getContext()); + if (mlir::failed( + mlir::writeBytecodeToFile(program.mlir_module, out, config))) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to serialize IFRT IR module string: %s", + diagnostic_handler.ConsumeStatus().message())); + } + return serialized; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr) override { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(auto module, + support::ParseMlirModuleString(serialized, *context)); + return std::make_unique(std::move(context), + std::move(module)); + } + + static char ID; // NOLINT +}; + +char IfrtIRProgramSerDes::ID = 0; // NOLINT + +// clang-format off +bool register_ifrt_ir_program_serdes = ([]() { + RegisterSerDes(std::make_unique()); +}(), true); +// clang-format on + +} // namespace + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc new file mode 100644 index 00000000000000..019f3599d73ed2 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/support/module_parsing.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; +using ::tsl::testing::StatusIs; + +std::string PrintModule(mlir::ModuleOp module) { + std::string module_str; + llvm::raw_string_ostream os(module_str); + module->print(os, mlir::OpPrintingFlags().enableDebugInfo()); + return module_str; +} + +TEST(IfrtIRProgramSerDesTest, RoundTrip) { + static constexpr absl::string_view kMlirModuleStr = R"( +!array = !ifrt.array, #ifrt.sharding_param<1 to [0] on 1>, [0]> +module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0] + : (!array) -> !array + return %0 : !array + } + + module @add_one { + func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = mhlo.constant dense<1> : tensor<2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2xi32> + return %1 : tensor<2xi32> + } + } +} + )"; + + Serialized serialized; + auto context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + mlir::OwningOpRef module, + support::ParseMlirModuleString(kMlirModuleStr, *context)); + auto initial_program = + std::make_unique(std::move(context), std::move(module)); + TF_ASSERT_OK_AND_ASSIGN(serialized, Serialize(*initial_program)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr deserialized_program, + Deserialize(serialized, /*options=*/nullptr)); + + EXPECT_EQ(PrintModule(initial_program->mlir_module), + PrintModule(deserialized_program->mlir_module)); +} + +TEST(IfrtIRProgramSerDesTest, DeserializationError) { + static constexpr absl::string_view kMlirModuleStr = R"( +!array = !ifrt.array, #ifrt.sharding_param<1 to [0] on 1>, [0]> +module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0] + : (!array) -> !array + return %0 : !array + } + + module @add_one { + func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = mhlo.constant dense<1> : tensor<2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2xi32> + return %1 : tensor<2xi32> + } + } +} + )"; + Serialized serialized; + { + auto context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + mlir::OwningOpRef module, + support::ParseMlirModuleString(kMlirModuleStr, *context)); + auto program = + std::make_unique(std::move(context), std::move(module)); + TF_ASSERT_OK_AND_ASSIGN(serialized, Serialize(*program)); + } + + serialized.set_data("invalid data"); + + EXPECT_THAT(Deserialize(serialized, /*options=*/nullptr), + StatusIs(Not(absl::StatusCode::kOk), + HasSubstr("Failed to parse IFRT IR module string"))); +} + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/support/BUILD b/third_party/xla/xla/python/ifrt/support/BUILD index 1c287ac13ad2ce..05893d160e9296 100644 --- a/third_party/xla/xla/python/ifrt/support/BUILD +++ b/third_party/xla/xla/python/ifrt/support/BUILD @@ -1,14 +1,43 @@ load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) +cc_library( + name = "module_parsing", + srcs = ["module_parsing.cc"], + hdrs = ["module_parsing.h"], + compatible_with = get_compatible_with_portable(), + visibility = ["//xla/python/ifrt:friends"], + deps = [ + "//xla/mlir/utils:error_util", + "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/python/ifrt/ir", + "//xla/python/ifrt/ir/transforms:built_in_spmd_expansions", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLProgramDialect", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:ShapeDialect", + "@shardy//shardy/dialect/sdy/ir:register", + "@stablehlo//:register", + ], +) + cc_library( name = "sharding_conversions", srcs = ["sharding_conversions.cc"], hdrs = ["sharding_conversions.h"], + compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla:xla_data_proto_cc", diff --git a/third_party/xla/xla/python/ifrt/support/module_parsing.cc b/third_party/xla/xla/python/ifrt/support/module_parsing.cc new file mode 100644 index 00000000000000..ef514108fa9374 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/support/module_parsing.cc @@ -0,0 +1,73 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/support/module_parsing.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "shardy/dialect/sdy/ir/register.h" +#include "stablehlo/dialect/Register.h" +#include "xla/mlir/utils/error_util.h" +#include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h" + +namespace xla { +namespace ifrt { +namespace support { + +void RegisterMlirDialects(mlir::MLIRContext& context) { + mlir::DialectRegistry registry; + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + mlir::func::registerAllExtensions(registry); + mlir::mhlo::registerAllMhloDialects(registry); + mlir::sdy::registerAllDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + xla::ifrt::AttachBuiltInSpmdExpansions(registry); + context.appendDialectRegistry(registry); +} + +absl::StatusOr> ParseMlirModuleString( + absl::string_view mlir_module_str, mlir::MLIRContext& context) { + RegisterMlirDialects(context); + mlir::BaseScopedDiagnosticHandler diagnostic_handler(&context); + mlir::OwningOpRef module = + mlir::parseSourceString(mlir_module_str, &context); + if (!module) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to parse IFRT IR module string: %s", + diagnostic_handler.ConsumeStatus().message())); + } + return module; +} + +} // namespace support +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/support/module_parsing.h b/third_party/xla/xla/python/ifrt/support/module_parsing.h new file mode 100644 index 00000000000000..b93f0f8dbb84e6 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/support/module_parsing.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_SUPPORT_MODULE_PARSING_H_ +#define XLA_PYTHON_IFRT_SUPPORT_MODULE_PARSING_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" + +namespace xla { +namespace ifrt { +namespace support { + +// Registers all dialects required by IFRT IR modules. +void RegisterMlirDialects(mlir::MLIRContext& context); + +// Converts an IFRT IR module string to an mlir::Module. +absl::StatusOr> ParseMlirModuleString( + absl::string_view mlir_module_str, mlir::MLIRContext& context); + +} // namespace support +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_SUPPORT_MODULE_PARSING_H_ From 7b38658d0d4e121700d5993a4462137f771c9d76 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Wed, 25 Sep 2024 17:40:00 -0700 Subject: [PATCH 289/483] Fix floating point comparisons in the presence of non-default MXCSR settings. Some models (e.g. yggdrasil-decision-forests) store model information in float arrays that are later reinterpreted as (i.e., bitcast to) integers. This causes problems for very small floats. For example, the int32s 1 and 2 are, bitcast to float, 1.401298e-45 and 2.802597e-45. Under SSE, these numbers would in fact compare as identical if denormals-as-zero behavior is configured. (And would thus, before this change, make the "convert an array of identical constants to a broadcast" logic kick in, in XLA, thus changing the array.) Hence, don't compare floating point numbers using FP logic, but just compare their bit representation. PiperOrigin-RevId: 678912465 --- third_party/xla/xla/client/xla_builder_test.cc | 17 +++++++++++++++++ third_party/xla/xla/literal.cc | 2 +- third_party/xla/xla/literal_test.cc | 18 ++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index 8ecf2434fc1d3f..5aa2ef9fb13c19 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -1560,6 +1560,23 @@ TEST(XlaBuilderTest, CheckBufferDonor) { EXPECT_FALSE(config.ParameterIsBufferDonor(1, {})); } +TEST(XlaBuilderTest, ConstantLiteral) { + XlaBuilder b(TestName()); +#if defined(__x86_64__) && defined(_MM_DENORMALS_ZERO_ON) + int old_csr = _mm_getcsr(); + // Treat denormals as zero. This will make the small number below equal to + // 0.0, as far as the FP unit is concerned. + _mm_setcsr(old_csr | _MM_DENORMALS_ZERO_ON); +#endif + ConstantR1(&b, {0.0f, 1.401298e-45f}); +#if defined(__x86_64__) && defined(_MM_DENORMALS_ZERO_ON) + _mm_setcsr(old_csr); +#endif + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + const HloInstruction* root = GetRoot(*module); + ASSERT_THAT(root, GmockMatch(m::Constant())); +} + TEST(XlaBuilderTest, InvalidInputOutputAliasBufferDonor) { XlaBuilder b(TestName()); diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index c1026718435087..971b1d48ac563b 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -1950,7 +1950,7 @@ template static bool AllElementsEqualValue(absl::Span data, NativeT value) { for (int64_t i = 0; i < data.size(); ++i) { - if (!EqualIncludingNan(data[i], value)) { + if (memcmp(&data[i], &value, sizeof value)) { return false; } } diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index dd9c1df6a3eb24..767fe581121db3 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -744,6 +744,24 @@ TEST_F(LiteralUtilTest, IsAllFirst) { complex64 c7_9 = {7, 9}; EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAllFirst()); EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}).IsAllFirst()); + +#if defined(__x86_64__) && defined(_MM_DENORMALS_ZERO_ON) + int old_csr = _mm_getcsr(); + // Treat denormals as zero. This will make the small numbers below equal to + // 0.0, as far as the FP unit is concerned. + _mm_setcsr(old_csr | _MM_DENORMALS_ZERO_ON); +#endif + bool eq0 = LiteralUtil::CreateR1({0.0, 1.401298e-45}).IsAllFirst(); + bool eq1 = LiteralUtil::CreateR1({0.0, 2.802597e-45}).IsAllFirst(); + bool eq2 = + LiteralUtil::CreateR1({4.203895e-45, 7.006492e-45}).IsAllFirst(); +#if defined(__x86_64__) && defined(_MM_DENORMALS_ZERO_ON) + _mm_setcsr(old_csr); +#endif + + EXPECT_FALSE(eq0); + EXPECT_FALSE(eq1); + EXPECT_FALSE(eq2); } TEST_F(LiteralUtilTest, CountEqualInt) { From eecc6f2097f84e6d8c51011aacbbc54c8d133338 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 18:20:22 -0700 Subject: [PATCH 290/483] Drop shard_as/shard_like unknown shardings as well when invoking ShardingPropagation::ProcessShardingInstruction from auto-sharding as auto-sharding currently does not support these sharding annotations. This requires us to add an additional default parameter to ShardingPropagation::ProcessShardingInstruction to control this behavior. Also performed some cleanup along the way. PiperOrigin-RevId: 678922869 --- .../auto_sharding/auto_sharding.cc | 86 ++++++----- .../auto_sharding/auto_sharding.h | 11 +- .../auto_sharding/auto_sharding_test.cc | 143 ++++++++++++------ .../xla/xla/service/sharding_propagation.cc | 5 +- .../xla/xla/service/sharding_propagation.h | 3 +- 5 files changed, 162 insertions(+), 86 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 41713ce3bd5314..4841c6a5c76e02 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2573,32 +2573,39 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, return str; } -void SaveShardingForInstruction( +absl::Status SaveShardingForInstruction( const HloInstruction* inst, bool save_for_copy_users, absl::flat_hash_map>& preserve_shardings) { - auto save_sharding = [&preserve_shardings](const HloInstruction* inst) { + auto save_sharding = + [&preserve_shardings](const HloInstruction* inst) -> absl::Status { if (!inst->has_sharding()) { - return; + return absl::OkStatus(); + } + if (inst->sharding().IsUnknown()) { + return absl::UnimplementedError( + "Auto-sharding currently does not support shard_as/shard_like " + "sharding annotations"); } if (!inst->sharding().IsTuple()) { preserve_shardings[inst->name()] = {inst->sharding()}; } else { preserve_shardings[inst->name()] = inst->sharding().tuple_elements(); } + return absl::OkStatus(); }; - save_sharding(inst); + TF_RETURN_IF_ERROR(save_sharding(inst)); + // Also preserve the shardings of copy users of theinstruction. if (save_for_copy_users) { for (const auto user : inst->users()) { - // Also preserve the shardings of copy ops that are the users of those - // instructions. if (user->opcode() == HloOpcode::kCopy) { - save_sharding(user); + TF_RETURN_IF_ERROR(save_sharding(user)); } } } + return absl::OkStatus(); } // Check whether the shardings that need to be preserved are preserved. @@ -3272,13 +3279,14 @@ bool HasReduceScatterOpportunity(const HloInstruction* inst, } // namespace spmd -std::pair>, bool> +absl::StatusOr AutoShardingImplementation::SaveAndRemoveShardingAnnotation( HloModule* module, const absl::flat_hash_set& instructions_to_shard, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads) { - absl::flat_hash_map> preserve_shardings; + absl::flat_hash_map> + preserved_shardings; absl::flat_hash_set keep_inst; for (const HloComputation* computation : @@ -3289,16 +3297,16 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( inst->opcode() == HloOpcode::kRecvDone || inst->opcode() == HloOpcode::kSend || inst->opcode() == HloOpcode::kSendDone) { - spmd::SaveShardingForInstruction(inst, - /* save_for_copy_users */ false, - preserve_shardings); + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( + inst, + /*save_for_copy_users=*/false, preserved_shardings)); continue; } if (spmd::IsInstructionBeforeSPMDFullToShardShapeCustomCall(inst) || spmd::IsSPMDShardToFullShapeCustomCall(inst)) { - spmd::SaveShardingForInstruction(inst, - /* save_for_copy_users */ false, - preserve_shardings); + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( + inst, + /*save_for_copy_users=*/false, preserved_shardings)); } if (inst->has_sharding() && spmd::IsShardingMisaligned(inst->sharding(), inst->shape()) && @@ -3318,12 +3326,12 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( for (const HloComputation* computation : module->computations(execution_threads)) { for (const auto inst : computation->instructions()) { - spmd::SaveShardingForInstruction(inst, - /* save_for_copy_users */ true, - preserve_shardings); + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( + inst, + /*save_for_copy_users=*/true, preserved_shardings)); } } - return std::make_pair(preserve_shardings, /* module_is_changed */ false); + return SaveShardingAnnotationsResult{preserved_shardings, false}; } bool module_is_changed = false; @@ -3335,23 +3343,23 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( // they are small tensors if (replicated_small_tensors.count(ins->name())) { keep_inst.insert(ins); - spmd::SaveShardingForInstruction(ins, - /* save_for_copy_users */ false, - preserve_shardings); + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( + ins, + /*save_for_copy_users=*/false, preserved_shardings)); continue; } // Do not remove entry computation's parameter and root instruction's - // sharding if preserve_shardings is kKeepInputOutputShardings. + // sharding if preserved_shardings is kKeepInputOutputShardings. if (option_.preserve_shardings == AutoShardingOption::PreserveShardingsType:: kKeepInputOutputShardings && is_entry_computation && (ins->opcode() == HloOpcode::kParameter || ins->IsRoot())) { keep_inst.insert(ins); - spmd::SaveShardingForInstruction( + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( ins, - /* save_for_copy_users */ ins->opcode() == HloOpcode::kParameter, - preserve_shardings); + /*save_for_copy_users=*/ins->opcode() == HloOpcode::kParameter, + preserved_shardings)); continue; } @@ -3375,7 +3383,7 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( } } } - return std::make_pair(preserve_shardings, module_is_changed); + return SaveShardingAnnotationsResult{preserved_shardings, module_is_changed}; } absl::Status AutoShardingImplementation::CanonicalizeLayouts( @@ -3527,7 +3535,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( ProcessShardingInstruction( module, execution_threads, /*replace_sharding_with_copy=*/true, &unspecified_dims, /*saved_root_shardings=*/nullptr, - /*saved_parameter_shardings=*/nullptr)); + /*saved_parameter_shardings=*/nullptr, + /*instruction_to_shard_group_id=*/nullptr, + /*shard_group_id_to_shard_as_group=*/nullptr, + /*shard_group_id_to_shard_like_group=*/nullptr, + /*allow_spmd_sharding_propagation_to_parameters_vector=*/nullptr, + /*remove_unknown_shardings=*/true)); + + DumpHloModuleIfEnabled(*module, "after_spmd_calls"); if (changed) { module_is_changed = true; VLOG(3) << "CustomCalls with custom_call_target=Sharding are removed and " @@ -3584,13 +3599,13 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( const absl::flat_hash_set& instructions_to_shard = ComputeInstructionsToShard(*module, sequence); - std::pair>, bool> - preserve_shardings_result = SaveAndRemoveShardingAnnotation( - module, instructions_to_shard, replicated_small_tensors, - execution_threads); + TF_ASSIGN_OR_RETURN(SaveShardingAnnotationsResult saved_sharding_result, + SaveAndRemoveShardingAnnotation( + module, instructions_to_shard, + replicated_small_tensors, execution_threads)); absl::flat_hash_map> - preserve_shardings = std::move(preserve_shardings_result.first); - module_is_changed |= preserve_shardings_result.second; + preserve_shardings = std::move(saved_sharding_result.preserved_shardings); + module_is_changed |= saved_sharding_result.module_is_changed; absl::flat_hash_map instruction_execution_counts = spmd::ComputeInstructionExecutionCounts( @@ -3867,7 +3882,8 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( CHECK(instruction->has_sharding()); CHECK(!instruction->sharding().IsManual()); CHECK(instruction->operand(0)->has_sharding()); - CHECK(instruction->operand(0)->sharding().IsManual()); + CHECK(instruction->operand(0)->sharding().IsManual()) + << instruction->ToString(); } } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index 1a1b67c83873ff..e749fb8682532d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -68,11 +68,16 @@ class AutoShardingImplementation { const absl::flat_hash_map& sharding_propagation_solution = {}); + struct SaveShardingAnnotationsResult { + absl::flat_hash_map> + preserved_shardings; + bool module_is_changed; + }; + // Returns sharding annotations that need to be preserved in a map (for // verification after auto-sharding is done), and removes any sharding - // anotations that need to be removed. - std::pair>, bool> - SaveAndRemoveShardingAnnotation( + // annotations that need to be removed. + absl::StatusOr SaveAndRemoveShardingAnnotation( HloModule* module, const absl::flat_hash_set& instructions_to_shard, const absl::flat_hash_set& replicated_small_tensors, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 2b0c2aec59e6f9..55d53e34619a82 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -1473,16 +1473,17 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {}, - /* execution_threads */ {}); + + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_FALSE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_FALSE(saved_shardings_result.module_is_changed); std::vector instructions = module->entry_computation()->MakeInstructionPostOrder(); EXPECT_THAT(instructions, @@ -1531,16 +1532,16 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {"dot"}, - /* execution_threads */ {}); + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {"dot"}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_FALSE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_FALSE(saved_shardings_result.module_is_changed); std::vector instructions = module->entry_computation()->MakeInstructionPostOrder(); EXPECT_THAT(instructions, @@ -1586,16 +1587,17 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {}, - /* execution_threads */ {}); + + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_TRUE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_TRUE(saved_shardings_result.module_is_changed); // Dot does not have shardings anymore. const HloInstruction* dot = FindInstruction(module.get(), "dot"); @@ -1670,16 +1672,17 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {}, - /* execution_threads */ {}); + + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_TRUE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_TRUE(saved_shardings_result.module_is_changed); EXPECT_THAT(saved_shardings, IsEmpty()); std::vector instructions = module->entry_computation()->MakeInstructionPostOrder(); @@ -1708,16 +1711,16 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {"dot", "copy"}, - /* execution_threads */ {}); + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {"dot", "copy"}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_TRUE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_TRUE(saved_shardings_result.module_is_changed); // params have no shardings. const HloInstruction* param0 = FindInstruction(module.get(), "param0"); @@ -2807,6 +2810,56 @@ ENTRY %entry { EXPECT_THAT(slice1, op::Sharding("{replicated}")); } +TEST_F(AutoShardingTest, CrashIfAskedToRespectShardAsShardLike) { + const char* const kHloString = R"( +HloModule module +ENTRY matmul { + param1 = f32[32,64]{1,0} parameter(0) + param2 = f32[64,128]{1,0} parameter(1) + custom-call1 = f32[32,64]{1,0} custom-call(param1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0} + custom-call2 = f32[64,128]{1,0} custom-call(param2), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0} + ROOT root = f32[32,128]{1,0} dot(custom-call1, custom-call2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings; + option.enable = true; + option.device_mesh_shape = {4, 1}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + // TODO(b/369616683) Fix the error message output in this case. + EXPECT_DEATH( + absl::StatusOr status = AutoSharding(option).Run(module.get()), + "The auto-sharding solver has timed out without a solution."); +} + +TEST_F(AutoShardingTest, IgnoreShardAsShardLike) { + const char* const kHloString = R"( +HloModule module +ENTRY matmul { + param1 = f32[32,64]{1,0} parameter(0) + param2 = f32[64,128]{1,0} parameter(1) + custom-call1 = f32[32,64]{1,0} custom-call(param1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0} + custom-call2 = f32[64,128]{1,0} custom-call(param2), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0} + ROOT root = f32[32,128]{1,0} dot(custom-call1, custom-call2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kRemoveAllShardings; + option.enable = true; + option.device_mesh_shape = {4, 1}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + EXPECT_TRUE(changed); +} + TEST(NormalizeTest, NormalizeHandlesNegativeCosts) { EdgeReshardingCostMatrix edge_cost(2, 2); edge_cost(0, 0).communication_cost = -100; diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 3ffd85687793e5..bf44306a1cccd0 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -1383,7 +1383,8 @@ absl::StatusOr ProcessShardingInstruction( absl::flat_hash_map>* shard_group_id_to_shard_like_group, const std::vector* - allow_spmd_sharding_propagation_to_parameters_vector) { + allow_spmd_sharding_propagation_to_parameters_vector, + bool remove_unknown_shardings) { bool changed = false; const bool use_shard_group = instruction_to_shard_group_id && @@ -1475,7 +1476,7 @@ absl::StatusOr ProcessShardingInstruction( bool replaced_with_copy = replace_sharding_with_copy && - (!original_sharding.IsUnknown() || + (!original_sharding.IsUnknown() || remove_unknown_shardings || instruction->operand(0)->opcode() == HloOpcode::kParameter); // Replace the sharding instruction with a copy node so that it does not // need special handling. diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h index 27cef820977436..c3dc49ec4daa34 100644 --- a/third_party/xla/xla/service/sharding_propagation.h +++ b/third_party/xla/xla/service/sharding_propagation.h @@ -72,7 +72,8 @@ absl::StatusOr ProcessShardingInstruction( absl::flat_hash_map>* shard_group_id_to_shard_like_group = nullptr, const std::vector* - allow_spmd_sharding_propagation_to_parameters_vector = nullptr); + allow_spmd_sharding_propagation_to_parameters_vector = nullptr, + bool remove_unknown_shardings = false); int64_t ComputeNonRootUsers(const HloInstruction* instr); From 7b650dbb1d5f0a0f171984fa9e130cea7fb83057 Mon Sep 17 00:00:00 2001 From: Yifan Jiang Date: Wed, 25 Sep 2024 18:33:17 -0700 Subject: [PATCH 291/483] Add a missing log to log error if XLA_PJRT_GPU_ALLOW_DELETE_BEFORE_FULFILL fails to read. PiperOrigin-RevId: 678926392 --- third_party/xla/xla/pjrt/BUILD | 1 + third_party/xla/xla/pjrt/local_device_state.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index d7c70c22f9454f..8c1bacdd6e6be7 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -157,6 +157,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index dd037deade6a0c..4f6caed0e5ac47 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/protobuf/error_codes.pb.h" @@ -61,6 +62,10 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, absl::Status status = tsl::ReadBoolFromEnvVar("XLA_PJRT_GPU_ALLOW_DELETE_BEFORE_FULFILL", true, &allow_delete_before_fulfill_); + if (!status.ok()) { + LOG(ERROR) << "Failed to read XLA_PJRT_GPU_ALLOW_DELETE_BEFORE_FULFILL: " + << status; + } local_hardware_id_ = executor_->device_ordinal(); local_device_id_ = From 96c8197bcf8a8e48187a609c946849e6b276a8fb Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 25 Sep 2024 19:07:47 -0700 Subject: [PATCH 292/483] [XLA:SPMD] Use stable sort to fix a flaky test. PiperOrigin-RevId: 678935741 --- .../xla/xla/service/spmd/spmd_partitioner.cc | 10 ++++------ .../xla/xla/service/spmd/spmd_partitioner_util.cc | 14 +++++++------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index ceb9553f9ee095..a301c2a3c0d9af 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -1143,8 +1143,6 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, std::vector(halo_exchange_base_shape.rank(), 1))); } - std::vector left_halo_size_functions(base_shape_.rank()); - std::vector right_halo_size_functions(base_shape_.rank()); // TODO(yuanzx): We are concatenating on each sharded dimension one at time, // and in the second dimension (and beyond) we create halos by slicing the // concat in the previous dimension, which is not optimal. We should generate @@ -1162,18 +1160,18 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, // partition. MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded( input_shard_size, explicit_left_padding[dim], 1); - left_halo_size_functions[dim] = + OffsetCalculation left_halo_size_functions = shard_limit_of_previous_on_padded - start_on_padded_calculations[dim]; // Right halo. MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded( input_shard_size, input_shard_size + explicit_left_padding[dim], 1); - right_halo_size_functions[dim] = + OffsetCalculation right_halo_size_functions = limit_on_padded_calculations[dim] - shard_start_of_next_on_padded; auto resharded = ExchangeHaloAndGetValidData( - visiting_hlo, halo_exchange_base_shape, left_halo_size_functions[dim], - right_halo_size_functions[dim], explicit_left_padding[dim], + visiting_hlo, halo_exchange_base_shape, left_halo_size_functions, + right_halo_size_functions, explicit_left_padding[dim], padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, *halo_exchange_target, offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim], state_.collective_ops_creator, diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc index 5034d168589bf9..ac1d272bac56f0 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc @@ -956,8 +956,7 @@ HloInstruction* ExchangeHaloCompact( (i + 1) * input_shard_size + right_halo_size_function.Calculate(i); max_window_size = std::max(max_window_size, limit - start); while (next_start < limit) { - halos[i].emplace_back(); - Halo& halo = halos[i].back(); + Halo& halo = halos[i].emplace_back(); halo.my_index = i; halo.halo_offset = next_start - start; halo.start = next_start % input_shard_size; @@ -1038,11 +1037,12 @@ HloInstruction* ExchangeHaloCompact( // Sort halos that are from the same src according to halo_offset, so that // they are more likely to have similar characteristics. for (int64_t i = 0; i < src_to_dst.size(); ++i) { - absl::c_sort(src_to_dst[i], [&](const std::pair& a, - const std::pair& b) { - return halos[a.first][a.second].halo_offset < - halos[b.first][b.second].halo_offset; - }); + absl::c_stable_sort(src_to_dst[i], + [&](const std::pair& a, + const std::pair& b) { + return halos[a.first][a.second].halo_offset < + halos[b.first][b.second].halo_offset; + }); } // Build collective permutes with distinct src/dst values. From 3997b3b5d4e5f0babb6e1997c7f327dda364df7e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 19:57:23 -0700 Subject: [PATCH 293/483] Add logic to track definitions across calls in the scheduler. PiperOrigin-RevId: 678948517 --- .../xla/service/latency_hiding_scheduler.cc | 66 ++++++++++++++++++- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index f0e5af7ac3c9a3..c4e812e7e5aba0 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -73,12 +73,72 @@ bool IsNopInstruction(const HloInstruction& hlo) { (op == HloOpcode::kTuple && hlo.user_count() == 1 && hlo.users().front()->opcode() == HloOpcode::kWhile); } + +bool InstructionDefinesValue(const HloInstruction* instruction, + const HloValue* value) { + if (value->defining_instruction() == instruction) { + return true; + } + if (value->shape().has_layout() && + value->shape().layout().memory_space() != kDefaultMemorySpace) { + return false; + } + // Also check if the instruction is a call to a computation that defines the + // value. This is needed in cases, e.g., where we wrap a value-defining + // instruction in a async call for offloading, and the async call itself will + // effectively define the value in the current scope that the scheduler is + // running in. + if (instruction->opcode() == HloOpcode::kAsyncStart || + instruction->opcode() == HloOpcode::kAsyncDone) { + if (instruction->async_wrapped_opcode() == HloOpcode::kCall) { + return instruction->async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() == value->defining_instruction(); + } + return instruction->async_wrapped_instruction() == + value->defining_instruction(); + } + return false; +} + +bool InstructionFirstDefinesBuffer( + const HloInstruction* instruction, + const BufferInfoTracker::ValueInfo& buffer_value_info) { + if (buffer_value_info.first_definition == instruction) { + return true; + } + if (buffer_value_info.value->values()[0]->shape().has_layout() && + buffer_value_info.value->values()[0]->shape().layout().memory_space() != + kDefaultMemorySpace) { + return false; + } + // Similar to logic above, also check if the instruction is a call to a + // computation that defines the value. + if (instruction->opcode() == HloOpcode::kAsyncStart || + instruction->opcode() == HloOpcode::kAsyncDone) { + if (instruction->async_wrapped_opcode() == HloOpcode::kCall) { + return instruction->async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() == buffer_value_info.first_definition; + } + return instruction->async_wrapped_instruction() == + buffer_value_info.first_definition; + } + return false; +} + } // namespace CanonicalAsyncOp DefaultGetCanonicalAsyncOp(const HloInstruction& hlo) { switch (hlo.opcode()) { case HloOpcode::kAsyncStart: case HloOpcode::kAsyncDone: + if (hlo.async_wrapped_opcode() == HloOpcode::kCall) { + return {hlo.opcode(), hlo.async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() + ->opcode()}; + } return {hlo.opcode(), hlo.async_wrapped_opcode()}; case HloOpcode::kAllReduceStart: return {HloOpcode::kAsyncStart, HloOpcode::kAllReduce}; @@ -596,7 +656,7 @@ void MemoryPressureTracker::Initialize( output_values.push_back(std::make_pair( buffer_tracker_.GetBufferInfo(buffer->id()), index)); if (absl::c_any_of(buffer->values(), [&](const HloValue* value) { - return value->defining_instruction() == instruction; + return InstructionDefinesValue(instruction, value); })) { defined_values.push_back( buffer_tracker_.GetBufferInfo(buffer->id())); @@ -663,7 +723,7 @@ void MemoryPressureTracker::UpdateBuffers(const HloInstruction* instruction) { continue; } if (live_buffers_[b.value->id()] != 0) { - if (b.first_definition == instruction) { + if (InstructionFirstDefinesBuffer(instruction, b)) { live_memory_usage_ -= b.buffer_size; live_buffers_set_.erase(b.value->id()); } @@ -721,7 +781,7 @@ std::pair MemoryPressureTracker::MemoryPressureDifference( continue; } if (live_buffers_[b.value->id()]) { - if (b.first_definition == instruction) { + if (InstructionFirstDefinesBuffer(instruction, b)) { increase -= b.buffer_size; } } From 650dac3697862848aa5471221745d80a713a2c31 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 21:13:55 -0700 Subject: [PATCH 294/483] Load the builtin Bazel java rules from @rules_java PiperOrigin-RevId: 678972616 --- tensorflow/java/BUILD | 1 + tensorflow/java/src/main/java/org/tensorflow/examples/BUILD | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index c86ebca819cf24..a0bfea1d34accc 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -1,6 +1,7 @@ # Description: # TensorFlow Java API. +load("@rules_java//java:defs.bzl", "java_library", "java_plugin") load( "//tensorflow:tensorflow.bzl", "VERSION", diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD index c8e8abbf1c4947..c1e7724652148b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD +++ b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow Java examples. +load("@rules_java//java:defs.bzl", "java_binary") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:private"], From a260bed13b5c28e0c9026c93d1fe2d2a030298ac Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 25 Sep 2024 23:14:33 -0700 Subject: [PATCH 295/483] Introduce derived classes CudaKernel and RocmKernel This change makes `GpuKernel` an abstract base class and moves its implementation into the derived classes `CudaKernel` and `RocmKernel`. This avoids having two implementations for the same functions and also reduces the exposure of gpu_types.h which we want to get rid of. I'm also adding some basic tests for the new classes. PiperOrigin-RevId: 679003532 --- .../xla/xla/stream_executor/cuda/BUILD | 24 +++++++ .../xla/stream_executor/cuda/cuda_executor.cc | 3 +- .../xla/stream_executor/cuda/cuda_kernel.cc | 11 +-- .../xla/stream_executor/cuda/cuda_kernel.h | 69 +++++++++++++++++++ .../stream_executor/cuda/cuda_kernel_test.cc | 60 ++++++++++++++++ .../xla/xla/stream_executor/gpu/gpu_kernel.h | 41 +---------- .../xla/xla/stream_executor/rocm/BUILD | 29 +++++++- .../xla/stream_executor/rocm/rocm_executor.cc | 4 +- .../xla/stream_executor/rocm/rocm_kernel.cc | 15 ++-- .../xla/stream_executor/rocm/rocm_kernel.h | 69 +++++++++++++++++++ .../stream_executor/rocm/rocm_kernel_test.cc | 60 ++++++++++++++++ 11 files changed, 334 insertions(+), 51 deletions(-) create mode 100644 third_party/xla/xla/stream_executor/cuda/cuda_kernel.h create mode 100644 third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc create mode 100644 third_party/xla/xla/stream_executor/rocm/rocm_kernel.h create mode 100644 third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index a94edde0814b72..529a0b6197935f 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -527,12 +527,36 @@ cuda_only_cc_library( cuda_only_cc_library( name = "cuda_kernel", srcs = ["cuda_kernel.cc"], + hdrs = ["cuda_kernel.h"], deps = [ "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", + "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_test( + name = "cuda_kernel_test", + srcs = ["cuda_kernel_test.cc"], + backends = ["gpu_any"], + deps = [ + ":cuda_kernel", + ":cuda_runtime", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/stream_executor/gpu:gpu_test_kernels_cuda", + "@com_google_googletest//:gtest_main", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 98441a9bdc0549..5dc8838aa02a44 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_collectives.h" #include "xla/stream_executor/cuda/cuda_event.h" +#include "xla/stream_executor/cuda/cuda_kernel.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/cuda/cuda_runtime.h" #include "xla/stream_executor/cuda/cuda_status.h" @@ -191,7 +192,7 @@ absl::Status CudaExecutor::LoadModuleFromHsaco(const char* hsaco, absl::StatusOr> CudaExecutor::LoadKernel( const MultiKernelLoaderSpec& spec) { - auto cuda_kernel = std::make_unique(this); + auto cuda_kernel = std::make_unique(this); CUmodule module; const std::string* kernel_name; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc index c62a8f99ae3298..33d460b64bd210 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc @@ -13,28 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/cuda/cuda_kernel.h" + #include #include #include "absl/log/log.h" #include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/launch_dim.h" namespace stream_executor { namespace gpu { -absl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( +absl::StatusOr CudaKernel::GetMaxOccupiedBlocksPerCore( ThreadDim threads, size_t dynamic_shared_memory_bytes) const { int32_t threads_per_block = threads.x * threads.y * threads.z; VLOG(3) << "Get kernel block occupancy: " << name() << "; threads_per_block: " << threads_per_block << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; - return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, - threads_per_block, - dynamic_shared_memory_bytes); + return GpuDriver::GetMaxOccupiedBlocksPerCore( + gpu_executor_->gpu_context(), gpu_function_, threads_per_block, + dynamic_shared_memory_bytes); } } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h new file mode 100644 index 00000000000000..be44bd940ebe48 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h @@ -0,0 +1,69 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The CUDA implementation of the StreamExecutor functionality. +// CUDA inclusions are ideally confined to this implementation file. +// +// The notions from the StreamExecutor basically correspond to the CUDA streams +// programming model provided by the libcuda.so driver APIs, so we don't have +// to do much more than wrap the calls to the libraries appropriately. +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/launch_dim.h" +#include "tsl/platform/logging.h" + +namespace stream_executor::gpu { + +class CudaKernel : public GpuKernel { + public: + explicit CudaKernel(GpuExecutor* gpu_executor) + : gpu_executor_(gpu_executor) {} + + // Note that the function is unloaded when the module is unloaded, and the + // module that the function is contained in is owned by the GpuExecutor. + ~CudaKernel() override { gpu_executor_->UnloadKernel(this); } + + // As arity cannot be reflected upon using the CUDA API, the arity is + // explicitly set during the GpuExecutor::GetKernel initialization process. + void set_arity(unsigned arity) { arity_ = arity; } + unsigned Arity() const override { return arity_; } + + absl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + + // Simple accessor methods. + GpuFunctionHandle gpu_function() const override { return gpu_function_; } + void set_gpu_function(GpuFunctionHandle gpu_function) { + gpu_function_ = gpu_function; + } + + private: + GpuExecutor* gpu_executor_ = nullptr; + + CUfunction gpu_function_ = nullptr; // wrapped CUDA kernel handle + unsigned arity_ = 0; // number of formal parameters the kernel takes +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc new file mode 100644 index 00000000000000..94620a640b3f4d --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_kernel.h" + +#include +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_runtime.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using testing::Ge; +using tsl::testing::IsOkAndHolds; + +TEST(CudaKernelTest, GetMaxOccupiedBlocksPerCore) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + PlatformManager::PlatformWithName("CUDA")); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + GpuExecutor* gpu_executor = ExtractGpuExecutor(executor); + + CudaKernel cuda_kernel(gpu_executor); + cuda_kernel.set_arity(3); + + TF_ASSERT_OK_AND_ASSIGN( + CUfunction function, + CudaRuntime::GetFuncBySymbol(internal::GetAddI32Kernel())); + + cuda_kernel.set_gpu_function(function); + + EXPECT_EQ(cuda_kernel.Arity(), 3); + EXPECT_EQ(cuda_kernel.gpu_function(), function); + + EXPECT_THAT(cuda_kernel.GetMaxOccupiedBlocksPerCore( + ThreadDim(1, 1, 1), /*dynamic_shared_memory_bytes=*/0), + IsOkAndHolds(Ge(1))); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h index fb3439c507d1cb..c6714cce9def47 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h @@ -22,51 +22,16 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "tsl/platform/logging.h" namespace stream_executor::gpu { +// A GpuKernel is a `Kernel` that can be launched on a GPU. It allows +// access to the underlying GPU function through `gpu_function()`. class GpuKernel : public Kernel { public: - explicit GpuKernel(GpuExecutor* gpu_executor) - : gpu_executor_(gpu_executor), - gpu_context_(gpu_executor->gpu_context()) {} - - // Note that the function is unloaded when the module is unloaded, and the - // module that the function is contained in is owned by the GpuExecutor. - ~GpuKernel() override { gpu_executor_->UnloadKernel(this); } - - // As arity cannot be reflected upon using the CUDA API, the arity is - // explicitly set during the GpuExecutor::GetKernel initialization process. - void set_arity(unsigned arity) { arity_ = arity; } - unsigned Arity() const override { return arity_; } - - absl::StatusOr GetMaxOccupiedBlocksPerCore( - ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; - - // Simple accessor methods. - GpuFunctionHandle gpu_function() const { return gpu_function_; } - void set_gpu_function(GpuFunctionHandle gpu_function) { - gpu_function_ = gpu_function; - } - - private: - GpuExecutor* gpu_executor_ = nullptr; - Context* gpu_context_ = nullptr; // context where kernel is loaded - - GpuFunctionHandle gpu_function_ = nullptr; // wrapped CUDA kernel handle - unsigned arity_ = 0; // number of formal parameters the kernel takes + virtual GpuFunctionHandle gpu_function() const = 0; }; inline const GpuKernel* AsGpuKernel(const Kernel* kernel) { diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index aa6a019eb25a26..8d6c8b565f4af7 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -12,6 +12,7 @@ load( "//xla/stream_executor:build_defs.bzl", "stream_executor_friends", ) +load("//xla/tests:build_defs.bzl", "xla_test") load( "//xla/tsl:tsl.bzl", "if_google", @@ -219,6 +220,7 @@ cc_library( cc_library( name = "rocm_kernel", srcs = ["rocm_kernel.cc"], + hdrs = ["rocm_kernel.h"], tags = [ "gpu", "rocm-only", @@ -228,10 +230,35 @@ cc_library( ]), visibility = ["//visibility:public"], deps = [ + "//xla/stream_executor:launch_dim", "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_test( + name = "rocm_kernel_test", + srcs = ["rocm_kernel_test.cc"], + backends = ["gpu_amd_any"], + deps = [ + ":rocm_kernel", + ":rocm_runtime", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/stream_executor/gpu:gpu_test_kernels", + "@com_google_googletest//:gtest_main", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], - alwayslink = True, ) cc_library( diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index c7bca9c3851c3c..ac8f253a955ef4 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "rocm/include/hip/hip_runtime.h" #include "rocm/include/hip/hip_version.h" #include "rocm/rocm_config.h" #include "xla/stream_executor/blas.h" @@ -71,6 +72,7 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_driver.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_event.h" +#include "xla/stream_executor/rocm/rocm_kernel.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/rocm/rocm_runtime.h" #include "xla/stream_executor/rocm/rocm_version_parser.h" @@ -273,7 +275,7 @@ absl::Status RocmExecutor::Init() { absl::StatusOr> RocmExecutor::LoadKernel( const MultiKernelLoaderSpec& spec) { - auto rocm_kernel = std::make_unique(this); + auto rocm_kernel = std::make_unique(this); hipModule_t module = nullptr; const std::string* kernel_name; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc index e8eb84a89092e1..49b65415e73724 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc @@ -13,24 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/rocm/rocm_kernel.h" + +#include #include +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/launch_dim.h" namespace stream_executor { namespace gpu { -absl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( +absl::StatusOr RocmKernel::GetMaxOccupiedBlocksPerCore( ThreadDim threads, size_t dynamic_shared_memory_bytes) const { int32_t threads_per_block = threads.x * threads.y * threads.z; VLOG(0) << "Get kernel block occupancy: " << name() << "; threads_per_block: " << threads_per_block << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; - return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, - threads_per_block, - dynamic_shared_memory_bytes); + return GpuDriver::GetMaxOccupiedBlocksPerCore( + gpu_executor_->gpu_context(), rocm_function_, threads_per_block, + dynamic_shared_memory_bytes); } } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h new file mode 100644 index 00000000000000..26d20b667e7609 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h @@ -0,0 +1,69 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The CUDA implementation of the StreamExecutor functionality. +// CUDA inclusions are ideally confined to this implementation file. +// +// The notions from the StreamExecutor basically correspond to the CUDA streams +// programming model provided by the libcuda.so driver APIs, so we don't have +// to do much more than wrap the calls to the libraries appropriately. +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_KERNEL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "tsl/platform/logging.h" + +namespace stream_executor::gpu { + +class RocmKernel : public GpuKernel { + public: + explicit RocmKernel(GpuExecutor* gpu_executor) + : gpu_executor_(gpu_executor) {} + + // Note that the function is unloaded when the module is unloaded, and the + // module that the function is contained in is owned by the GpuExecutor. + ~RocmKernel() override { gpu_executor_->UnloadKernel(this); } + + // As arity cannot be reflected upon using the HIP API, the arity is + // explicitly set during the RocmExecutor::GetKernel initialization process. + void set_arity(unsigned arity) { arity_ = arity; } + unsigned Arity() const override { return arity_; } + + absl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + + // Simple accessor methods. + hipFunction_t gpu_function() const override { return rocm_function_; } + void set_gpu_function(hipFunction_t rocm_function) { + rocm_function_ = rocm_function; + } + + private: + GpuExecutor* gpu_executor_ = nullptr; + + hipFunction_t rocm_function_ = nullptr; // wrapped CUDA kernel handle + unsigned arity_ = 0; // number of formal parameters the kernel takes +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc new file mode 100644 index 00000000000000..cfff348f9b5b11 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_kernel.h" + +#include +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_runtime.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using testing::Ge; +using tsl::testing::IsOkAndHolds; + +TEST(RocmKernelTest, GetMaxOccupiedBlocksPerCore) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + PlatformManager::PlatformWithName("ROCM")); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + GpuExecutor* gpu_executor = ExtractGpuExecutor(executor); + + RocmKernel rocm_kernel(gpu_executor); + rocm_kernel.set_arity(3); + + TF_ASSERT_OK_AND_ASSIGN( + hipFunction_t function, + RocmRuntime::GetFuncBySymbol(internal::GetAddI32Kernel())); + + rocm_kernel.set_gpu_function(function); + + EXPECT_EQ(rocm_kernel.Arity(), 3); + EXPECT_EQ(rocm_kernel.gpu_function(), function); + + EXPECT_THAT(rocm_kernel.GetMaxOccupiedBlocksPerCore( + ThreadDim(1, 1, 1), /*dynamic_shared_memory_bytes=*/0), + IsOkAndHolds(Ge(1))); +} + +} // namespace +} // namespace stream_executor::gpu From 60e87ffaf0171352477ff6a30a1a8c91741cfc9e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Sep 2024 23:22:38 -0700 Subject: [PATCH 296/483] Destroy distributed client before service to avoid shutdown errors. PiperOrigin-RevId: 679005510 --- third_party/xla/xla/tools/multihost_hlo_runner/create_client.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.h b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.h index 230a39e67a0939..6c8bb2a4ca685c 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.h +++ b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.h @@ -31,8 +31,9 @@ limitations under the License. namespace xla { struct PjRtEnvironment { - std::unique_ptr client; + // Sequence matters here, client should be destroyed before service. std::unique_ptr service; + std::unique_ptr client; std::shared_ptr kv_store; std::shared_ptr distributed_client; }; From 85d25fe32247850f8b6558a77c2748b005ad07d9 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Wed, 25 Sep 2024 23:23:51 -0700 Subject: [PATCH 297/483] [XLA:GPU] Fix forward the flakiness of the test that was introduced in the cl/678283878 The test is flaky because on the different platforms different versions win the competition in autotuner. I.e. from time to time cublas version wins because it is faster than triton version. At the same time both versions are fine. PiperOrigin-RevId: 679005859 --- third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc b/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc index f731049a8f6f6f..23d21d9fbd1b56 100644 --- a/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc +++ b/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc @@ -228,7 +228,7 @@ INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest, Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32), Values(F32), Values(CC(8, 0)), Values(SemanticVersion{6, 0, 0}), - Values(BackendRestriction::kTritonOnly), + Values(BackendRestriction::kNoRestriction), Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); From 748c1437ae9cce2fbf36c3df619d93e771e5dee1 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Thu, 26 Sep 2024 00:02:02 -0700 Subject: [PATCH 298/483] Tag missing rocm-only targets as manual PiperOrigin-RevId: 679016215 --- third_party/xla/xla/backends/profiler/gpu/BUILD | 11 +++++++++-- third_party/xla/xla/stream_executor/BUILD | 5 ++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index 36275f118c52d8..b7fcada623f68e 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -11,6 +11,7 @@ load( load("//xla/tests:build_defs.bzl", "xla_test") load( "//xla/tsl:tsl.bzl", + "if_google", "internal_visibility", "tsl_copts", "tsl_gpu_library", @@ -223,7 +224,10 @@ tsl_gpu_library( tags = [ "gpu", "rocm-only", - ], + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], deps = [ "//xla/stream_executor/rocm:roctracer_wrapper", @@ -262,7 +266,10 @@ tsl_gpu_library( tags = [ "gpu", "rocm-only", - ], + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], deps = [ ":rocm_collector", diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index b0e8b0bdf5d299..639686129cbd35 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -937,7 +937,10 @@ alias( tags = [ "gpu", "rocm-only", - ], + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), ) alias( From f3b81492332df9af8923467e312710921066dc53 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 26 Sep 2024 00:22:44 -0700 Subject: [PATCH 299/483] [XLA:GPU] Add support for iota in the Triton fusion emitter. `Iota` must be treated like a parameter, i.e. it needs to be offset, and potentially strided. We therefore need to ensure that`tile_offsets_indexing` is always derived for the instruction. PiperOrigin-RevId: 679023133 --- .../xla/xla/service/gpu/fusions/triton/BUILD | 2 + .../fusions/triton/triton_fusion_emitter.cc | 54 +++++++++++++++ .../triton_fusion_emitter_device_test.cc | 69 +++++++++++++++++++ .../gpu/fusions/triton/triton_support.cc | 9 +++ third_party/xla/xla/service/gpu/model/BUILD | 1 + .../gpu/model/symbolic_tile_analysis.cc | 5 +- .../gpu/model/symbolic_tile_analysis.h | 2 +- .../gpu/model/symbolic_tile_analysis_test.cc | 24 +++++++ 8 files changed, 163 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 7532b1a4e1c7b9..e3c95865b8e5db 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -261,6 +261,8 @@ xla_test( ":triton_test_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 4a4245c1590670..a105789fa86f01 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -928,6 +928,56 @@ Value EmitTiledBroadcast( padded_output_tile_shape); } +absl::StatusOr EmitTiledIota(ImplicitLocOpBuilder& b, + ValueRange tile_multi_index, + const TiledHloInstruction& tiled_iota) { + const HloIotaInstruction* hlo_iota = + ::xla::Cast(tiled_iota.hlo()); + int64_t iota_dim = hlo_iota->iota_dimension(); + + SmallVector padded_tile_sizes = + GetPaddedTileSizes(tiled_iota.tile_sizes()); + + // We can treat iota more or less as a parameter load, except that we need to + // generate the right values in the right place as opposed to loading them. + TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, + tiled_iota.tile_offsets_indexing()); + + auto iota_dim_offset = b.create( + b.getI32Type(), mlir_converter::ApplyIndexing( + tile_offsets_indexing, /*dims=*/tile_multi_index, + /*symbols=*/{}, b)[iota_dim]); + + // First, stride as needed between the iota components. + Value range = b.create( + Range(b, padded_tile_sizes[iota_dim]), + Splat(b, + CreateConst(b, b.getI32Type(), tiled_iota.tile_strides()[iota_dim]), + padded_tile_sizes[iota_dim])); + + // Then, add the base offset to the iota components. + range = b.create( + range, Splat(b, iota_dim_offset, padded_tile_sizes[iota_dim])); + + // Cast the result to the targeted type. + TF_ASSIGN_OR_RETURN(Type iota_element_type, + TritonType(b, hlo_iota->shape().element_type())); + + range = Cast(b, range, iota_element_type); + + // And finally, produce a broadcast along the non-iota dimensions in order to + // produce the whole iota tile. + for (int i = 0; i < padded_tile_sizes.size() - 1; i++) { + if (i < iota_dim) { + range = b.create(range, /*axis=*/0); + } else { + range = b.create(range, /*axis=*/i + 1); + } + } + + return Broadcast(b, mlir::cast(range), padded_tile_sizes); +} + Value EmitTiledReshape(ImplicitLocOpBuilder& b, ArrayRef tile_sizes, Value input) { SmallVector padded_tile_sizes = GetPaddedTileSizes(tile_sizes); @@ -1057,6 +1107,10 @@ absl::StatusOr EmitTiledHloInstruction( absl::StrCat("Unsupported non-scalar constant ", hlo->ToString())); } + if (hlo->opcode() == HloOpcode::kIota) { + return EmitTiledIota(b, tile_multi_index, tiled_hlo); + } + if (hlo->opcode() == HloOpcode::kBroadcast) { return EmitTiledBroadcast(b, tiled_hlo, values); } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index f136f7190d1a64..d79a54f4615681 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "llvm/IR/LLVMContext.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" @@ -30,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/fusions/triton/triton_test_utils.h" @@ -40,6 +42,7 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -1196,6 +1199,72 @@ ENTRY main { RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } +TEST_F(TritonEmitterTest, StridedIota4DIsCodegeneratedCorrectly) { + constexpr std::string_view kHloText = R"( +triton_computation { + iota = f32[3,4,1000,5] iota(), iota_dimension=2 + ROOT slice = f32[3,4,182,5] slice(iota), slice={[0:3], [0:4], [91:1000:5], [0:5]} +} + +ENTRY main { + ROOT triton_fusion = f32[3,4,182,5] fusion(), + kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", + "block_level_fusion_config":{"output_tile_sizes":["1","2","64","8"], + "num_warps":"1"}}} +})"; + + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: %[[RANGE:.*]] = tt.make_range {{.*}} : tensor<64xi32> +CHECK: arith.muli{{.*}} %[[RANGE]] +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +class IotaEmitterParametrizedTest + : public TritonEmitterTest, + public ::testing::WithParamInterface {}; + +TEST_P(IotaEmitterParametrizedTest, Iota4DIsCodegeneratedCorrectly) { + auto data_type = GetParam(); + const std::string kHloText = + absl::Substitute(R"( +triton_computation { + ROOT iota = $0[3,4,1000,5] iota(), iota_dimension=2 +} + +ENTRY main { + ROOT triton_fusion = $0[3,4,1000,5] fusion(), + kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", + "block_level_fusion_config":{"output_tile_sizes":["1","2","64","8"], + "num_warps":"1"}}} +})", + primitive_util::LowercasePrimitiveTypeName(data_type)); + + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: %[[RANGE:.*]] = tt.make_range {{.*}} : tensor<64xi32> +CHECK: arith.addi{{.*}} %[[RANGE]] + // Omit the data type below, since it depends on a test parameter + // and is not abbreviated the same as in HLO. +CHECK: tt.broadcast {{.*}} -> tensor<1x2x64x8x +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +INSTANTIATE_TEST_SUITE_P(IotaEmitterParametrizedTestSuite, + IotaEmitterParametrizedTest, + ::testing::ValuesIn({S8, S16, S32, S64, BF16, F16, F32, + F64})); + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc index d0a33343fa2237..d006ae65fcc550 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc @@ -285,6 +285,15 @@ CodegenDecision IsTritonSupportedInstructionImpl( "Only scalar constants are supported in Triton."); } + if (instr.opcode() == HloOpcode::kIota) { + PrimitiveType element_type = instr.shape().element_type(); + return element_type != PrimitiveType::F8E4M3FN && + element_type != PrimitiveType::F8E5M2 + ? CodegenDecision::Allow() + : CodegenDecision::Forbid( + "F8E4M3FN and F8E5M2 are not supported for iota."); + } + if (instr.IsElementwise()) { if (!IsTritonSupportedElementwise( instr.opcode(), diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 7a4790b1c3adaf..fccc5e66360eb8 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -780,6 +780,7 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 9b9e637904210a..a08b8aaa3ea908 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -345,7 +345,7 @@ void SortTiledHloInstructionsInPostOrder( }); } -} // namespace +} // anonymous namespace /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation( const HloComputation& computation, MLIRContext* ctx, @@ -562,7 +562,8 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( std::optional tile_offset_indexing; if (compute_all_tile_offset_indexing_maps || - parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo())) { + parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo()) || + symbolic_tiled_hlo->hlo()->opcode() == HloOpcode::kIota) { TF_ASSIGN_OR_RETURN( tile_offset_indexing, ComputeTileOffsetIndexing( diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index 58a08afde9ba17..775de1670f51ea 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -91,7 +91,7 @@ class SymbolicTileAnalysis { // Returns a graph of HLO instructions tiled with the given tile parameters. // The provided tile parameters must satisfy the analysis's constraints. - // By default, `ComputetiledHloInstructions` performs a check that the + // By default, `ComputeTiledHloInstructions` performs a check that the // constraints are satisfied by the chosen tiled parameters. Setting // `constraints_are_known_satisfied` to true bypasses this check. // diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 99f136d3aecae2..f4166076234a7e 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -1020,6 +1021,29 @@ ENTRY main { EXPECT_TRUE(analysis.has_value()); } +TEST_F(SymbolicTileAnalysisTest, IotaAlwaysHasTileOffsetsIndexingSet) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion { + ROOT iota = s32[100] iota(), iota_dimension=0 +} + +ENTRY main { + ROOT fusion = s32[100] fusion(), kind=kLoop, calls=fusion +})")); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{4}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/false)); + + const TiledHloInstruction* iota = tiled_hlo_computation.GetRoot(); + EXPECT_THAT(iota->tile_offsets_indexing().status(), ::tsl::testing::IsOk()); +} + } // namespace } // namespace gpu } // namespace xla From 53436d175ea9ac41a938c3b9c54efb18d9d85caa Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Thu, 26 Sep 2024 01:03:08 -0700 Subject: [PATCH 300/483] Fix flake due to unordered elements PiperOrigin-RevId: 679033903 --- third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc | 4 ++-- third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index 967df31ba84397..38c11e1ff8023a 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -957,9 +957,9 @@ VariableConstraints GetConstraintsForVariables(const IndexingMap& map) { for (const auto& constraint : map.GetConstraints()) { constraint.first.walk([&](mlir::AffineExpr leaf) { if (auto dim = mlir::dyn_cast(leaf)) { - result.constraints_for_dims[dim.getPosition()].push_back(constraint); + result.constraints_for_dims[dim.getPosition()].insert(constraint); } else if (auto sym = mlir::dyn_cast(leaf)) { - result.constraints_for_symbols[sym.getPosition()].push_back(constraint); + result.constraints_for_symbols[sym.getPosition()].insert(constraint); } }); } diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h index e47f2cf3e2d323..28e46396a34091 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Bytecode/BytecodeOpInterface.h" // IWYU pragma: keep #include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep @@ -42,9 +43,9 @@ limitations under the License. namespace xla::gpu { struct VariableConstraints { - llvm::SmallVector>> + llvm::SmallVector> constraints_for_dims; - llvm::SmallVector>> + llvm::SmallVector> constraints_for_symbols; }; VariableConstraints GetConstraintsForVariables(const IndexingMap& map); From c26670cb1ac98244e0fcb5655a55603df52383c1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 01:06:13 -0700 Subject: [PATCH 301/483] Temporary enable flatbuffer verification assertion to investigate a crash on Windows. PiperOrigin-RevId: 679034841 --- .../compiler/mlir/lite/core/model_builder_base.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.h b/tensorflow/compiler/mlir/lite/core/model_builder_base.h index 3484f4e3d071d6..6856ac7260550c 100644 --- a/tensorflow/compiler/mlir/lite/core/model_builder_base.h +++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.h @@ -386,9 +386,15 @@ class FlatBufferModelBase { size_t allocation_size = std::min(allocation->bytes(), static_cast(FLATBUFFERS_MAX_BUFFER_SIZE - 1)); + flatbuffers::Verifier::Options options; + // TODO(b/366118885): Remove after the root cause of the crash on Windows + // is found. +#if defined(_WIN32) + options.assert = true; +#endif flatbuffers::Verifier base_verifier( - reinterpret_cast(allocation->base()), - allocation_size); + reinterpret_cast(allocation->base()), allocation_size, + options); if (!VerifyModelBuffer(base_verifier)) { TF_LITE_REPORT_ERROR(error_reporter, "The model is not a valid Flatbuffer buffer"); From dd53fe8d7157662683395c4d14f21aaedd6ad2e0 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 26 Sep 2024 01:15:44 -0700 Subject: [PATCH 302/483] [XLA:GPU] Disable a flaky test PiperOrigin-RevId: 679037782 --- third_party/xla/xla/tests/collective_ops_e2e_test.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index cdcb82919b67f8..479935f7d01f66 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -450,6 +450,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithSplitDim) { } TEST_F(CollectiveOpsTestE2E, AsyncAllToAllMemCpy) { + // TODO(b/369751308): Re-enable this test after the threading issues are + // fixed. + GTEST_SKIP() << "This test is flaky. See b/369751308"; + const absl::string_view kModuleStr = R"( HloModule test ENTRY test_computation { From f0933e089a6319eed606c0b2dbdc5508ee6bdb12 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 01:39:23 -0700 Subject: [PATCH 303/483] Automated Code Change PiperOrigin-RevId: 679044777 --- .../core/common_runtime/optimized_function_graph_info.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/optimized_function_graph_info.h b/tensorflow/core/common_runtime/optimized_function_graph_info.h index b15790dbeede36..c23d722176bf07 100644 --- a/tensorflow/core/common_runtime/optimized_function_graph_info.h +++ b/tensorflow/core/common_runtime/optimized_function_graph_info.h @@ -73,8 +73,8 @@ struct OptimizedFunctionGraphInfo { delete; OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) = default; // NOLINT - OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo&& info) = - default; // NOLINT + OptimizedFunctionGraphInfo& operator=( + OptimizedFunctionGraphInfo&& info) noexcept = default; // NOLINT // Converts from the struct to OptimizedFunctionGraph proto. static OptimizedFunctionGraph ToProto(const OptimizedFunctionGraphInfo& info); From 5de3027e14a1fe82928adcc05a012d9544a2eed1 Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Thu, 26 Sep 2024 01:46:52 -0700 Subject: [PATCH 304/483] PR #16913: [PJRT:GPU] Enable creating topology without a GPU device Imported from GitHub PR https://github.com/openxla/xla/pull/16913 Currently PJRT_TopologyDescription_Create always creates topology from the local client. This requires having a local GPU device. This patch allows explicitly specifying topology shape and device config in PJRT_TopologyDescription_Create call, without querying local client. This enables deviceless compilation. Copybara import of the project: -- bc85038cbdfbed4b43b5859037515ef9049ecfec by Jaroslav Sevcik : [PJRT:GPU] Enable creating topology without a GPU device -- e4eb44ecf356f0c7d859c969ff38b69e52d569c5 by Jaroslav Sevcik : Enable overlaying topology on local device -- ff25a5790bfba35262f9649cb6df234fc2af5a3d by Jaroslav Sevcik : Address reviewer comments -- 8bfb7a251f1f687b65e6db3744dc7c76d60a321e by Jaroslav Sevcik : Cleanup Merging this change closes #16913 PiperOrigin-RevId: 679047171 --- .../xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 121 +++++++++++++++--- .../xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 106 +++++++++++++++ 2 files changed, 208 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 5bddd6f2660e8e..ae6d84ae039a05 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -167,19 +167,37 @@ PJRT_Error* PJRT_ExecuteContext_Create(PJRT_ExecuteContext_Create_Args* args) { return nullptr; } -PJRT_Error* PJRT_GpuDeviceTopology_Create( - PJRT_TopologyDescription_Create_Args* args) { - PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( - "PJRT_TopologyDescription_Create_Args", - PJRT_TopologyDescription_Create_Args_STRUCT_SIZE, args->struct_size)); +namespace { - PJRT_ASSIGN_OR_RETURN(xla::LocalClient * xla_client, - xla::GetGpuXlaClient(/*platform_name=*/std::nullopt, - /*allowed_devices=*/std::nullopt)); +struct TargetConfigAndDevices { + stream_executor::GpuTargetConfigProto target_config_proto; + std::vector device_ids; +}; + +// Parses the 'target_config' entry in 'options'. The option is +// parsed as GpuTargetConfigProto. If there is no 'target_config' in +// 'options', the function falls back to creating a local client, +// returning the local client's target config. +absl::StatusOr GetTargetConfigFromOptions( + const absl::flat_hash_map& options) { + if (auto target_config_it = options.find("target_config"); + target_config_it != options.end()) { + std::string target_config_proto_string = + std::get(target_config_it->second); + stream_executor::GpuTargetConfigProto target_config_proto; + if (!tsl::protobuf::TextFormat::ParseFromString(target_config_proto_string, + &target_config_proto)) { + return absl::FailedPreconditionError( + "Failed to parse GpuTargetConfigProto " + "from the 'target_config' parameter."); + } + return {{target_config_proto, {}}}; + } + TF_ASSIGN_OR_RETURN(xla::LocalClient * xla_client, + xla::GetGpuXlaClient(/*platform_name=*/std::nullopt, + /*allowed_devices=*/std::nullopt)); stream_executor::StreamExecutor* executor = xla_client->backend().default_stream_executor(); - const stream_executor::DeviceDescription& description = - executor->GetDeviceDescription(); std::vector device_ids; device_ids.reserve(xla_client->backend().stream_executors().size()); for (stream_executor::StreamExecutor* executor : @@ -187,13 +205,43 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create( device_ids.push_back(executor->device_ordinal()); } auto gpu_target_config = xla::Compiler::TargetConfig(executor); - // TODO(b/341334898): Create a single-host GPU topology. Will be updated for - // multi-host support in the future. - auto gpu_topology = std::make_shared( - device_ids, description.name(), - /*num_slices=*/1, - /*num_hosts_per_slice=*/1, - /*num_devices_per_host=*/device_ids.size()); + return {{gpu_target_config.ToProto(), device_ids}}; +} + +struct TopologySizes { + int num_slices = 0; + int num_hosts_per_slice = 0; + int num_devices_per_host = 0; + + int GetDeviceCount() { + return num_slices * num_hosts_per_slice * num_devices_per_host; + } + + static absl::StatusOr FromString( + std::string_view topology_string) { + TopologySizes sizes; + std::vector topology_components = + absl::StrSplit(topology_string, 'x'); + if (topology_components.size() != 3 || + !absl::SimpleAtoi(topology_components[0], &sizes.num_slices) || + !absl::SimpleAtoi(topology_components[1], &sizes.num_hosts_per_slice) || + !absl::SimpleAtoi(topology_components[2], + &sizes.num_devices_per_host)) { + return absl::InternalError( + "topology must be of shape " + "\"xx\""); + } + return sizes; + } +}; + +} // namespace + +PJRT_Error* PJRT_GpuDeviceTopology_Create( + PJRT_TopologyDescription_Create_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_TopologyDescription_Create_Args", + PJRT_TopologyDescription_Create_Args_STRUCT_SIZE, args->struct_size)); // Determine the platform ID and name based on the platform. xla::PjRtPlatformId platform_id = @@ -203,12 +251,47 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create( (std::string(PJRT_GPU_PLUGIN_PLATFORM_NAME) == "ROCM") ? xla::RocmName() : xla::CudaName(); + absl::flat_hash_map create_options = + pjrt::ConvertFromPjRtNamedValueList(args->create_options, + args->num_options); + + PJRT_ASSIGN_OR_RETURN(TargetConfigAndDevices target_config_and_devices, + GetTargetConfigFromOptions(create_options)); + + std::vector& device_ids = target_config_and_devices.device_ids; + stream_executor::GpuTargetConfigProto& target_config_proto = + target_config_and_devices.target_config_proto; + TopologySizes sizes{1, 1, static_cast(device_ids.size())}; + + if (auto topology_it = create_options.find("topology"); + topology_it != create_options.end()) { + std::string topology_string = std::get(topology_it->second); + PJRT_ASSIGN_OR_RETURN(sizes, TopologySizes::FromString(topology_string)); + } + + if (sizes.GetDeviceCount() == 0) { + // If the user did not specify the topology and we did not + // get any devices from the client, then error out because + // we do not know how many devices the topology should have. + return new PJRT_Error{absl::FailedPreconditionError( + "Cannot create topology without an explicit topology shape or without " + "a client")}; + } + + if (sizes.GetDeviceCount() != device_ids.size()) { + device_ids.resize(sizes.GetDeviceCount()); + absl::c_iota(device_ids, sizes.GetDeviceCount()); + } + + auto gpu_topology = std::make_shared( + device_ids, target_config_proto.device_description_str(), + sizes.num_slices, sizes.num_hosts_per_slice, sizes.num_devices_per_host); + auto pjrt_topology = std::make_unique( platform_id, platform_name, std::move(gpu_topology), absl::flat_hash_map{ - {"target_config", - gpu_target_config.ToProto().SerializeAsString()}}); + {"target_config", target_config_proto.SerializeAsString()}}); args->topology = CreateWrapperDeviceTopology(std::move(pjrt_topology)); return nullptr; } diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 17d9c9d72228f3..57d2c31d320107 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -427,6 +427,8 @@ TEST(PJRTGpuDeviceTopologyTest, CreateGpuTopology) { args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE; args.extension_start = nullptr; args.topology = nullptr; + args.num_options = 0; + args.create_options = nullptr; PJRT_Error* error = pjrt_api->PJRT_TopologyDescription_Create(&args); EXPECT_EQ(error, nullptr) << error->status.message(); @@ -452,6 +454,110 @@ TEST(PJRTGpuDeviceTopologyTest, CreateGpuTopology) { EXPECT_EQ(destroy_error, nullptr) << destroy_error->status.message(); } +constexpr char const* kTargetConfigString = R"(gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 49152 + shared_memory_per_core: 98304 + threads_per_core_limit: 2048 + core_count: 80 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 898048000000 + l2_cache_size: 6291456 + clock_rate_ghz: 1.53 + device_memory_size: 34072559616 + shared_memory_per_block_optin: 98304 + cuda_compute_capability { + major: 7 + } + registers_per_core_limit: 65536 + registers_per_block_limit: 65536 +} +platform_name: "CUDA" +dnn_version_info { + major: 9 + minor: 3 +} +device_description_str: "Tesla V100-SXM2-32GB" +)"; + +TEST(PJRTGpuDeviceTopologyTest, CreateExplicitGpuTopologyAndTargetConfig) { + auto pjrt_api = gpu_plugin::GetGpuPjrtApi(); + + absl::flat_hash_map options = { + {"topology", static_cast("16 x 2 x 4")}, + {"target_config", static_cast(kTargetConfigString)}}; + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); + + PJRT_TopologyDescription_Create_Args args; + args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.topology = nullptr; + args.num_options = c_options.size(); + args.create_options = c_options.data(); + + PJRT_Error* error = pjrt_api->PJRT_TopologyDescription_Create(&args); + EXPECT_EQ(error, nullptr) << error->status.message(); + + auto pjrt_topology = + reinterpret_cast(args.topology); + ASSERT_NE(pjrt_topology, nullptr); + + EXPECT_EQ(pjrt_topology->topology->platform_id(), xla::CudaId()); + EXPECT_EQ(pjrt_topology->topology->platform_name(), xla::CudaName()); + + EXPECT_EQ(pjrt_topology->topology->ProcessCount().value(), 16 * 2); + EXPECT_EQ(pjrt_topology->topology->DeviceDescriptions().size(), 16 * 2 * 4); + EXPECT_EQ(pjrt_topology->topology->DeviceDescriptions()[0]->device_kind(), + "Tesla V100-SXM2-32GB"); + + PJRT_TopologyDescription_Destroy_Args destroy_args; + destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE; + destroy_args.extension_start = nullptr; + destroy_args.topology = const_cast(pjrt_topology); + PJRT_Error* destroy_error = + pjrt_api->PJRT_TopologyDescription_Destroy(&destroy_args); + EXPECT_EQ(destroy_error, nullptr) << destroy_error->status.message(); +} + +TEST(PJRTGpuDeviceTopologyTest, CreateExplicitGpuTopology) { + auto pjrt_api = gpu_plugin::GetGpuPjrtApi(); + + absl::flat_hash_map options = { + {"topology", static_cast("16 x 2 x 4")}}; + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); + + PJRT_TopologyDescription_Create_Args args; + args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.topology = nullptr; + args.num_options = c_options.size(); + args.create_options = c_options.data(); + + PJRT_Error* error = pjrt_api->PJRT_TopologyDescription_Create(&args); + EXPECT_EQ(error, nullptr) << error->status.message(); + + auto pjrt_topology = + reinterpret_cast(args.topology); + ASSERT_NE(pjrt_topology, nullptr); + + EXPECT_EQ(pjrt_topology->topology->ProcessCount().value(), 16 * 2); + EXPECT_EQ(pjrt_topology->topology->DeviceDescriptions().size(), 16 * 2 * 4); + + PJRT_TopologyDescription_Destroy_Args destroy_args; + destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE; + destroy_args.extension_start = nullptr; + destroy_args.topology = const_cast(pjrt_topology); + PJRT_Error* destroy_error = + pjrt_api->PJRT_TopologyDescription_Destroy(&destroy_args); + EXPECT_EQ(destroy_error, nullptr) << destroy_error->status.message(); +} + void TestCustomCallV2() {} TEST(PjrtCApiGpuExtensionTest, CustomCallUntyped) { From 37f9ca56611703e8960ef1434ff5b5d183d4c52b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 02:16:43 -0700 Subject: [PATCH 305/483] Automated Code Change PiperOrigin-RevId: 679056498 --- tensorflow/compiler/tf2xla/BUILD | 22 +++++++++++++++ tensorflow/compiler/tf2xla/xla_op_kernel.h | 10 +++++++ tensorflow/compiler/tf2xla/xla_op_registry.cc | 27 ++++++++++++++----- tensorflow/compiler/tf2xla/xla_op_registry.h | 9 +++++++ .../compiler/tf2xla/xla_op_registry_test.cc | 5 ++++ tensorflow/compiler/tf2xla/xla_resource.cc | 17 +++++++++--- tensorflow/compiler/tf2xla/xla_resource.h | 3 +++ 7 files changed, 83 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 257720495c651e..b85f9cbb58751e 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -520,15 +520,18 @@ cc_library( "//tensorflow/core/tpu:tpu_defs", "//tensorflow/core/util:overflow", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:executable_run_options", + "@local_xla//xla:literal", "@local_xla//xla:protobuf_util", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", @@ -634,6 +637,16 @@ cc_library( "//tensorflow/core/common_runtime/next_pluggable_device:next_pluggable_device_factory_hdrs", "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/tfrt/common:pjrt_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:status", + "@local_xla//xla:util", "@local_xla//xla/client:client_library", ], alwayslink = 1, @@ -681,8 +694,14 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@local_xla//xla:shape_util", + "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", ], @@ -1384,8 +1403,11 @@ tf_cc_test( srcs = ["xla_op_registry_test.cc"], deps = [ ":xla_compiler", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/log", ], ) diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a66be8384f003c..602075c2f1bff4 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" @@ -23,9 +27,15 @@ limitations under the License. #include "xla/client/value_inference.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 0109f6a3f07ef3..95e80fc1fb9205 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -19,21 +19,34 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" -#include "xla/client/client_library.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/local_device.h" +#include "xla/util.h" #include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/device_factory.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_util.h" -#include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 333a9168f3deda..aa3130e57a70fd 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -22,16 +22,25 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" +#include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc index 4d8e1bc31f8d58..13b648b78004ac 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc @@ -14,7 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +#include "absl/log/log.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 0e1d33a0c1c718..40b18e5c4e02f9 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -18,11 +18,22 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/xla_builder.h" +#include "xla/status_macros.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/managed_stack_trace.h" +#include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 902c62edd5664a..f08508cfe7591d 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -19,11 +19,14 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "xla/client/xla_builder.h" +#include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/managed_stack_trace.h" namespace tensorflow { From e8778b598ba6fad3aac852fbf27d73fc0249abdc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 02:23:41 -0700 Subject: [PATCH 306/483] Automated Code Change PiperOrigin-RevId: 679058248 --- .../quantization/common/quantization_lib/quantization_driver.cc | 1 - .../common/quantization_lib/quantization_driver_test.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc index 7645177160fc62..4ab942e6864a6d 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc index f017054cbe7044..60f5c05530820b 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project From 1cb9ed26bcb2ccee706f27a13e8ddc0e7e8927eb Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Thu, 26 Sep 2024 02:39:38 -0700 Subject: [PATCH 307/483] Remove enable_xlir build flag It's not in use anymore. PiperOrigin-RevId: 679062623 --- third_party/xla/xla/service/gpu/BUILD | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index edc11b7c230bf2..a33e71fc84dac7 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1,7 +1,6 @@ # Description: # GPU-specific components in XLA service implementation. -load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load( "@local_config_rocm//rocm:build_defs.bzl", @@ -539,12 +538,6 @@ cc_library( ], ) -# TODO(b/244780257): Remove this config. -bool_flag( - name = "enable_xlir", - build_setting_default = if_google(True, False), -) - cc_library( name = "gpu_executable", srcs = [ From 49adb615f69ebc8df98959108d5293ca01dbdff2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 03:39:48 -0700 Subject: [PATCH 308/483] Automated Code Change PiperOrigin-RevId: 679079122 --- .../tools/evaluation/tasks/imagenet_image_classification/BUILD | 1 + .../evaluation/tasks/imagenet_image_classification/run_eval.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD index 3a0b2fb9eb0633..82b1567212e997 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD @@ -18,6 +18,7 @@ cc_library( "//tensorflow/lite/core/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc index c2a082447866aa..1dbb26a0176d91 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h" From 9e07b78bca3e8e4a6908c1e54955f3b2f1cfa348 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 04:26:51 -0700 Subject: [PATCH 309/483] Update GraphDef version to 1997. PiperOrigin-RevId: 679092055 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index b7d03cb07c7474..ce31c7a2e1df87 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1996 // Updated: 2024/9/25 +#define TF_GRAPH_DEF_VERSION 1997 // Updated: 2024/9/26 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 07ba96e56a9d7ba29bd831f8d38378ba66e7dd63 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 04:26:52 -0700 Subject: [PATCH 310/483] compat: Update forward compatibility horizon to 2024-09-26 PiperOrigin-RevId: 679092057 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 39c9f37847b730..f8107f5d868557 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 25) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 26) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 9a8d5f81406eb25f3448f8575670459f5982815a Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 26 Sep 2024 04:44:09 -0700 Subject: [PATCH 311/483] PR #17625: [GPU] Optimize zero-clamping of index operands known to be non-negative. Imported from GitHub PR https://github.com/openxla/xla/pull/17625 This is a minor optimization and fixes LaxBackedScipyTests.testSphHarmAccuracy from JAX lax_scipy_test on H100 with CUDA 12.5.0 which has incorrect handling of max(min(abs(x), y), z) in ptxas. Copybara import of the project: -- 02fb5c50395d970ee33a9ecd8b0cb9df86037a17 by Ilia Sergachev : [GPU] Optimize zero-clamping of index operands known to be non-negative. This is a minor optimization and fixes LaxBackedScipyTests.testSphHarmAccuracy from JAX lax_scipy_test on H100 with CUDA 12.5.0 which has incorrect handling of max(min(abs(x), y), z) in ptxas. -- 68925a17eb7f0c9fb265c8b5891b42e035f597a7 by Ilia Sergachev : Address review comments. Support arbitrary integer data type width; propagate range of the input if known. Merging this change closes #17625 PiperOrigin-RevId: 679096544 --- .../gpu/fusions/transforms/simplify_arith.cc | 34 +++++++++++ .../transforms/tests/simplify_arith.mlir | 57 +++++++++++++++++++ third_party/xla/xla/service/gpu/tests/BUILD | 1 + .../gpu/tests/zero_clamp_abs_index.hlo | 13 +++++ 4 files changed, 105 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/tests/zero_clamp_abs_index.hlo diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc index f3d67e24ee3248..72b1c0a22628e4 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" @@ -228,6 +229,35 @@ struct RewriteTruncExtShuffle : public OpRewritePattern { } }; +static std::optional GetSelectRange(mlir::Operation* sel) { + // Match |x| implemented as (x >= 0) ? x : (0 - x). + mlir::Value x = sel->getOperand(1); + auto m_x = mlir::matchers::m_Val(x); + if (!x.getType().isSignlessIntOrIndex() || + !mlir::matchPattern( + sel, mlir::m_Op( + mlir::m_Op(m_x, mlir::m_Zero()), m_x, + mlir::m_Op(mlir::m_Zero(), m_x)))) { + return std::nullopt; + } + if (sel->getOperand(0).getDefiningOp().getPredicate() != + CmpIPredicate::sge) { + return std::nullopt; + } + // Annotate |x| as >= 0. + Interval result{0, + static_cast( + (1ull << (x.getType().getIntOrFloatBitWidth() - 1)) - 1)}; + std::optional x_range = GetRange(x); + if (x_range.has_value()) { + Interval positive_range = x_range->max({0, 0}); + Interval negative_range = -x_range->min({0, 0}); + Interval abs_range = positive_range.Union(negative_range); + return result.Intersect(abs_range); + } + return result; +} + void AnnotateRanges(mlir::func::FuncOp func) { func->walk([](mlir::Operation* op) { if (op->getNumResults() != 1) { @@ -262,6 +292,10 @@ void AnnotateRanges(mlir::func::FuncOp func) { } else { out_range = lhs_range * rhs_range; } + } else if (mlir::isa(op)) { + out_range = GetRange(op->getOperand(0)); + } else if (mlir::isa(op)) { + out_range = GetSelectRange(op); } if (out_range) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir index 9524c3d32cc6c2..b301a3bbc93a74 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir @@ -348,3 +348,60 @@ func.func @dus(%arg0: tensor<20x30xf32>, %arg1: tensor<5x6xf32>, %arg2: i32, %ar // CHECK-SAME: xla.range = [0 : index, 19 : index] // CHECK: arith.addi // CHECK-SAME: xla.range = [0 : index, 29 : index] + +// ----- + +module { + func.func @annotate_range_abs_index(%v: i32) -> index { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.cmpi sge, %v, %c0_i32 : i32 + %1 = arith.subi %c0_i32, %v : i32 + %2 = arith.select %0, %v, %1 : i32 + %3 = arith.index_cast %2 : i32 to index + return %3: index + } +} + +// CHECK-LABEL: @annotate_range_abs +// CHECK: arith.select +// CHECK-SAME: xla.range = [0 : index, 2147483647 : index] +// CHECK-NEXT: arith.index_cast +// CHECK-SAME: xla.range = [0 : index, 2147483647 : index] + +// ----- + +module { + func.func @annotate_range_abs_index(%v: i32 {xla.range = [-31 : i32, 17 : i32]}) -> index { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.cmpi sge, %v, %c0_i32 : i32 + %1 = arith.subi %c0_i32, %v : i32 + %2 = arith.select %0, %v, %1 : i32 + %3 = arith.index_cast %2 : i32 to index + return %3: index + } +} + +// CHECK-LABEL: @annotate_range_abs +// CHECK: arith.select +// CHECK-SAME: xla.range = [0 : index, 31 : index] +// CHECK-NEXT: arith.index_cast +// CHECK-SAME: xla.range = [0 : index, 31 : index] + +// ----- + +module { + func.func @annotate_range_abs_index(%v: i32 {xla.range = [-5 : i32, 3 : i32]}) -> index { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.cmpi sge, %v, %c0_i32 : i32 + %1 = arith.subi %c0_i32, %v : i32 + %2 = arith.select %0, %v, %1 : i32 + %3 = arith.index_cast %2 : i32 to index + return %3: index + } +} + +// CHECK-LABEL: @annotate_range_abs +// CHECK: arith.select +// CHECK-SAME: xla.range = [0 : index, 5 : index] +// CHECK-NEXT: arith.index_cast +// CHECK-SAME: xla.range = [0 : index, 5 : index] diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index c8632c377da2d7..76946321d410f2 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -650,6 +650,7 @@ lit_test_suite( "transpose_210.hlo", "transpose_210_extra_output.hlo", "triton_naming.hlo", + "zero_clamp_abs_index.hlo", ], include = [ "*.hlo", diff --git a/third_party/xla/xla/service/gpu/tests/zero_clamp_abs_index.hlo b/third_party/xla/xla/service/gpu/tests/zero_clamp_abs_index.hlo new file mode 100644 index 00000000000000..59f448644172d4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/zero_clamp_abs_index.hlo @@ -0,0 +1,13 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s + +e { + p0 = s32[8,9] parameter(0) + p1 = s32[5] parameter(1) + a = s32[5] abs(p1) + ROOT r = s32[5,2,3] gather(p0, a), + offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0}, + index_vector_dim=1, slice_sizes={2,3} +} + +// CHECK: llvm.smin.i32 +// CHECK-NOT: llvm.smax.i32 From 2ff6cdb83899a6f2c406e1b78aeac97ba383c7a2 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Thu, 26 Sep 2024 05:28:47 -0700 Subject: [PATCH 312/483] Reverts 1cb9ed26bcb2ccee706f27a13e8ddc0e7e8927eb PiperOrigin-RevId: 679108603 --- third_party/xla/xla/service/gpu/BUILD | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index a33e71fc84dac7..edc11b7c230bf2 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1,6 +1,7 @@ # Description: # GPU-specific components in XLA service implementation. +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load( "@local_config_rocm//rocm:build_defs.bzl", @@ -538,6 +539,12 @@ cc_library( ], ) +# TODO(b/244780257): Remove this config. +bool_flag( + name = "enable_xlir", + build_setting_default = if_google(True, False), +) + cc_library( name = "gpu_executable", srcs = [ From 8a7c8178e7f640de4fc93ff6b05f08953bfbb5d3 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 26 Sep 2024 05:29:37 -0700 Subject: [PATCH 313/483] Don't bail out analyzing dots. The logic is correct and it doesn't hurt existing code. This is in preparation for adding a pass to nest gemm fusions. PiperOrigin-RevId: 679108860 --- .../gpu/model/symbolic_tile_analysis.cc | 3 +- .../gpu/model/symbolic_tile_analysis_test.cc | 57 ++++++++++++++++--- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index a08b8aaa3ea908..a3f2e77c97656c 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -228,8 +228,7 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation( // Bail out on instructions that are known to cause problems down the // line. This is not an inherent limitation of the approach, but simply // issues to be resolved in the current implementation. - if (hlo->opcode() == HloOpcode::kDot || - hlo->opcode() == HloOpcode::kConcatenate) { + if (hlo->opcode() == HloOpcode::kConcatenate) { return FusionDecision::Forbid("Bailing out on ") << hlo->ToString(); } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index f4166076234a7e..db62388f89f099 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -399,23 +399,62 @@ ENTRY main { )")); } -TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedDot) { +TEST_F(SymbolicTileAnalysisTest, DotOffsetIndexingIsCorrect) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( fusion { - p0 = f32[1,2]{1,0} parameter(0) - p1 = f32[2,3]{1,0} parameter(1) - ROOT dot = f32[1,3]{1,0} dot(p0, p1), - lhs_batch_dims={}, rhs_batch_dims={}, + p0 = f32[4,8] parameter(0) + p1 = f32[8,16] parameter(1) + ROOT dot = f32[4,16] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY main { - p0 = f32[1,2]{1,0} parameter(0) - p1 = f32[2,3]{1,0} parameter(1) - ROOT fusion = f32[1,3]{1,0} fusion(p0, p1), kind=kLoop, calls=fusion + p0 = f32[4,8] parameter(0) + p1 = f32[8,16] parameter(1) + ROOT fusion = f32[4,16] fusion(p0, p1), kind=kLoop, calls=fusion })")); - EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value()); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{2, 2}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); + + const TiledHloInstruction* dot = tiled_hlo_computation.GetRoot(); + EXPECT_THAT(*dot, MatchTiledHloInstruction( + /*tile_sizes=*/{2, 2}, /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/R"( + (d0, d1) -> (d0 * 2, d1 * 2), + domain: + d0 in [0, 1], + d1 in [0, 7], + is_simplified: true + )")); + + const TiledHloInstruction* lhs = dot->operand(0); + EXPECT_THAT(*lhs, MatchTiledHloInstruction( + /*tile_sizes=*/{2, 8}, /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/R"( + (d0, d1) -> (d0 * 2, 0), + domain: + d0 in [0, 1], + d1 in [0, 7], + is_simplified: true + )")); + + const TiledHloInstruction* rhs = dot->operand(1); + EXPECT_THAT(*rhs, MatchTiledHloInstruction( + /*tile_sizes=*/{8, 2}, /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/R"( + (d0, d1) -> (0, d1 * 2), + domain: + d0 in [0, 1], + d1 in [0, 7], + is_simplified: true + )")); } TEST_F(SymbolicTileAnalysisTest, DoesNotBailOutOnConstrainedReshape) { From 351e4745981fe68cf5509481d7d4a4537d5c9ab4 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 26 Sep 2024 06:05:48 -0700 Subject: [PATCH 314/483] [XLA:GPU] Remove AffineMapPrinter. This is a preparation step for printing variable names depending on their VariableType. PiperOrigin-RevId: 679118294 --- .../xla/xla/service/gpu/fusions/ir/BUILD | 1 - .../service/gpu/fusions/ir/xla_gpu_attrs.cc | 4 +- .../xla/service/gpu/fusions/ir/xla_gpu_ops.cc | 5 +- .../xla/xla/service/gpu/fusions/legacy/BUILD | 11 +- .../gpu/fusions/legacy/concatenate_test.cc | 37 +-- .../in_place_dynamic_update_slice_test.cc | 15 +- .../gpu/fusions/legacy/input_slices_test.cc | 15 +- .../service/gpu/fusions/legacy/loop_test.cc | 32 +- .../gpu/fusions/legacy/reduction_test.cc | 10 +- .../gpu/fusions/legacy/scatter_test.cc | 57 ++-- .../gpu/fusions/legacy/transpose_test.cc | 21 +- .../gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 6 +- .../gpu/fusions/tools/test_correctness.cc | 3 +- .../gpu/fusions/transforms/peel_loops.cc | 7 +- .../xla/xla/service/gpu/fusions/triton/BUILD | 1 - third_party/xla/xla/service/gpu/model/BUILD | 51 +-- .../service/gpu/model/affine_map_printer.cc | 269 ---------------- .../service/gpu/model/affine_map_printer.h | 67 ---- .../gpu/model/affine_map_printer_test.cc | 59 ---- .../model/gpu_indexing_performance_model.cc | 2 +- .../service/gpu/model/indexing_analysis.cc | 23 +- .../xla/service/gpu/model/indexing_analysis.h | 5 +- .../gpu/model/indexing_analysis_test.cc | 15 +- .../xla/xla/service/gpu/model/indexing_map.cc | 87 +---- .../xla/xla/service/gpu/model/indexing_map.h | 31 +- .../gpu/model/indexing_map_serialization.cc | 304 +++++++++++++++++- .../gpu/model/indexing_map_serialization.h | 37 +++ .../model/indexing_map_serialization_test.cc | 21 +- .../service/gpu/model/indexing_map_test.cc | 152 +++++---- .../service/gpu/model/indexing_test_utils.h | 3 +- .../xla/service/gpu/model/symbolic_tile.cc | 83 ++--- .../xla/xla/service/gpu/model/symbolic_tile.h | 13 +- .../gpu/model/symbolic_tile_analysis.cc | 13 +- .../gpu/model/symbolic_tile_analysis.h | 4 +- .../model/symbolic_tiled_hlo_instruction.cc | 2 +- .../gpu/model/tiled_hlo_instruction.cc | 8 +- 36 files changed, 608 insertions(+), 866 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/model/affine_map_printer.cc delete mode 100644 third_party/xla/xla/service/gpu/model/affine_map_printer.h delete mode 100644 third_party/xla/xla/service/gpu/model/affine_map_printer_test.cc diff --git a/third_party/xla/xla/service/gpu/fusions/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD index 1d60d912a8345f..d7e85cc8c89cc0 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD @@ -135,7 +135,6 @@ cc_library( ":xla_gpu_ops_inc_gen", ":xla_gpu_types_inc_gen", "//xla/service/gpu/model:indexing_analysis", - "//xla/service/gpu/model:indexing_map_serialization", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BytecodeOpInterface", diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index 577ec1262970c6..5a2cd2b9a29584 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -73,7 +73,7 @@ mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { } void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { - printer << "<\"" << getIndexingMap().ToString() << "\">"; + printer << "<\"" << ToString(getIndexingMap()) << "\">"; } IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, @@ -135,7 +135,7 @@ mlir::Attribute LayoutAttr::parse(mlir::AsmParser& parser, mlir::Type) { void LayoutAttr::print(mlir::AsmPrinter& printer) const { printer << "<\"" << stringifyMemorySpace(getMemorySpace().getValue()) - << "\", \"" << getThreadMap().getIndexingMap().ToString() << "\">"; + << "\", \"" << ToString(getThreadMap().getIndexingMap()) << "\">"; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index 38c11e1ff8023a..a4724eb8b5c9f6 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -47,6 +47,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" namespace xla { namespace gpu { @@ -834,13 +835,13 @@ LogicalResult LoopOp::verify() { return emitOpError() << "mismatch in number of induction variables " << getNumInductionVars() << " and RangeVars in the indexing map " - << indexing_map.ToString(); + << ToString(indexing_map); } if (indexing_map.GetDimVarsCount() != getDims().size()) { return emitOpError() << "mismatch in number of dims operands " << getDims().size() << " and DimVars in the indexing map " - << indexing_map.ToString(); + << ToString(indexing_map); } for (auto [bb_arg, result_type, init] : llvm::zip(getRegionIterArgs(), getResultTypes(), getInits())) { diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD index 98d8ade7c5e5c3..8b9f3a34073441 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD @@ -38,7 +38,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -89,7 +89,7 @@ xla_cc_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -140,7 +140,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -290,7 +290,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -344,6 +344,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -395,7 +396,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc index 9a9bdc2dd488b2..32437d5bca3772 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc @@ -14,15 +14,13 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/legacy/concatenate.h" -#include - #include #include #include "mlir/IR/MLIRContext.h" #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -32,21 +30,12 @@ namespace gpu { namespace { class ConcatenateTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - protected: DebugOptions GetDebugOptionsForTest() override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; } - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; }; @@ -97,22 +86,22 @@ TEST_F(ConcatenateTest, ThreadIndexing) { is_simplified: true )"; EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_), + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(kIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_), + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(kIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_), + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(kIndexing)); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc index 53be6363567cdd..6bf9ea865e1c45 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -33,21 +33,12 @@ namespace gpu { namespace { class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - protected: DebugOptions GetDebugOptionsForTest() override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; } - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; stream_executor::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); @@ -83,7 +74,9 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { auto thread_id_update_indexing = fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_); - EXPECT_THAT(thread_id_update_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_update_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( th_x floordiv 6, th_x mod 6), diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc index 0c604502bd51d1..08fcc0d387c777 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -32,21 +32,12 @@ namespace gpu { namespace { class InputSlicesTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - protected: DebugOptions GetDebugOptionsForTest() override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; } - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; }; @@ -80,7 +71,9 @@ TEST_F(InputSlicesTest, ThreadIndexing) { auto thread_id_to_output_indexing = fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_to_output_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0, ((bl_x * 128 + th_x) floordiv 3) mod 2, diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc index 82a9de34c7cc49..60ae18e5cc6a17 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" @@ -36,19 +36,9 @@ namespace gpu { namespace { class LoopTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - protected: stream_executor::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; }; @@ -85,7 +75,9 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_to_output_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( (bl_x * 128 + th_x) floordiv 15000, @@ -128,7 +120,9 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_to_output_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x), domain: @@ -145,7 +139,9 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_to_input_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x), domain: @@ -183,7 +179,9 @@ TEST_F(LoopTest, Broadcast) { auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_to_output_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( (bl_x * 128 + th_x) floordiv 600, @@ -204,7 +202,9 @@ TEST_F(LoopTest, Broadcast) { auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_to_input_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (((bl_x * 128 + th_x) floordiv 30) mod 20), diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc index 144159ce442424..46c7a26970e538 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc @@ -19,15 +19,11 @@ limitations under the License. #include #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -73,7 +69,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { ReductionFusion fusion(analysis); EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + ToString(*fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( d3 floordiv 8, @@ -94,7 +90,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { is_simplified: true )")); EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + ToString(*fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> ( d3 floordiv 8, diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc index 8c6674d4a2b546..7381d375645660 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/fusions/legacy/scatter.h" #include +#include #include #include @@ -22,7 +23,7 @@ limitations under the License. #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -33,13 +34,6 @@ namespace gpu { namespace { class ScatterFusionTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id", "index_id"}); - } DebugOptions GetDebugOptionsForTest() override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); @@ -47,7 +41,6 @@ class ScatterFusionTest : public HloTestBase { } protected: - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; }; @@ -166,31 +159,31 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { bl_x * 128 + th_x in [0, 8399], is_simplified: true )"; + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector symbol_names = {"chunk_id", "unroll_id"}; EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_), + dim_names, symbol_names), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_), + dim_names, symbol_names), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_), + dim_names, symbol_names), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_), + dim_names, symbol_names), MatchIndexingString(kUpdatesIndexing)); + symbol_names.push_back("index_id"); constexpr auto kIndicesIndexing = R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ((bl_x * 128 + th_x) floordiv 200, 0), @@ -208,16 +201,14 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { is_simplified: true )"; EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_), + dim_names, symbol_names), MatchIndexingString(kIndicesIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_), + dim_names, symbol_names), MatchIndexingString(kIndicesIndexing)); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc index bba3d721368e5b..c66094061e6366 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" @@ -77,7 +78,7 @@ TEST_F(TransposeTest, ThreadIndexing021) { mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, @@ -98,7 +99,7 @@ TEST_F(TransposeTest, ThreadIndexing021) { is_simplified: true )")); EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, @@ -141,7 +142,7 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( 0, @@ -162,7 +163,7 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { is_simplified: true )")); EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( 0, @@ -207,7 +208,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d0 floordiv 32 + s0 * 4, @@ -228,7 +229,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { is_simplified: true )")); EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d0 floordiv 32 + s0 * 4, @@ -274,8 +275,8 @@ TEST_F(TransposeTest, SameInputIndexingForRealHeroAndSideOutput) { mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString()); + ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), + ToString(*fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context))); } TEST_F(TransposeTest, ThreadIndexingSideOutput) { @@ -305,7 +306,7 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { // Check if side output `%broadcast` get the correct input indexing, which // should corresponds to `%input1` with shape [100,32]. EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, @@ -325,7 +326,7 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { is_simplified: true )")); EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 292659c1916898..4c2340973f9f72 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -1627,9 +1627,9 @@ ValueRange EmitLoopNest(ImplicitLocOpBuilder& b, ValueRange dim_values, remainder.GetMutableSymbolBound(sym_index).lower = bound.upper; remainder.Simplify(); - VLOG(5) << "Peeled indexing map " << indexing_map.ToString() << "\n into " - << peeled_map.ToString() << "\nand remainder\n" - << remainder.ToString(); + VLOG(5) << "Peeled indexing map " << indexing_map << "\n into " + << peeled_map << "\nand remainder\n" + << remainder; return EmitLoopNestImpl(b, dim_values, first_results, remainder, create_body, vectorize); } diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc index 72529cd6545c4d..c812a24fab915c 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/service/gpu/fusions/tools/test_lib.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" @@ -71,7 +72,7 @@ absl::Status TestBijection(const IndexingMap& map, auto status = VerifyBijection(map, intervals); if (status.ok()) return status; return absl::FailedPreconditionError( - absl::StrCat(status.message(), " in map ", map.ToString())); + absl::StrCat(status.message(), " in map ", ToString(map))); } TEST_F(CorrectnessTest, RunAndCompare) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc index 9e95e3c3264239..63d9bb3d6923dc 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" namespace xla { namespace gpu { @@ -88,9 +89,9 @@ struct PeelLoop : public OpRewritePattern { tail_map.Simplify(); VLOG(5) << "Peeled indexing map\n" - << indexing_map.ToString() << "into\n" - << peeled_map.ToString() << "and\n" - << tail_map.ToString() << "\n"; + << ToString(indexing_map) << "into\n" + << ToString(peeled_map) << "and\n" + << ToString(tail_map) << "\n"; indexing_maps.pop_back(); indexing_maps.push_back(tail_map); indexing_maps.push_back(peeled_map); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index e3c95865b8e5db..4a16dfb0fe8dca 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -66,7 +66,6 @@ cc_library( "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/transforms:passes", "//xla/service/gpu/llvm_gpu_backend", - "//xla/service/gpu/model:affine_map_printer", "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index fccc5e66360eb8..f4d085671b2abe 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -413,32 +413,6 @@ xla_cc_test( ], ) -cc_library( - name = "affine_map_printer", - srcs = ["affine_map_printer.cc"], - hdrs = ["affine_map_printer.h"], - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], -) - -xla_cc_test( - name = "affine_map_printer_test", - srcs = ["affine_map_printer_test.cc"], - deps = [ - ":affine_map_printer", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:test", - ], -) - cc_library( name = "affine_map_evaluator", srcs = ["affine_map_evaluator.cc"], @@ -469,13 +443,14 @@ cc_library( srcs = [ "indexing_analysis.cc", "indexing_map.cc", + "indexing_map_serialization.cc", ], hdrs = [ "indexing_analysis.h", "indexing_map.h", + "indexing_map_serialization.h", ], deps = [ - ":affine_map_printer", "//xla:permutation_util", "//xla:shape_util", "//xla:util", @@ -491,9 +466,11 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:logging", @@ -504,9 +481,7 @@ xla_cc_test( name = "indexing_map_test", srcs = ["indexing_map_test.cc"], deps = [ - ":affine_map_printer", ":indexing_analysis", - ":indexing_map_serialization", ":indexing_test_utils", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", @@ -523,25 +498,11 @@ xla_cc_test( ], ) -cc_library( - name = "indexing_map_serialization", - srcs = ["indexing_map_serialization.cc"], - hdrs = ["indexing_map_serialization.h"], - deps = [ - ":indexing_analysis", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], -) - xla_cc_test( name = "indexing_map_serialization_test", srcs = ["indexing_map_serialization_test.cc"], deps = [ - ":indexing_map_serialization", + ":indexing_analysis", ":indexing_test_utils", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -602,7 +563,6 @@ cc_library( hdrs = ["symbolic_tile.h"], deps = [ ":affine_map_evaluator", - ":affine_map_printer", ":indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -728,7 +688,6 @@ cc_library( hdrs = ["symbolic_tile_analysis.h"], deps = [ ":affine_map_evaluator", - ":affine_map_printer", ":indexing_analysis", ":symbolic_tile", ":symbolic_tiled_hlo_instruction", diff --git a/third_party/xla/xla/service/gpu/model/affine_map_printer.cc b/third_party/xla/xla/service/gpu/model/affine_map_printer.cc deleted file mode 100644 index 83b68eca0473d8..00000000000000 --- a/third_party/xla/xla/service/gpu/model/affine_map_printer.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/affine_map_printer.h" - -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/Support/LLVM.h" - -namespace xla { -namespace gpu { -namespace { - -using mlir::AffineBinaryOpExpr; -using mlir::AffineConstantExpr; -using mlir::AffineDimExpr; -using mlir::AffineExpr; -using mlir::AffineExprKind; -using mlir::AffineMap; -using mlir::AffineSymbolExpr; - -} // namespace - -AffineMapPrinter::AffineMapPrinter( - absl::Span dim_names, - absl::Span symbol_names) { - dim_id_to_name_.reserve(dim_names.size()); - for (const auto& [index, name] : llvm::enumerate(dim_names)) { - dim_id_to_name_[index] = name; - } - symbol_id_to_name_.reserve(symbol_names.size()); - for (const auto& [index, name] : llvm::enumerate(symbol_names)) { - symbol_id_to_name_[index] = name; - } -} - -void AffineMapPrinter::Print(std::ostream& out, AffineMap affine_map) const { - out << ToString(affine_map); -} - -std::string AffineMapPrinter::ToString(AffineMap affine_map) const { - std::string s; - llvm::raw_string_ostream ss(s); - - if (dim_id_to_name_.empty() && symbol_id_to_name_.empty()) { - affine_map.print(ss); - return s; - } - // Dimension identifiers. - int dim_count = affine_map.getNumDims(); - ss << '('; - for (int i = 0; i < dim_count - 1; ++i) { - ss << GetDimensionName(i) << ", "; - } - if (dim_count >= 1) { - ss << GetDimensionName(dim_count - 1); - } - ss << ')'; - // Symbolic identifiers. - int symbol_count = affine_map.getNumSymbols(); - if (symbol_count != 0) { - ss << '['; - for (unsigned i = 0; i < symbol_count - 1; ++i) { - ss << GetSymbolName(i) << ", "; - } - if (affine_map.getNumSymbols() >= 1) { - ss << GetSymbolName(symbol_count - 1); - } - ss << ']'; - } - // Result affine expressions. - ss << " -> ("; - llvm::interleaveComma(affine_map.getResults(), ss, [&](AffineExpr expr) { - PrintExprImpl(expr, /*add_parentheses=*/false, ss); - }); - ss << ')'; - return s; -} - -void AffineMapPrinter::Print(std::ostream& out, - mlir::AffineExpr affine_expr) const { - out << ToString(affine_expr); -} - -std::string AffineMapPrinter::ToString(mlir::AffineExpr affine_expr) const { - std::string s; - llvm::raw_string_ostream ss(s); - PrintExprImpl(affine_expr, /*add_parentheses=*/false, ss); - return s; -} - -void AffineMapPrinter::PrintExprImpl(const mlir::AffineExpr affine_expr, - bool add_parentheses, - llvm::raw_ostream& os) const { - const char* binopSpelling = nullptr; - switch (affine_expr.getKind()) { - case AffineExprKind::SymbolId: { - unsigned symbol_id = - mlir::cast(affine_expr).getPosition(); - os << GetSymbolName(symbol_id); - return; - } - case AffineExprKind::DimId: { - unsigned dim_id = mlir::cast(affine_expr).getPosition(); - os << GetDimensionName(dim_id); - return; - } - case AffineExprKind::Constant: - os << mlir::cast(affine_expr).getValue(); - return; - case AffineExprKind::Add: - binopSpelling = " + "; - break; - case AffineExprKind::Mul: - binopSpelling = " * "; - break; - case AffineExprKind::FloorDiv: - binopSpelling = " floordiv "; - break; - case AffineExprKind::CeilDiv: - binopSpelling = " ceildiv "; - break; - case AffineExprKind::Mod: - binopSpelling = " mod "; - break; - } - - auto binOp = mlir::cast(affine_expr); - AffineExpr lhsExpr = binOp.getLHS(); - AffineExpr rhsExpr = binOp.getRHS(); - - // Handle tightly binding binary operators. - if (binOp.getKind() != AffineExprKind::Add) { - if (add_parentheses) { - os << '('; - } - - // Pretty print multiplication with -1. - auto rhsConst = mlir::dyn_cast(rhsExpr); - if (rhsConst && binOp.getKind() == AffineExprKind::Mul && - rhsConst.getValue() == -1) { - os << "-"; - PrintExprImpl(lhsExpr, /*add_parentheses=*/true, os); - if (add_parentheses) { - os << ')'; - } - return; - } - - PrintExprImpl(lhsExpr, /*add_parentheses=*/true, os); - - os << binopSpelling; - PrintExprImpl(rhsExpr, /*add_parentheses=*/true, os); - - if (add_parentheses) { - os << ')'; - } - return; - } - - // Print out special "pretty" forms for add. - if (add_parentheses) { - os << '('; - } - - // Pretty print addition to a product that has a negative operand as a - // subtraction. - if (auto rhs = mlir::dyn_cast(rhsExpr)) { - if (rhs.getKind() == AffineExprKind::Mul) { - AffineExpr rrhsExpr = rhs.getRHS(); - if (auto rrhs = mlir::dyn_cast(rrhsExpr)) { - if (rrhs.getValue() == -1) { - PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); - os << " - "; - if (rhs.getLHS().getKind() == AffineExprKind::Add) { - PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/true, os); - } else { - PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/false, os); - } - - if (add_parentheses) { - os << ')'; - } - return; - } - - if (rrhs.getValue() < -1) { - PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); - os << " - "; - PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/true, os); - os << " * " << -rrhs.getValue(); - if (add_parentheses) { - os << ')'; - } - return; - } - } - } - } - - // Pretty print addition to a negative number as a subtraction. - if (auto rhsConst = mlir::dyn_cast(rhsExpr)) { - if (rhsConst.getValue() < 0) { - PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); - os << " - " << -rhsConst.getValue(); - if (add_parentheses) { - os << ')'; - } - return; - } - } - - PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); - - os << " + "; - PrintExprImpl(rhsExpr, /*add_parentheses=*/false, os); - - if (add_parentheses) { - os << ')'; - } -} - -void AffineMapPrinter::SetSymbolName(int64_t symbol_id, llvm::StringRef name) { - symbol_id_to_name_[symbol_id] = name; -} - -void AffineMapPrinter::SetDimensionName(int64_t dim_id, llvm::StringRef name) { - dim_id_to_name_[dim_id] = name; -} - -std::string AffineMapPrinter::GetSymbolName(int64_t symbol_id) const { - auto it = symbol_id_to_name_.find(symbol_id); - if (it == symbol_id_to_name_.end()) { - return absl::StrCat("s", symbol_id); - } - return it->second; -} - -std::string AffineMapPrinter::GetDimensionName(int64_t dim_id) const { - auto it = dim_id_to_name_.find(dim_id); - if (it == dim_id_to_name_.end()) { - return absl::StrCat("d", dim_id); - } - return it->second; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/affine_map_printer.h b/third_party/xla/xla/service/gpu/model/affine_map_printer.h deleted file mode 100644 index bb1f6fbeb6902b..00000000000000 --- a/third_party/xla/xla/service/gpu/model/affine_map_printer.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ -#define XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ - -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" - -namespace xla { -namespace gpu { - -// AffineMapPrinter allows to "pretty print" mlir::AffineMap by setting custom -// symbol and dimension names. -class AffineMapPrinter { - public: - AffineMapPrinter() = default; - AffineMapPrinter(AffineMapPrinter&& other) = default; - AffineMapPrinter& operator=(AffineMapPrinter&& other) = default; - AffineMapPrinter(absl::Span dim_names, - absl::Span symbol_names); - - void SetSymbolName(int64_t symbol_id, llvm::StringRef name); - void SetDimensionName(int64_t dim_id, llvm::StringRef name); - - std::string GetSymbolName(int64_t symbol_id) const; - std::string GetDimensionName(int64_t dim_id) const; - - void Print(std::ostream& out, mlir::AffineMap affine_map) const; - std::string ToString(mlir::AffineMap affine_map) const; - - void Print(std::ostream& out, mlir::AffineExpr affine_expr) const; - std::string ToString(mlir::AffineExpr affine_expr) const; - - private: - void PrintExprImpl(mlir::AffineExpr affine_expr, bool add_parentheses, - llvm::raw_ostream& os) const; - - llvm::DenseMap dim_id_to_name_; - llvm::DenseMap symbol_id_to_name_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ diff --git a/third_party/xla/xla/service/gpu/model/affine_map_printer_test.cc b/third_party/xla/xla/service/gpu/model/affine_map_printer_test.cc deleted file mode 100644 index 01c6092b4d02c3..00000000000000 --- a/third_party/xla/xla/service/gpu/model/affine_map_printer_test.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/affine_map_printer.h" - -#include -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace gpu { -namespace { - -using ::mlir::AffineExpr; -using ::mlir::AffineMap; -using ::mlir::bindDims; -using ::mlir::bindSymbols; -using ::testing::HasSubstr; - -class IndexingMapTest : public HloTestBase { - public: - mlir::MLIRContext mlir_context_; - AffineMapPrinter printer_; -}; - -TEST_F(IndexingMapTest, AffineMapPrinterTest) { - AffineExpr d0, d1, s0, s1; - bindDims(&mlir_context_, d0, d1); - bindSymbols(&mlir_context_, s0, s1); - - // (d0, d1)[s0, s1] -> (d0 + d1 floordiv 8, s0 + s1 mod 16). - auto map = - AffineMap::get(2, 2, {d0 + d1.floorDiv(8), s0 + s1 % 16}, &mlir_context_); - - printer_.SetDimensionName(0, "offset"); - printer_.SetSymbolName(1, "linear_index"); - EXPECT_THAT(printer_.ToString(map), - HasSubstr("(offset, d1)[s0, linear_index] -> " - "(offset + d1 floordiv 8, s0 + linear_index mod 16)")); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 13a4a0ac6e4d54..f8ac967dd1ac68 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -274,7 +274,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion( auto element_type = instr->shape().element_type(); int64_t n_bytes_total = 0; for (const auto& indexing_map : indexing_maps) { - VLOG(10) << indexing_map.ToString(); + VLOG(10) << indexing_map; int64_t num_iters = GetIterationSpaceSize(indexing_map, instr); diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index ad9215c921b09f..9477826b0f801f 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -49,7 +49,6 @@ limitations under the License. #include "xla/service/gather_simplifier.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -1257,33 +1256,25 @@ HloInstructionIndexing HloInstructionIndexing::FromIndexingMaps( return instr_indexing; } -std::string HloInstructionIndexing::ToString( - const AffineMapPrinter& printer) const { - std::string s; - std::stringstream ss(s); - Print(ss, printer); +std::string HloInstructionIndexing::ToString() const { + std::stringstream ss; + ss << *this; return ss.str(); } -void HloInstructionIndexing::Print(std::ostream& out, - const AffineMapPrinter& printer) const { +std::ostream& operator<<(std::ostream& out, + const HloInstructionIndexing& instr_indexing) { for (const auto& [operand_id, indexing_maps] : - llvm::enumerate(indexing_maps)) { + llvm::enumerate(instr_indexing.indexing_maps)) { out << "operand id = " << operand_id << ' '; for (const auto& indexing_map : indexing_maps) { if (indexing_map.IsUndefined()) { out << "unknown indexing"; continue; } - indexing_map.Print(out, printer); + out << indexing_map; } } -} - -std::ostream& operator<<(std::ostream& out, - const HloInstructionIndexing& instr_indexing) { - AffineMapPrinter printer; - instr_indexing.Print(out, printer); return out; } diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h index d4c170aace2063..e05a598cc2a322 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h @@ -31,7 +31,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape.h" @@ -43,9 +42,7 @@ using IndexingMapSet = absl::flat_hash_set; // Contains indexing maps for all N-dimensional tensor input operands that // correspond to a particular output. struct HloInstructionIndexing { - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; - void Print(std::ostream& out, const AffineMapPrinter& printer) const; + std::string ToString() const; // Returns true if the indexing was simplified. bool Simplify(); diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index 440c06f44bb19b..b3f4043d73825f 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "tsl/platform/test.h" @@ -2679,10 +2680,9 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing) { HloInstructionAdaptor log(*computation->GetInstructionWithName("log"), fusion.get()); - EXPECT_THAT( - ComputeEpilogueInputToOutputIndexing(transpose, log, &mlir_context_) - .ToString(), - MatchIndexingString(R"( + EXPECT_THAT(ToString(ComputeEpilogueInputToOutputIndexing(transpose, log, + &mlir_context_)), + MatchIndexingString(R"( (d0, d1) -> (d1 * 1000 + d0), domain: d0 in [0, 999], @@ -2710,10 +2710,9 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { HloInstructionAdaptor transpose(*computation->GetInstructionWithName("t"), fusion.get()); - EXPECT_THAT( - ComputeEpilogueInputToOutputIndexing(transpose, transpose, &mlir_context_) - .ToString(), - MatchIndexingString(R"( + EXPECT_THAT(ToString(ComputeEpilogueInputToOutputIndexing( + transpose, transpose, &mlir_context_)), + MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 7add913b3e5942..c1964b351b35bc 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/numeric/int128.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -48,7 +49,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "tsl/platform/logging.h" // IWYU pragma: keep namespace xla { @@ -831,14 +831,21 @@ IndexingMap GetIndexingMapForInstruction(const HloInstruction* instr, int64_t operand_idx, mlir::MLIRContext* mlir_context); +std::ostream& operator<<(std::ostream& out, const Interval& interval) { + out << absl::StrFormat("[%d, %d]", interval.lower, interval.upper); + return out; +} + std::string Interval::ToString() const { std::stringstream ss; - Print(ss); + ss << *this; return ss.str(); } -void Interval::Print(std::ostream& out) const { - out << '[' << lower << ", " << upper << "]"; +inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const Interval& interval) { + os << absl::StrFormat("[%d, %d]", interval.lower, interval.upper); + return os; } int64_t Interval::GetLoopTripCount() const { @@ -958,11 +965,6 @@ Interval Interval::FloorDiv(int64_t rhs) const { return {std::min(a, b), std::max(a, b)}; } -std::ostream& operator<<(std::ostream& out, const Interval& range) { - range.Print(out); - return out; -} - bool operator==(const DimVar& lhs, const DimVar& rhs) { return lhs.bounds == rhs.bounds; } @@ -1256,77 +1258,10 @@ Interval RangeEvaluator::ComputeExpressionRange(AffineExpr expr) { return result; } -std::string IndexingMap::ToString(const AffineMapPrinter& printer) const { - std::stringstream ss; - Print(ss, printer); - return ss.str(); -} - -void PrintRTVars(const std::vector& rt_vars, - int first_rt_var_symbol_index, std::ostream& out, - const AffineMapPrinter& printer) { - for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { - out << printer.GetSymbolName( - static_cast(first_rt_var_symbol_index + index)) - << " in "; - rt_var.feasible_values.Print(out); - out << ", hlo: " - << (rt_var.hlo == nullptr ? "NULL" : rt_var.hlo->ToString()) << ", "; - printer.Print(out, rt_var.map); - out << ", "; - } -} - -void IndexingMap::Print(std::ostream& out, - const AffineMapPrinter& printer) const { - if (IsKnownEmpty()) { - out << "KNOWN EMPTY\n"; - return; - } - printer.Print(out, affine_map_); - if (dim_vars_.empty() && range_vars_.empty() && rt_vars_.empty()) { - return; - } - out << ", domain: "; - for (const auto& [index, dim_var] : llvm::enumerate(dim_vars_)) { - out << printer.GetDimensionName(static_cast(index)) << " in "; - dim_var.bounds.Print(out); - out << ", "; - } - for (const auto& [index, range_var] : llvm::enumerate(range_vars_)) { - out << printer.GetSymbolName(static_cast(index)) << " in "; - range_var.range.Print(out); - out << ", "; - } - int64_t range_vars_count = GetRangeVarsCount(); - PrintRTVars(rt_vars_, /*first_rt_var_symbol_index=*/range_vars_count, out, - printer); - std::vector expr_range_strings; - expr_range_strings.reserve(constraints_.size()); - for (const auto& [expr, range] : constraints_) { - std::stringstream ss; - printer.Print(ss, expr); - ss << " in "; - range.Print(ss); - expr_range_strings.push_back(ss.str()); - } - std::sort(expr_range_strings.begin(), expr_range_strings.end()); - for (const auto& expr_range_string : expr_range_strings) { - out << expr_range_string << ", "; - } - out << "is_simplified: " << (is_simplified_ ? "true" : "false"); -} - MLIRContext* IndexingMap::GetMLIRContext() const { return IsUndefined() ? nullptr : affine_map_.getContext(); } -std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { - AffineMapPrinter printer; - indexing_map.Print(out, printer); - return out; -} - bool operator==(const IndexingMap& lhs, const IndexingMap& rhs) { return lhs.GetAffineMap() == rhs.GetAffineMap() && lhs.GetDimVars() == rhs.GetDimVars() && diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 25d40abd47c3f1..81ca0e9d03588c 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -26,7 +26,6 @@ limitations under the License. #include #include -#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" @@ -37,7 +36,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/model/affine_map_printer.h" namespace xla { namespace gpu { @@ -61,8 +59,6 @@ std::ostream& operator<<(std::ostream& out, VariableKind var_type); // Interval represents a closed interval [lower_bound, upper_bound]. struct Interval { std::string ToString() const; - void Print(std::ostream& out) const; - bool IsPoint() const { return lower == upper; } bool IsFeasible() const { return lower <= upper; } @@ -161,12 +157,9 @@ struct Interval { int64_t upper = 0; }; -std::ostream& operator<<(std::ostream& out, const Interval& range); +std::ostream& operator<<(std::ostream& out, const Interval& interval); inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, - const Interval& interval) { - os << absl::StrFormat("[%d, %d]", interval.lower, interval.upper); - return os; -} + const Interval& interval); template H AbslHashValue(H h, const Interval& range) { @@ -324,11 +317,6 @@ class IndexingMap { absl::Span symbol_upper_bounds, bool is_simplified = false); - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; - - void Print(std::ostream& out, const AffineMapPrinter& printer) const; - // Returns true if the map was simplified. bool Simplify(); @@ -497,21 +485,6 @@ IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs); IndexingMap ComposeIndexingMaps(const IndexingMap& first, const IndexingMap& second); -// Prints the RTVars. -// -// This is exposed to allow SymbolicTile to reuse it. -// -// `first_rt_var_symbol_index`: The index of the symbol associated with the -// first RTVar. The RTVars will be printed with consequent symbol indices -// starting with `first_rt_var_symbol_index`. For example, if `rt_vars.size() -// == 3` and `first_rt_var_symbol_index == 4`, then the symbol names "s4", -// "s5" and "s6" will be used. -// -// TODO(b/334043862): Unexpose this function if possible. -void PrintRTVars(const std::vector& rt_vars, - int first_rt_var_symbol_index, std::ostream& out, - const AffineMapPrinter& printer); - template H AbslHashValue(H h, const IndexingMap& indexing_map) { llvm::hash_code affine_map_hash = diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc index 31e7bfaa53ec21..4e72a4b56dd94f 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc @@ -15,15 +15,20 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map_serialization.h" +#include #include #include #include +#include #include #include #include #include +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" @@ -43,8 +48,14 @@ namespace { using llvm::SmallVector; using llvm::SmallVectorImpl; using llvm::StringRef; +using mlir::AffineBinaryOpExpr; +using mlir::AffineConstantExpr; +using mlir::AffineDimExpr; using mlir::AffineExpr; +using mlir::AffineExprKind; using mlir::AffineMap; +using mlir::AffineMapAttr; +using mlir::AffineSymbolExpr; using mlir::ArrayRef; using mlir::MLIRContext; @@ -376,16 +387,172 @@ bool ParseAffineExprsWithMLIR(ArrayRef dim_var_names, llvm::errs() << "Failed to parse affine map: " << ss.str() << "\n"; return false; } - mlir::AffineMap affine_map = - mlir::cast(affine_map_attr).getValue(); + AffineMap affine_map = mlir::cast(affine_map_attr).getValue(); affine_exprs = llvm::to_vector(affine_map.getResults()); return true; } +std::string GetSymbolName(int64_t symbol_id, + absl::Span symbol_names = {}) { + if (symbol_names.empty()) { + return absl::StrCat("s", symbol_id); + } + return symbol_names.at(symbol_id); +} + +std::string GetDimensionName(int64_t dim_id, + absl::Span dim_names = {}) { + if (dim_names.empty()) { + return absl::StrCat("d", dim_id); + } + return dim_names.at(dim_id); +} + +void PrintAffineExprImpl(const AffineExpr affine_expr, + absl::Span dim_names, + absl::Span symbol_names, + bool add_parentheses, llvm::raw_ostream& os) { + const char* binopSpelling = nullptr; + switch (affine_expr.getKind()) { + case AffineExprKind::SymbolId: { + unsigned symbol_id = + mlir::cast(affine_expr).getPosition(); + os << GetSymbolName(symbol_id, symbol_names); + return; + } + case AffineExprKind::DimId: { + unsigned dim_id = mlir::cast(affine_expr).getPosition(); + os << GetDimensionName(dim_id, dim_names); + return; + } + case AffineExprKind::Constant: + os << mlir::cast(affine_expr).getValue(); + return; + case AffineExprKind::Add: + binopSpelling = " + "; + break; + case AffineExprKind::Mul: + binopSpelling = " * "; + break; + case AffineExprKind::FloorDiv: + binopSpelling = " floordiv "; + break; + case AffineExprKind::CeilDiv: + binopSpelling = " ceildiv "; + break; + case AffineExprKind::Mod: + binopSpelling = " mod "; + break; + } + + auto binOp = mlir::cast(affine_expr); + AffineExpr lhsExpr = binOp.getLHS(); + AffineExpr rhsExpr = binOp.getRHS(); + + // Handle tightly binding binary operators. + if (binOp.getKind() != AffineExprKind::Add) { + if (add_parentheses) { + os << '('; + } + + // Pretty print multiplication with -1. + auto rhsConst = mlir::dyn_cast(rhsExpr); + if (rhsConst && binOp.getKind() == AffineExprKind::Mul && + rhsConst.getValue() == -1) { + os << "-"; + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/true, os); + if (add_parentheses) { + os << ')'; + } + return; + } + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/true, os); + + os << binopSpelling; + PrintAffineExprImpl(rhsExpr, dim_names, symbol_names, + /*add_parentheses=*/true, os); + + if (add_parentheses) { + os << ')'; + } + return; + } + + // Print out special "pretty" forms for add. + if (add_parentheses) { + os << '('; + } + + // Pretty print addition to a product that has a negative operand as a + // subtraction. + if (auto rhs = mlir::dyn_cast(rhsExpr)) { + if (rhs.getKind() == AffineExprKind::Mul) { + AffineExpr rrhsExpr = rhs.getRHS(); + if (auto rrhs = mlir::dyn_cast(rrhsExpr)) { + if (rrhs.getValue() == -1) { + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + os << " - "; + if (rhs.getLHS().getKind() == AffineExprKind::Add) { + PrintAffineExprImpl(rhs.getLHS(), dim_names, symbol_names, + /*add_parentheses=*/true, os); + } else { + PrintAffineExprImpl(rhs.getLHS(), dim_names, symbol_names, + /*add_parentheses=*/false, os); + } + if (add_parentheses) { + os << ')'; + } + return; + } + + if (rrhs.getValue() < -1) { + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + os << " - "; + PrintAffineExprImpl(rhs.getLHS(), dim_names, symbol_names, + /*add_parentheses=*/true, os); + os << " * " << -rrhs.getValue(); + if (add_parentheses) { + os << ')'; + } + return; + } + } + } + } + + // Pretty print addition to a negative number as a subtraction. + if (auto rhsConst = mlir::dyn_cast(rhsExpr)) { + if (rhsConst.getValue() < 0) { + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + os << " - " << -rhsConst.getValue(); + if (add_parentheses) { + os << ')'; + } + return; + } + } + + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + + os << " + "; + PrintAffineExprImpl(rhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + + if (add_parentheses) { + os << ')'; + } +} + } // namespace std::optional ParseIndexingMap(llvm::StringRef input, - mlir::MLIRContext* context) { + MLIRContext* context) { Parser parser(input); // Parse variable names. @@ -512,5 +679,136 @@ std::optional ParseIndexingMap(llvm::StringRef input, constraints, is_simplified}; } +std::string ToString(AffineExpr affine_expr) { + return ToString(affine_expr, /*dim_names=*/{}, /*symbol_names=*/{}); +} + +std::ostream& operator<<(std::ostream& out, AffineExpr affine_expr) { + out << ToString(affine_expr); + return out; +} + +std::string ToString(AffineExpr affine_expr, + absl::Span dim_names, + absl::Span symbol_names) { + std::string s; + llvm::raw_string_ostream ss(s); + PrintAffineExprImpl(affine_expr, dim_names, symbol_names, + /*add_parentheses=*/false, ss); + return s; +} + +std::string ToString(AffineMap affine_map) { + int dim_count = affine_map.getNumDims(); + SmallVector dim_names; + dim_names.reserve(affine_map.getNumDims()); + for (int64_t dim_id = 0; dim_id < dim_count; ++dim_id) { + dim_names.push_back(GetDimensionName(dim_id)); + } + int symbol_count = affine_map.getNumSymbols(); + SmallVector symbol_names; + symbol_names.reserve(affine_map.getNumSymbols()); + for (int64_t symbol_id = 0; symbol_id < symbol_count; ++symbol_id) { + symbol_names.push_back(GetSymbolName(symbol_id)); + } + return ToString(affine_map, dim_names, symbol_names); +} + +std::ostream& operator<<(std::ostream& out, AffineMap affine_map) { + out << ToString(affine_map); + return out; +} + +std::string ToString(AffineMap affine_map, + absl::Span dim_names, + absl::Span symbol_names) { + CHECK_EQ(dim_names.size(), affine_map.getNumDims()); + CHECK_EQ(symbol_names.size(), affine_map.getNumSymbols()); + + std::string s; + llvm::raw_string_ostream ss(s); + + // Dimension identifiers. + ss << '(' << absl::StrJoin(dim_names, ", ") << ')'; + // Symbolic identifiers. + if (affine_map.getNumSymbols() != 0) { + ss << '[' << absl::StrJoin(symbol_names, ", ") << ']'; + } + // Result affine expressions. + ss << " -> ("; + llvm::interleaveComma(affine_map.getResults(), ss, [&](AffineExpr expr) { + PrintAffineExprImpl(expr, dim_names, symbol_names, + /*add_parentheses=*/false, ss); + }); + ss << ')'; + return s; +} + +std::string ToString(const IndexingMap& indexing_map) { + const auto& affine_map = indexing_map.GetAffineMap(); + int dim_count = affine_map.getNumDims(); + SmallVector dim_names; + dim_names.reserve(affine_map.getNumDims()); + for (int64_t dim_id = 0; dim_id < dim_count; ++dim_id) { + dim_names.push_back(GetDimensionName(dim_id)); + } + int symbol_count = affine_map.getNumSymbols(); + SmallVector symbol_names; + symbol_names.reserve(affine_map.getNumSymbols()); + for (int64_t symbol_id = 0; symbol_id < symbol_count; ++symbol_id) { + symbol_names.push_back(GetSymbolName(symbol_id)); + } + return ToString(indexing_map, dim_names, symbol_names); +} + +std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { + out << ToString(indexing_map); + return out; +} + +std::string ToString(const IndexingMap& indexing_map, + absl::Span dim_names, + absl::Span symbol_names) { + std::stringstream ss; + if (indexing_map.IsKnownEmpty()) { + ss << "KNOWN EMPTY\n"; + return ss.str(); + } + const auto& dim_vars = indexing_map.GetDimVars(); + const auto& range_vars = indexing_map.GetRangeVars(); + const auto& rt_vars = indexing_map.GetRTVars(); + ss << ToString(indexing_map.GetAffineMap(), dim_names, symbol_names); + if (dim_vars.empty() && range_vars.empty() && rt_vars.empty()) { + return ss.str(); + } + ss << ", domain: "; + for (const auto& [index, dim_var] : llvm::enumerate(dim_vars)) { + ss << dim_names[index] << " in " << dim_var.bounds << ", "; + } + for (const auto& [index, range_var] : llvm::enumerate(range_vars)) { + ss << symbol_names[index] << " in " << range_var.range << ", "; + } + int64_t num_range_vars = range_vars.size(); + for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { + ss << GetSymbolName(num_range_vars + index, symbol_names) << " in " + << rt_var.feasible_values << ", hlo: " + << (rt_var.hlo == nullptr ? "NULL" : rt_var.hlo->ToString()) << ", " + << ToString(rt_var.map) << ", "; + } + std::vector expr_range_strings; + const auto& constraints = indexing_map.GetConstraints(); + expr_range_strings.reserve(constraints.size()); + for (const auto& [expr, range] : constraints) { + expr_range_strings.push_back(absl::StrCat( + ToString(expr, dim_names, symbol_names), " in ", range.ToString())); + } + std::sort(expr_range_strings.begin(), expr_range_strings.end()); + for (const auto& expr_range_string : expr_range_strings) { + ss << expr_range_string << ", "; + } + ss << "is_simplified: " << (indexing_map.IsSimplified() ? "true" : "false"); + return ss.str(); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h index ce09a90e65a36d..f308bb16182862 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h @@ -17,8 +17,13 @@ limitations under the License. #define XLA_SERVICE_GPU_MODEL_INDEXING_MAP_SERIALIZATION_H_ #include +#include +#include +#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/MLIRContext.h" #include "xla/service/gpu/model/indexing_map.h" @@ -29,6 +34,38 @@ namespace gpu { std::optional ParseIndexingMap(llvm::StringRef input, mlir::MLIRContext* context); +// Prints AffineExpr using the default (d0, d1, ..., s0, s1, ...) variable +// names. +std::string ToString(mlir::AffineExpr affine_expr); + +// Prints AffineExpr using the provided variable names. +std::string ToString(mlir::AffineExpr affine_expr, + absl::Span dim_names, + absl::Span symbol_names); + +std::ostream& operator<<(std::ostream& out, mlir::AffineExpr affine_expr); + +// Prints AffineMap using the default (d0, d1, ..., s0, s1, ...) variable names. +std::string ToString(mlir::AffineMap affine_map); + +// Prints AffineMap using the provided variable names. +std::string ToString(mlir::AffineMap affine_map, + absl::Span dim_names, + absl::Span symbol_names); + +std::ostream& operator<<(std::ostream& out, mlir::AffineMap affine_map); + +// Prints IndexingMap using the default (d0, d1, ..., s0, s1, ...) variable +// names. +std::string ToString(const IndexingMap& indexing_map); + +// Prints IndexingMap using the provided variable names. +std::string ToString(const IndexingMap& indexing_map, + absl::Span dim_names, + absl::Span symbol_names); + +std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc index c7d39e8c8690fc..28a7f7b60b4ac8 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include "absl/strings/string_view.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/MLIRContext.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/tests/hlo_test_base.h" @@ -27,14 +29,15 @@ namespace xla { namespace gpu { namespace { +using ::testing::HasSubstr; + class IndexingMapSerializationTest : public HloTestBase { public: mlir::MLIRContext mlir_context_; void ParseAndCheck(absl::string_view indexing_map_str) { auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); ASSERT_TRUE(indexing_map.has_value()); - EXPECT_THAT(indexing_map->ToString(), - MatchIndexingString(indexing_map_str)); + EXPECT_THAT(ToString(*indexing_map), MatchIndexingString(indexing_map_str)); } }; @@ -130,10 +133,22 @@ TEST_F(IndexingMapSerializationTest, CustomNames) { )"; auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); ASSERT_TRUE(indexing_map.has_value()); - EXPECT_THAT(indexing_map->ToString(), + EXPECT_THAT(ToString(*indexing_map), MatchIndexingString(indexing_map_golden)); } +TEST_F(IndexingMapSerializationTest, AffineMapPrinterTest) { + mlir::AffineExpr d0, d1, s0, s1; + mlir::bindDims(&mlir_context_, d0, d1); + mlir::bindSymbols(&mlir_context_, s0, s1); + + // (d0, d1)[s0, s1] -> (d0 + d1 floordiv 8, s0 + s1 mod 16). + auto map = mlir::AffineMap::get(2, 2, {d0 + d1.floorDiv(8), s0 + s1 % 16}, + &mlir_context_); + EXPECT_THAT(ToString(map, {"offset", "d1"}, {"s0", "linear_index"}), + HasSubstr("(offset, d1)[s0, linear_index] -> " + "(offset + d1 floordiv 8, s0 + linear_index mod 16)")); +} } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index c3b8669f1b46c7..e5fe84b49bfe6c 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "mlir/IR/AffineMap.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/tests/hlo_test_base.h" @@ -58,7 +57,6 @@ class IndexingMapTest : public HloTestBase { } mlir::MLIRContext mlir_context_; - AffineMapPrinter printer_; }; std::vector ConvertToSTL(const llvm::SmallBitVector& bit_vector) { @@ -89,7 +87,8 @@ TEST_F(IndexingMapTest, VariableKind) { } TEST_F(IndexingMapTest, RTVar) { - auto zero_dim_map = AffineMap::get(&mlir_context_); + auto zero_dim_map = + AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, &mlir_context_); std::vector rt_vars{RTVar{Interval{0, 2}, /*instr=*/nullptr, zero_dim_map}, RTVar({Interval{0, 7}, @@ -100,10 +99,8 @@ TEST_F(IndexingMapTest, RTVar) { &mlir_context_), {DimVar{{0, 99}}, DimVar{{0, 43}}}, {RangeVar{{-99, 99}}}, std::move(rt_vars)); - printer_.SetSymbolName(0, "range"); - printer_.SetSymbolName(1, "rt_0"); - printer_.SetSymbolName(2, "rt_1"); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map, {"d0", "d1"}, {"range", "rt_0", "rt_1"}), + MatchIndexingString(R"( (d0, d1)[range, rt_0, rt_1] -> (d1, d0, range + rt_0, rt_0), domain: d0 in [0, 99], @@ -111,10 +108,10 @@ TEST_F(IndexingMapTest, RTVar) { range in [-99, 99], rt_0 in [0, 2], hlo: NULL, - () -> (), + (d0, d1) -> (), rt_1 in [0, 7], hlo: NULL, - () -> (), + (d0, d1) -> (), is_simplified: false )")); } @@ -263,7 +260,8 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { } TEST_F(IndexingMapTest, Composition_RTVar) { - auto zero_dim_map = AffineMap::get(&mlir_context_); + auto zero_dim_map = + AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, &mlir_context_); std::vector rt_vars{ RTVar{Interval{0, 0}, /*instr=*/nullptr, zero_dim_map}, @@ -279,13 +277,10 @@ TEST_F(IndexingMapTest, Composition_RTVar) { IndexingMap consumer( ParseAffineMap("(d0, d1)[s0] -> (0, d1, s0)", &mlir_context_), {DimVar{{0, 0}}, DimVar{{0, 1}}}, {RangeVar{0, 31}}, {}); - printer_.SetSymbolName(0, "s"); - printer_.SetSymbolName(1, "rt_0"); - printer_.SetSymbolName(2, "rt_1"); - printer_.SetSymbolName(3, "rt_2"); auto composed = ComposeIndexingMaps(consumer, producer); - EXPECT_THAT(composed.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(composed, {"d0", "d1"}, {"s", "rt_0", "rt_1", "rt_2"}), + MatchIndexingString(R"( (d0, d1)[s, rt_0, rt_1, rt_2] -> (rt_0, d1 + rt_1, s + rt_2), domain: d0 in [0, 0], @@ -293,19 +288,20 @@ TEST_F(IndexingMapTest, Composition_RTVar) { s in [0, 31], rt_0 in [0, 0], hlo: NULL, - () -> (), + (d0, d1) -> (), rt_1 in [0, 1], hlo: NULL, - () -> (), + (d0, d1) -> (), rt_2 in [0, 226], hlo: NULL, - () -> (), + (d0, d1) -> (), is_simplified: false )")); } TEST_F(IndexingMapTest, Composition_OnlyRTVars) { - auto zero_dim_map = AffineMap::get(&mlir_context_); + auto zero_dim_map = + AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, &mlir_context_); IndexingMap producer( ParseAffineMap("(d0, d1)[s0, s1] -> (d0 + s0, d1 + 4 * s1)", @@ -322,13 +318,10 @@ TEST_F(IndexingMapTest, Composition_OnlyRTVars) { {RTVar({Interval{0, 25}, /*instr=*/nullptr, zero_dim_map}), RTVar({Interval{0, 16}, /*instr=*/nullptr, zero_dim_map})}); - printer_.SetSymbolName(0, "ps_0"); - printer_.SetSymbolName(1, "ps_1"); - printer_.SetSymbolName(2, "cs_0"); - printer_.SetSymbolName(3, "cs_1"); - auto composed = ComposeIndexingMaps(consumer, producer); - EXPECT_THAT(composed.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT( + ToString(composed, {"d0", "d1"}, {"ps_0", "ps_1", "cs_0", "cs_1"}), + MatchIndexingString(R"( (d0, d1)[ps_0, ps_1, cs_0, cs_1] -> (d0 + cs_0 * 2 + ps_0, d1 + cs_1 * 3 + ps_1 * 4), domain: @@ -336,16 +329,16 @@ TEST_F(IndexingMapTest, Composition_OnlyRTVars) { d1 in [0, 15], ps_0 in [0, 2], hlo: NULL, - () -> (), + (d0, d1) -> (), ps_1 in [0, 1], hlo: NULL, - () -> (), + (d0, d1) -> (), cs_0 in [0, 25], hlo: NULL, - () -> (), + (d0, d1) -> (), cs_1 in [0, 16], hlo: NULL, - () -> (), + (d0, d1) -> (), d0 + cs_0 * 2 in [0, 24], d1 + cs_1 * 3 in [0, 15], is_simplified: false @@ -604,7 +597,8 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { - auto zero_dim_map = AffineMap::get(&mlir_context_); + auto zero_dim_map = + AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, &mlir_context_); IndexingMap indexing_map( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", &mlir_context_), @@ -624,7 +618,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { s0 in [0, 1], s1 in [0, 3], hlo: NULL, - () -> (), + (d0) -> (), d0 * 4 + s0 + s1 in [24, 459], is_simplified: false )")); @@ -664,7 +658,7 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: d0 in [0, 99], @@ -685,7 +679,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), domain: d0 in [0, 99], @@ -719,7 +713,7 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 * 6 + s0 * 3), domain: d0 in [0, 99], @@ -738,7 +732,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: d0 in [40, 95], @@ -757,7 +751,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], @@ -777,7 +771,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], @@ -796,7 +790,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: d0 in [2, 4], @@ -815,7 +809,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], @@ -835,7 +829,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], @@ -858,7 +852,7 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1] -> (d0, s1, s0), domain: d0 in [0, 3], @@ -879,7 +873,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (5), domain: d0 in [5, 5], @@ -927,7 +921,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 floordiv 6), domain: d0 in [0, 11], @@ -944,7 +938,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0 - 42), domain: d0 in [53, 71], @@ -960,7 +954,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0 + 5), domain: d0 in [-5, -1], @@ -986,7 +980,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + s0 mod 3), domain: d0 in [0, 1], @@ -1004,7 +998,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 mod 3) * 4 + s0 * 3), domain: d0 in [0, 1], @@ -1022,7 +1016,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 + 1) mod 3), domain: d0 in [0, 1], @@ -1041,7 +1035,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 7], @@ -1062,7 +1056,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1, d2) -> (d0, d1, d2), domain: d0 in [0, 8], @@ -1084,7 +1078,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, (d1 * 4 + d2) mod 8), domain: @@ -1105,7 +1099,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 7], @@ -1122,7 +1116,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (s0 * 128), domain: s0 in [0, 127], @@ -1140,7 +1134,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { )"); ; EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 128 + d1), domain: d0 in [0, 1023], @@ -1159,7 +1153,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 4 + d1 * 512), domain: d0 in [0, 127], @@ -1177,7 +1171,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_FALSE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> ((-d0) mod 2), domain: d0 in [0, 127], @@ -1200,7 +1194,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 512 + d1 * 4), domain: d0 in [0, 3071], @@ -1218,7 +1212,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715), domain: s0 in [0, 127], @@ -1234,7 +1228,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (s0), domain: s0 in [0, 1233], @@ -1251,7 +1245,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0, s1] -> ((s0 * 128 + s1) floordiv 192), domain: s0 in [0, 1233], @@ -1268,7 +1262,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> ((s0 * 2 + 3) floordiv 6), domain: s0 in [0, 1233], @@ -1312,7 +1306,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0, s1, s2, s3] -> ( ((s0 * 114688 + s3 * 128 + s2) mod 5000) * 4 + s1 ), @@ -1336,7 +1330,7 @@ TEST_F(IndexingMapTest, is_simplified: false )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0, s1] -> ( s0 * 4 + s1 floordiv 32 ), @@ -1359,7 +1353,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { is_simplified: false )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), domain: d0 in [0, 3], @@ -1384,7 +1378,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { // [BEFORE] Allowed values for s0: 3, 9, 15, ..., 39 = (6 * 6 + 3) // [AFTER] Allowed values for s0: 0, 1, 2, ..., 6 EXPECT_TRUE(indexing_map.RescaleSymbols()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3), domain: d0 in [0, 3], @@ -1408,7 +1402,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { is_simplified: false )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), domain: d0 in [0, 3], @@ -1432,7 +1426,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { is_simplified: false )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3), domain: d0 in [0, 3], @@ -1687,8 +1681,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ScalarConstant) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), - MatchIndexingString("() -> (42)")); + EXPECT_THAT(ToString(indexing_map), MatchIndexingString("() -> (42)")); } TEST_F(IndexingMapTest, ReplaceConstantRTVars_StaticIndexIntoTensorConstant) { @@ -1714,8 +1707,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_StaticIndexIntoTensorConstant) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), - MatchIndexingString("() -> (13)")); + EXPECT_THAT(ToString(indexing_map), MatchIndexingString("() -> (13)")); } TEST_F(IndexingMapTest, ReplaceConstantRTVars_NonFoldableTensor) { @@ -1764,7 +1756,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, d0), domain: d0 in [0, 255], @@ -1795,7 +1787,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, 7), domain: d0 in [0, 255], @@ -1828,7 +1820,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, d0), domain: d0 in [0, 254], @@ -1863,7 +1855,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, 11), domain: d0 in [0, 31], @@ -1906,7 +1898,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, (d0 floordiv 12) * -4 + 8), domain: d0 in [0, 35], @@ -1940,7 +1932,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 23], @@ -1979,7 +1971,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Add) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, d0 * 2 + 42), domain: d0 in [0, 11], @@ -2020,7 +2012,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Multiply) { EXPECT_TRUE(indexing_map.Simplify()); indexing_map.RemoveUnusedSymbols(); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, (-d0 + 11) * d0), domain: d0 in [0, 11], @@ -2057,7 +2049,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0, d0 * 2 + s0), domain: d0 in [0, 11], diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h index c4bb2910fa36cd..2e880208f19486 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -49,7 +50,7 @@ MATCHER_P(MatchIndexingMap, indexing_string, "") { return false; } return ExplainMatchResult( - true, ApproximateMatch(indexing_string, arg.ToString()), result_listener); + true, ApproximateMatch(indexing_string, ToString(arg)), result_listener); } MATCHER_P(MatchIndexingString, indexing_string, "") { diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index f2e14e0c655bc0..612b97146e01e0 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" @@ -41,8 +42,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" #include "xla/service/gpu/model/affine_map_evaluator.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" namespace xla { namespace gpu { @@ -231,8 +232,7 @@ ExtractSizesAndStridesFromMultivariateSummation( std::optional maybe_size_and_stride = ExtractSizeAndStride(summand, dimension_intervals, symbol_intervals); if (!maybe_size_and_stride.has_value()) { - VLOG(1) << "Couldn't extract size and stride from " - << AffineMapPrinter().ToString(summand); + VLOG(1) << "Couldn't extract size and stride from " << ToString(summand); return std::nullopt; } sizes_and_strides.push_back(*maybe_size_and_stride); @@ -320,8 +320,8 @@ std::optional TryGetSizeExpressionRangeSize( // working well with concatenations. Nevertheless, we can take a look // later. VLOG(1) << "Attempted to combine strides but got dimension " - << AffineMapPrinter().ToString(size) << " with lower bound " - << interval.lower << " != 0"; + << ToString(size) << " with lower bound " << interval.lower + << " != 0"; return std::nullopt; } // We need to add 1 to the upper bound of the interval to describe the @@ -364,7 +364,7 @@ std::optional CombineStrides( for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) { if (size_and_stride.stride.getKind() != AffineExprKind::Constant) { VLOG(1) << "Attempted to combine non-constant stride: " - << AffineMapPrinter().ToString(size_and_stride.stride); + << ToString(size_and_stride.stride); return std::nullopt; } @@ -379,7 +379,7 @@ std::optional CombineStrides( size_and_stride.size.getKind() != AffineExprKind::DimId) { VLOG(1) << "Attempted to combine strides but got non-constant, " "non-dimension size " - << AffineMapPrinter().ToString(size_and_stride.size); + << ToString(size_and_stride.size); return std::nullopt; } } @@ -567,9 +567,8 @@ std::optional CombineSizesAndStrides( if (VLOG_IS_ON(1)) { for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) { LOG(INFO) << "CombineSizesAndStrides:"; - LOG(INFO) << "size: " << AffineMapPrinter().ToString(size_and_stride.size) - << " stride: " - << AffineMapPrinter().ToString(size_and_stride.stride); + LOG(INFO) << "size: " << ToString(size_and_stride.size) + << " stride: " << ToString(size_and_stride.stride); } } @@ -603,7 +602,6 @@ std::optional ExtractSizeAndStride( AffineExpr strided_indexing, absl::Span dimension_intervals, absl::Span symbol_intervals) { MLIRContext* ctx = strided_indexing.getContext(); - AffineMapPrinter printer; switch (strided_indexing.getKind()) { case AffineExprKind::DimId: @@ -711,9 +709,8 @@ std::optional TryIntersectConjointConstraints( auto& [result_expr, result_interval] = *result_it; result_interval = result_interval.Intersect(interval); if (!result_interval.IsFeasible()) { - AffineMapPrinter printer; VLOG(1) << "Got two incompatible intervals for expression " - << printer.ToString(expr); + << ToString(expr); return std::nullopt; } } else { @@ -866,15 +863,13 @@ bool ConstraintExpression::IsSatisfiedBy( return constraints_are_satisfied; } -std::string ConstraintExpression::ToString( - const AffineMapPrinter& printer) const { +std::string ConstraintExpression::ToString() const { std::stringstream ss; - Print(ss, printer); + Print(ss); return ss.str(); } -void ConstraintExpression::Print(std::ostream& out, - const AffineMapPrinter& printer) const { +void ConstraintExpression::Print(std::ostream& out) const { if (IsAlwaysSatisfied()) { out << "always satisfied"; } else if (is_satisfiable()) { @@ -886,11 +881,8 @@ void ConstraintExpression::Print(std::ostream& out, std::vector constraint_strings; constraint_strings.reserve(disjunction.size()); for (const auto& [expr, interval] : disjunction) { - std::stringstream ss; - printer.Print(ss, expr); - ss << " in "; - interval.Print(ss); - constraint_strings.push_back(ss.str()); + constraint_strings.push_back(absl::StrCat(xla::gpu::ToString(expr), + " in ", interval.ToString())); } std::sort(constraint_strings.begin(), constraint_strings.end()); conjunction_strings.push_back(absl::StrJoin(constraint_strings, " && ")); @@ -1019,7 +1011,7 @@ void ConstraintExpression::Simplify() { /*static*/ std::optional SymbolicTile::FromIndexingMap( IndexingMap indexing_map) { - VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map.ToString(); + VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map; // We do not handle indexing maps with pre-existing constraints for now. // Let's try to simplify the indexing map, because the constraints my be @@ -1030,7 +1022,7 @@ void ConstraintExpression::Simplify() { if (indexing_map.GetConstraintsCount() != 0) { VLOG(1) << "Deriving symbolic tile from indexing map with pre-existing " << "constraints might produce spurious constraints. Bailing out. " - << indexing_map.ToString(); + << indexing_map; return std::nullopt; } @@ -1104,9 +1096,8 @@ void ConstraintExpression::Simplify() { offset = offset + size * stride - stride; stride = -stride; } else if (!constant) { - AffineMapPrinter printer; VLOG(1) << "Unexpected non-constant stride expression: " - << printer.ToString(stride); + << xla::gpu::ToString(stride); } } @@ -1139,45 +1130,35 @@ void ConstraintExpression::Simplify() { /*rt_vars=*/indexing_map.GetRTVars()); tile_map.RemoveUnusedSymbols(); CHECK_EQ(tile_map.GetRangeVarsCount(), 0); - VLOG(1) << "tile_map: " << tile_map.ToString(); + VLOG(1) << "tile_map: " << tile_map; constraints.Simplify(); return SymbolicTile(std::move(tile_map), std::move(constraints)); } -std::string SymbolicTile::RtVarsToString( - const AffineMapPrinter& printer) const { - std::string s; - std::stringstream ss(s); - PrintRTVars(tile_map_.GetRTVars(), /*first_rt_var_symbol_index=*/0, ss, - printer); - return ss.str(); -} - -std::string SymbolicTile::ToString(const AffineMapPrinter& printer) const { - std::string s; - std::stringstream ss(s); - Print(ss, printer); +std::string SymbolicTile::ToString() const { + std::stringstream ss; + Print(ss); return ss.str(); } -void SymbolicTile::Print(std::ostream& out, - const AffineMapPrinter& printer) const { +void SymbolicTile::Print(std::ostream& out) const { out << "Symbolic tile with \n"; - out << "\toffset_map: "; - printer.Print(out, offset_map()); - out << "\n\tsize_map: "; - printer.Print(out, size_map()); - out << "\n\tstride_map: "; - printer.Print(out, stride_map()); + out << "\toffset_map: " << offset_map(); + out << "\n\tsize_map: " << size_map(); + out << "\n\tstride_map: " << stride_map(); const std::vector& rt_vars = tile_map_.GetRTVars(); if (!rt_vars.empty()) { out << "\n\trt_vars: "; - PrintRTVars(rt_vars, /*first_rt_var_symbol_index=*/0, out, printer); + for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { + out << 's' << index << " in " << rt_var.feasible_values << ", hlo: " + << (rt_var.hlo == nullptr ? "NULL" : rt_var.hlo->ToString()) << ", " + << rt_var.map << ", "; + } } if (!constraints_.IsAlwaysSatisfied()) { out << "\n\tconstraints: "; - constraints_.Print(out, printer); + constraints_.Print(out); } } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.h b/third_party/xla/xla/service/gpu/model/symbolic_tile.h index a86cd363daf1e8..c8e19f27112e4e 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.h @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -110,10 +109,9 @@ class ConstraintExpression { return disjoint_conjoint_constraints_; } - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; + std::string ToString() const; - void Print(std::ostream& out, const AffineMapPrinter& printer) const; + void Print(std::ostream& out) const; // Simplifies the constraint expression. // @@ -285,12 +283,9 @@ class SymbolicTile { static std::optional FromIndexingMap(IndexingMap indexing_map); // For printing in tests. - std::string RtVarsToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; + std::string ToString() const; - void Print(std::ostream& out, const AffineMapPrinter& printer) const; + void Print(std::ostream& out) const; mlir::AffineMap offset_map() const; mlir::AffineMap size_map() const; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index a3f2e77c97656c..43358a5781ce1d 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -50,9 +50,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/symbolic_tile.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" @@ -132,7 +132,7 @@ absl::StatusOr ComputeTileOffsetIndexing( })) { return absl::FailedPreconditionError( absl::StrCat("Symbol lower bound is not zero. ", - tiled_hlo.indexing_map().ToString())); + ToString(tiled_hlo.indexing_map()))); } std::vector symbol_lower_bounds( @@ -255,7 +255,7 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation( if (!reshape_symbolic_tile.has_value()) { return FusionDecision::Forbid("Bailing out on reshape ") << hlo->ToString() << " with indexing map " - << reshape_indexing_map.ToString(); + << ToString(reshape_indexing_map); } } @@ -292,13 +292,13 @@ SetSymbolicTilesAndComputeConstraints( auto symbolic_tile = SymbolicTile::FromIndexingMap(indexing_map); if (!symbolic_tile.has_value()) { return FusionDecision::Forbid("Failed to compute symbolic tile for ") - << indexing_map.ToString() << " for HLO " << hlo->ToString(); + << ToString(indexing_map) << " for HLO " << hlo->ToString(); } if (!symbolic_tile->is_satisfiable()) { return FusionDecision::Forbid("Symbolic tile ") << symbolic_tile->ToString() << " is not satisfiable for " - << indexing_map.ToString() << " for HLO " << hlo->ToString(); + << ToString(indexing_map) << " for HLO " << hlo->ToString(); } constraints = ConstraintExpression::And(std::move(constraints), @@ -592,8 +592,7 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( output_tiling_info.num_output_tiles_per_dim); } -std::string SymbolicTileAnalysis::ToString( - const AffineMapPrinter& printer) const { +std::string SymbolicTileAnalysis::ToString() const { std::stringstream ss; NameUniquer name_uniquer("_"); absl::flat_hash_map tile_names; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index 775de1670f51ea..e8e0cea2fef4b8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -31,7 +31,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/symbolic_tile.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" @@ -141,8 +140,7 @@ class SymbolicTileAnalysis { // Returns a string representation of the analysis. Used only for error // messages and debugging. - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; + std::string ToString() const; // Returns a list of tilings for the symbolic tiled HLO computation of the // analysis that are expected to perform well. diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc b/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc index 4a6c067638cebd..fb687372b9d0be 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc @@ -49,7 +49,7 @@ std::string SymbolicTiledHloInstruction::ToString() const { std::stringstream ss; ss << "\thlo: " << hlo_->ToString() << "\n"; ss << "\t" << symbolic_tile().ToString() << "\n"; - ss << "\tindexing map: " << indexing_map_.ToString() << "\n"; + ss << "\tindexing map: " << indexing_map_ << "\n"; return ss.str(); } diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc index e68db3040c816a..997556007fbcb6 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -67,7 +68,7 @@ absl::Status VerifyTiledHloInstructionConstructorPreconditions( return absl::InvalidArgumentError(absl::StrFormat( "tile_offsets_indexing must have the same number of results as the " "rank of the hlo shape. tile_offsets_indexing = %s, hlo = %s", - tile_offsets_indexing->ToString(), hlo->ToString())); + ToString(*tile_offsets_indexing), hlo->ToString())); } return absl::OkStatus(); @@ -97,8 +98,9 @@ std::string TiledHloInstruction::ToString() const { ss << "\ttile_sizes: (" << absl::StrJoin(tile_sizes_, ", ") << ")\n"; ss << "\ttile_strides: (" << absl::StrJoin(tile_strides_, ", ") << ")\n"; ss << "\ttile_offsets_indexing: " - << (tile_offsets_indexing_.has_value() ? tile_offsets_indexing_->ToString() - : "nullopt"); + << (tile_offsets_indexing_.has_value() + ? gpu::ToString(*tile_offsets_indexing_) + : "nullopt"); return ss.str(); } From 39560bac6bf0f1444f58c9cb7d8678930e95da60 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 26 Sep 2024 06:44:24 -0700 Subject: [PATCH 315/483] [XLA:GPU] Do not fuse custom fusions in horizontal_input_fusion. PiperOrigin-RevId: 679128773 --- .../gpu/transforms/horizontal_input_fusion.cc | 3 +- .../horizontal_input_fusion_test.cc | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc index befe869ac072df..bb09aa02c77de5 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc @@ -95,7 +95,8 @@ std::vector FindAndSortFusionCandidates( // Find out the input fusion instructions whose only consumer is `consumer`. // This guarantees that fusing these candidates will never create cycles, as // there is no back edge. - if (IsInputFusibleReduction(*predecessor) && + if (!predecessor->IsCustomFusion() && + IsInputFusibleReduction(*predecessor) && IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) { if (fusion_instr_set.insert(predecessor).second) { fusion_instrs.push_back(predecessor); diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc index 5fc1a54acd8d53..dc8f5f3bfa1b5f 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc @@ -265,6 +265,38 @@ TEST_F(HorizontalInputFusionTest, NonfusionInstrs) { GmockMatch(m::Tuple(m::Reduce(), m::Reduce()))); } +TEST_F(HorizontalInputFusionTest, DoesNotFuseCustomFusions) { + auto module = ParseAndReturnVerifiedModule(R"( +max { + p0 = f16[] parameter(0) + p1 = f16[] parameter(1) + ROOT max = f16[] maximum(p0, p1) +} + +triton_a { + p = f16[128,256] parameter(0) + c = f16[] constant(0) + ROOT n = f16[128] reduce(p, c), dimensions={1}, to_apply=max +} + +triton_b { + p = f16[128,256] parameter(0) + c = f16[] constant(0) + ROOT n = f16[128] reduce(p, c), dimensions={1}, to_apply=max +} + + ENTRY entry_computation { + p = f16[128,256] parameter(0) + fa = f16[128] fusion(p), kind=kCustom, calls=triton_a + fb = f16[128] fusion(p), kind=kCustom, calls=triton_b + ROOT tuple = (f16[128], f16[128]) tuple(fa, fb) + } +)") + .value(); + + EXPECT_FALSE(horizontal_input_fusion_.Run(module.get()).value()); +} + } // namespace } // namespace gpu } // namespace xla From 1d22e69e361183e05071a94ebc327ac0ef0ed020 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 26 Sep 2024 06:44:34 -0700 Subject: [PATCH 316/483] Move `has_backend_config` check to parent inliner class. PiperOrigin-RevId: 679128822 --- third_party/xla/xla/service/call_inliner.cc | 1 + third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/call_inliner.cc b/third_party/xla/xla/service/call_inliner.cc index 0605fbd6457ff7..a879e560c1cd17 100644 --- a/third_party/xla/xla/service/call_inliner.cc +++ b/third_party/xla/xla/service/call_inliner.cc @@ -160,6 +160,7 @@ CallInliner::Inline(HloInstruction* call) { bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return instruction->opcode() == HloOpcode::kCall && + !instruction->has_backend_config() && !instruction->parent()->IsAsyncComputation(); } diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc index 9f863e23a6715d..c7564dbf3ed140 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc @@ -23,7 +23,6 @@ namespace xla { bool ShardyCallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return CallInliner::IsInlineableCallOp(instruction) && - !instruction->has_backend_config() && !(instruction->GetModule()->config().use_shardy_partitioner() && absl::StrContains(instruction->to_apply()->name(), "shmap_body")); } From 7d3ab91b8e2bdf936d7cf2ef785a04e78733f34f Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 26 Sep 2024 06:51:32 -0700 Subject: [PATCH 317/483] PR #17580: Algebraic simplifier: optimize comparisons of all non-negative instructions to zero. Imported from GitHub PR https://github.com/openxla/xla/pull/17580 PR stacked with https://github.com/openxla/xla/pull/17579 Copybara import of the project: -- 02c09a8dd5bb62ffd3729a23813a0e66f672a5a3 by Ilia Sergachev : Algebraic simplifier: mark iota non-negative. -- 4735edc2bac278ea1e87035f128a2f5d0f2a7a59 by Ilia Sergachev : Fix unrelated clang-format issues to make CI happy -- 94947974244caa09eff280647491872207be144e by Ilia Sergachev : Algebraic simplifier: optimize comparisons of all non-negative instructions to zero. Merging this change closes #17580 PiperOrigin-RevId: 679130659 --- .../xla/xla/service/algebraic_simplifier.cc | 8 ++++---- .../xla/xla/service/algebraic_simplifier_test.cc | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index ff9a2f688cc874..8e392d6a972c91 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -5096,16 +5096,16 @@ absl::Status AlgebraicSimplifierVisitor::HandleCompare( } if (compare->comparison_direction() == ComparisonDirection::kLt && - lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { + IsNonNegative(lhs, options_) && IsAll(rhs, 0)) { return ReplaceInstruction(compare, MakeScalarLike(compare, false)); } else if (compare->comparison_direction() == ComparisonDirection::kGt && - IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) { + IsAll(lhs, 0) && IsNonNegative(rhs, options_)) { return ReplaceInstruction(compare, MakeScalarLike(compare, false)); } else if (compare->comparison_direction() == ComparisonDirection::kGe && - lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { + IsNonNegative(lhs, options_) && IsAll(rhs, 0)) { return ReplaceInstruction(compare, MakeScalarLike(compare, true)); } else if (compare->comparison_direction() == ComparisonDirection::kLe && - IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) { + IsAll(lhs, 0) && IsNonNegative(rhs, options_)) { return ReplaceInstruction(compare, MakeScalarLike(compare, true)); } if (lhs == rhs && diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index ea67d07a141967..430f445d7fa8c2 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -8985,6 +8985,21 @@ TEST_F(AlgebraicSimplifierTest, CompareIota) { GmockMatch(m::Broadcast(m::ConstantScalar(false)))); } +TEST_F(AlgebraicSimplifierTest, CompareAbsLtZeroBecomesFalse) { + // |x| < 0 -> false + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(R"( +m { + p = s32[5] parameter(0) + a = s32[5] abs(p) + z = s32[] constant(0) + b = s32[5] broadcast(z) + ROOT r = pred[5] compare(a, b), direction=LT +})")); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::ConstantScalar(false)))); +} + TEST_F(AlgebraicSimplifierTest, CompareLtZero) { const char* kModuleStr = R"( HloModule m From c87444d00d87064005dcf6af72a7ba380d1ef2a9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 07:12:05 -0700 Subject: [PATCH 318/483] Integrate LLVM at llvm/llvm-project@29b92d07746f Updates LLVM usage to match [29b92d07746f](https://github.com/llvm/llvm-project/commit/29b92d07746f) PiperOrigin-RevId: 679135781 --- third_party/llvm/generated.patch | 4094 -------- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 8202 ++++++++--------- third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 8202 ++++++++--------- .../xla/third_party/shardy/workspace.bzl | 4 +- 6 files changed, 8208 insertions(+), 12302 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index de92cb4da63e52..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,4095 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst ---- a/llvm/docs/NVPTXUsage.rst -+++ b/llvm/docs/NVPTXUsage.rst -@@ -127,6 +127,69 @@ - NVPTX Intrinsics - ================ - -+Address Space Conversion -+------------------------ -+ -+'``llvm.nvvm.ptr.*.to.gen``' Intrinsics -+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -+ -+Syntax: -+""""""" -+ -+These are overloaded intrinsics. You can use these on any pointer types. -+ -+.. code-block:: llvm -+ -+ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) -+ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) -+ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) -+ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) -+ -+Overview: -+""""""""" -+ -+The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic -+address space to a generic address space pointer. -+ -+Semantics: -+"""""""""" -+ -+These intrinsics modify the pointer value to be a valid generic address space -+pointer. -+ -+ -+'``llvm.nvvm.ptr.gen.to.*``' Intrinsics -+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -+ -+Syntax: -+""""""" -+ -+These are overloaded intrinsics. You can use these on any pointer types. -+ -+.. code-block:: llvm -+ -+ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -+ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) -+ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) -+ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) -+ -+Overview: -+""""""""" -+ -+The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic -+address space to a pointer in the target address space. Note that these -+intrinsics are only useful if the address space of the target address space of -+the pointer is known. It is not legal to use address space conversion -+intrinsics to convert a pointer from one non-generic address space to another -+non-generic address space. -+ -+Semantics: -+"""""""""" -+ -+These intrinsics modify the pointer value to be a valid pointer in the target -+non-generic address space. -+ -+ - Reading PTX Special Registers - ----------------------------- - -diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst ---- a/llvm/docs/ReleaseNotes.rst -+++ b/llvm/docs/ReleaseNotes.rst -@@ -63,24 +63,6 @@ - * ``llvm.nvvm.bitcast.d2ll`` - * ``llvm.nvvm.bitcast.ll2d`` - --* Remove the following intrinsics which can be replaced with a funnel-shift: -- -- * ``llvm.nvvm.rotate.b32`` -- * ``llvm.nvvm.rotate.right.b64`` -- * ``llvm.nvvm.rotate.b64`` -- --* Remove the following intrinsics which can be replaced with an -- ``addrspacecast``: -- -- * ``llvm.nvvm.ptr.gen.to.global`` -- * ``llvm.nvvm.ptr.gen.to.shared`` -- * ``llvm.nvvm.ptr.gen.to.constant`` -- * ``llvm.nvvm.ptr.gen.to.local`` -- * ``llvm.nvvm.ptr.global.to.gen`` -- * ``llvm.nvvm.ptr.shared.to.gen`` -- * ``llvm.nvvm.ptr.constant.to.gen`` -- * ``llvm.nvvm.ptr.local.to.gen`` -- - Changes to LLVM infrastructure - ------------------------------ - -diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td ---- a/llvm/include/llvm/IR/IntrinsicsNVVM.td -+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td -@@ -30,18 +30,10 @@ - // * llvm.nvvm.max.ui --> select(x ule y, x, y) - // * llvm.nvvm.max.ull --> ibid. - // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 --// * llvm.nvvm.bitcast.f2i --> bitcast --// * llvm.nvvm.bitcast.i2f --> ibid. --// * llvm.nvvm.bitcast.d2ll --> ibid. --// * llvm.nvvm.bitcast.ll2d --> ibid. --// * llvm.nvvm.ptr.gen.to.global --> addrspacecast --// * llvm.nvvm.ptr.gen.to.shared --> ibid. --// * llvm.nvvm.ptr.gen.to.constant --> ibid. --// * llvm.nvvm.ptr.gen.to.local --> ibid. --// * llvm.nvvm.ptr.global.to.gen --> ibid. --// * llvm.nvvm.ptr.shared.to.gen --> ibid. --// * llvm.nvvm.ptr.constant.to.gen --> ibid. --// * llvm.nvvm.ptr.local.to.gen --> ibid. -+// * llvm.nvvm.bitcast.f2i --> bitcast -+// * llvm.nvvm.bitcast.i2f --> ibid. -+// * llvm.nvvm.bitcast.d2ll --> ibid. -+// * llvm.nvvm.bitcast.ll2d --> ibid. - - def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr - def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr -@@ -1610,6 +1602,40 @@ - [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], - "llvm.nvvm.ldg.global.p">; - -+// Use for generic pointers -+// - These intrinsics are used to convert address spaces. -+// - The input pointer and output pointer must have the same type, except for -+// the address-space. (This restriction is not enforced here as there is -+// currently no way to describe it). -+// - This complements the llvm bitcast, which can be used to cast one type -+// of pointer to another type of pointer, while the address space remains -+// the same. -+def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.local.to.gen">; -+def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.shared.to.gen">; -+def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.global.to.gen">; -+def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.constant.to.gen">; -+ -+def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.gen.to.global">; -+def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.gen.to.shared">; -+def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.gen.to.local">; -+def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.gen.to.constant">; -+ - // Used in nvvm internally to help address space opt and ptx code generation - // This is for params that are passed to kernel functions by pointer by-val. - def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], -@@ -4453,6 +4479,22 @@ - "llvm.nvvm.sust.p.3d.v4i32.trap">, - ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; - -+ -+def int_nvvm_rotate_b32 -+ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], -+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, -+ ClangBuiltin<"__nvvm_rotate_b32">; -+ -+def int_nvvm_rotate_b64 -+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], -+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, -+ ClangBuiltin<"__nvvm_rotate_b64">; -+ -+def int_nvvm_rotate_right_b64 -+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], -+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, -+ ClangBuiltin<"__nvvm_rotate_right_b64">; -+ - def int_nvvm_swap_lo_hi_b64 - : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], - [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, -diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp ---- a/llvm/lib/IR/AutoUpgrade.cpp -+++ b/llvm/lib/IR/AutoUpgrade.cpp -@@ -1272,19 +1272,6 @@ - // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} - Expand = - Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; -- else if (Name.consume_front("rotate.")) -- // nvvm.rotate.{b32,b64,right.b64} -- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; -- else if (Name.consume_front("ptr.gen.to.")) -- // nvvm.ptr.gen.to.{local,shared,global,constant} -- Expand = Name.starts_with("local") || Name.starts_with("shared") || -- Name.starts_with("global") || Name.starts_with("constant"); -- else if (Name.consume_front("ptr.")) -- // nvvm.ptr.{local,shared,global,constant}.to.gen -- Expand = -- (Name.consume_front("local") || Name.consume_front("shared") || -- Name.consume_front("global") || Name.consume_front("constant")) && -- Name.starts_with(".to.gen"); - else - Expand = false; - -@@ -2271,117 +2258,6 @@ - } - } - --static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, -- Function *F, IRBuilder<> &Builder) { -- Value *Rep = nullptr; -- -- if (Name == "abs.i" || Name == "abs.ll") { -- Value *Arg = CI->getArgOperand(0); -- Value *Neg = Builder.CreateNeg(Arg, "neg"); -- Value *Cmp = Builder.CreateICmpSGE( -- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); -- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); -- } else if (Name.starts_with("atomic.load.add.f32.p") || -- Name.starts_with("atomic.load.add.f64.p")) { -- Value *Ptr = CI->getArgOperand(0); -- Value *Val = CI->getArgOperand(1); -- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), -- AtomicOrdering::SequentiallyConsistent); -- } else if (Name.consume_front("max.") && -- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -- Name == "ui" || Name == "ull")) { -- Value *Arg0 = CI->getArgOperand(0); -- Value *Arg1 = CI->getArgOperand(1); -- Value *Cmp = Name.starts_with("u") -- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") -- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); -- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); -- } else if (Name.consume_front("min.") && -- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -- Name == "ui" || Name == "ull")) { -- Value *Arg0 = CI->getArgOperand(0); -- Value *Arg1 = CI->getArgOperand(1); -- Value *Cmp = Name.starts_with("u") -- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") -- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); -- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); -- } else if (Name == "clz.ll") { -- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. -- Value *Arg = CI->getArgOperand(0); -- Value *Ctlz = Builder.CreateCall( -- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, -- {Arg->getType()}), -- {Arg, Builder.getFalse()}, "ctlz"); -- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); -- } else if (Name == "popc.ll") { -- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an -- // i64. -- Value *Arg = CI->getArgOperand(0); -- Value *Popc = Builder.CreateCall( -- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, -- {Arg->getType()}), -- Arg, "ctpop"); -- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); -- } else if (Name == "h2f") { -- Rep = Builder.CreateCall( -- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, -- {Builder.getFloatTy()}), -- CI->getArgOperand(0), "h2f"); -- } else if (Name.consume_front("bitcast.") && -- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || -- Name == "d2ll")) { -- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); -- } else if (Name == "rotate.b32") { -- Value *Arg = CI->getOperand(0); -- Value *ShiftAmt = CI->getOperand(1); -- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, -- {Arg, Arg, ShiftAmt}); -- } else if (Name == "rotate.b64") { -- Type *Int64Ty = Builder.getInt64Ty(); -- Value *Arg = CI->getOperand(0); -- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); -- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, -- {Arg, Arg, ZExtShiftAmt}); -- } else if (Name == "rotate.right.b64") { -- Type *Int64Ty = Builder.getInt64Ty(); -- Value *Arg = CI->getOperand(0); -- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); -- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, -- {Arg, Arg, ZExtShiftAmt}); -- } else if ((Name.consume_front("ptr.gen.to.") && -- (Name.starts_with("local") || Name.starts_with("shared") || -- Name.starts_with("global") || Name.starts_with("constant"))) || -- (Name.consume_front("ptr.") && -- (Name.consume_front("local") || Name.consume_front("shared") || -- Name.consume_front("global") || -- Name.consume_front("constant")) && -- Name.starts_with(".to.gen"))) { -- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); -- } else { -- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); -- if (IID != Intrinsic::not_intrinsic && -- !F->getReturnType()->getScalarType()->isBFloatTy()) { -- rename(F); -- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); -- SmallVector Args; -- for (size_t I = 0; I < NewFn->arg_size(); ++I) { -- Value *Arg = CI->getArgOperand(I); -- Type *OldType = Arg->getType(); -- Type *NewType = NewFn->getArg(I)->getType(); -- Args.push_back( -- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) -- ? Builder.CreateBitCast(Arg, NewType) -- : Arg); -- } -- Rep = Builder.CreateCall(NewFn, Args); -- if (F->getReturnType()->isIntegerTy()) -- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); -- } -- } -- -- return Rep; --} -- - static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, - IRBuilder<> &Builder) { - LLVMContext &C = F->getContext(); -@@ -4332,8 +4208,85 @@ - - if (!IsX86 && Name == "stackprotectorcheck") { - Rep = nullptr; -+ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { -+ Value *Arg = CI->getArgOperand(0); -+ Value *Neg = Builder.CreateNeg(Arg, "neg"); -+ Value *Cmp = Builder.CreateICmpSGE( -+ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); -+ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); -+ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || -+ Name.starts_with("atomic.load.add.f64.p"))) { -+ Value *Ptr = CI->getArgOperand(0); -+ Value *Val = CI->getArgOperand(1); -+ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), -+ AtomicOrdering::SequentiallyConsistent); -+ } else if (IsNVVM && Name.consume_front("max.") && -+ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -+ Name == "ui" || Name == "ull")) { -+ Value *Arg0 = CI->getArgOperand(0); -+ Value *Arg1 = CI->getArgOperand(1); -+ Value *Cmp = Name.starts_with("u") -+ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") -+ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); -+ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); -+ } else if (IsNVVM && Name.consume_front("min.") && -+ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -+ Name == "ui" || Name == "ull")) { -+ Value *Arg0 = CI->getArgOperand(0); -+ Value *Arg1 = CI->getArgOperand(1); -+ Value *Cmp = Name.starts_with("u") -+ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") -+ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); -+ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); -+ } else if (IsNVVM && Name == "clz.ll") { -+ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. -+ Value *Arg = CI->getArgOperand(0); -+ Value *Ctlz = Builder.CreateCall( -+ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, -+ {Arg->getType()}), -+ {Arg, Builder.getFalse()}, "ctlz"); -+ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); -+ } else if (IsNVVM && Name == "popc.ll") { -+ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an -+ // i64. -+ Value *Arg = CI->getArgOperand(0); -+ Value *Popc = Builder.CreateCall( -+ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, -+ {Arg->getType()}), -+ Arg, "ctpop"); -+ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); - } else if (IsNVVM) { -- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); -+ if (Name == "h2f") { -+ Rep = -+ Builder.CreateCall(Intrinsic::getDeclaration( -+ F->getParent(), Intrinsic::convert_from_fp16, -+ {Builder.getFloatTy()}), -+ CI->getArgOperand(0), "h2f"); -+ } else if (Name.consume_front("bitcast.") && -+ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || -+ Name == "d2ll")) { -+ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); -+ } else { -+ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); -+ if (IID != Intrinsic::not_intrinsic && -+ !F->getReturnType()->getScalarType()->isBFloatTy()) { -+ rename(F); -+ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); -+ SmallVector Args; -+ for (size_t I = 0; I < NewFn->arg_size(); ++I) { -+ Value *Arg = CI->getArgOperand(I); -+ Type *OldType = Arg->getType(); -+ Type *NewType = NewFn->getArg(I)->getType(); -+ Args.push_back((OldType->isIntegerTy() && -+ NewType->getScalarType()->isBFloatTy()) -+ ? Builder.CreateBitCast(Arg, NewType) -+ : Arg); -+ } -+ Rep = Builder.CreateCall(NewFn, Args); -+ if (F->getReturnType()->isIntegerTy()) -+ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); -+ } -+ } - } else if (IsX86) { - Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); - } else if (IsARM) { -diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp ---- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -@@ -292,7 +292,6 @@ - static const LLT S224 = LLT::scalar(224); - static const LLT S256 = LLT::scalar(256); - static const LLT S512 = LLT::scalar(512); --static const LLT S1024 = LLT::scalar(1024); - static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); - - static const LLT V2S8 = LLT::fixed_vector(2, 8); -@@ -333,8 +332,8 @@ - static const LLT V2S128 = LLT::fixed_vector(2, 128); - static const LLT V4S128 = LLT::fixed_vector(4, 128); - --static std::initializer_list AllScalarTypes = { -- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; -+static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, -+ S160, S224, S256, S512}; - - static std::initializer_list AllS16Vectors{ - V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; -@@ -890,11 +889,10 @@ - .clampScalar(0, S16, S64); - - getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) -- .legalIf(isRegisterClassType(0)) -+ .legalIf(isRegisterType(0)) - // s1 and s16 are special cases because they have legal operations on - // them, but don't really occupy registers in the normal way. - .legalFor({S1, S16}) -- .clampNumElements(0, V16S32, V32S32) - .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) - .clampScalarOrElt(0, S32, MaxScalar) - .widenScalarToNextPow2(0, 32) -diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ---- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -@@ -174,6 +174,10 @@ - def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" - "&& Subtarget->getPTXVersion() >= 64)">; - -+def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; -+def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; -+def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; -+ - def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; - def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; - -@@ -1661,6 +1665,167 @@ - "brev.b64 \t$dst, $a;", - [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; - -+// -+// Rotate: Use ptx shf instruction if available. -+// -+ -+// 32 bit r2 = rotl r1, n -+// => -+// r2 = shf.l r1, r1, n -+def ROTL32imm_hw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), -+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, -+ Requires<[hasHWROT32]>; -+ -+def ROTL32reg_hw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -+ Requires<[hasHWROT32]>; -+ -+// 32 bit r2 = rotr r1, n -+// => -+// r2 = shf.r r1, r1, n -+def ROTR32imm_hw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), -+ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, -+ Requires<[hasHWROT32]>; -+ -+def ROTR32reg_hw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -+ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -+ Requires<[hasHWROT32]>; -+ -+// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. -+def ROT32imm_sw : -+ NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), -+ "{{\n\t" -+ ".reg .b32 %lhs;\n\t" -+ ".reg .b32 %rhs;\n\t" -+ "shl.b32 \t%lhs, $src, $amt1;\n\t" -+ "shr.b32 \t%rhs, $src, $amt2;\n\t" -+ "add.u32 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ []>; -+ -+def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); -+}]>; -+ -+def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), -+ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, -+ Requires<[noHWROT32]>; -+def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), -+ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, -+ Requires<[noHWROT32]>; -+ -+// 32-bit software rotate left by register. -+def ROTL32reg_sw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -+ "{{\n\t" -+ ".reg .b32 %lhs;\n\t" -+ ".reg .b32 %rhs;\n\t" -+ ".reg .b32 %amt2;\n\t" -+ "shl.b32 \t%lhs, $src, $amt;\n\t" -+ "sub.s32 \t%amt2, 32, $amt;\n\t" -+ "shr.b32 \t%rhs, $src, %amt2;\n\t" -+ "add.u32 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -+ Requires<[noHWROT32]>; -+ -+// 32-bit software rotate right by register. -+def ROTR32reg_sw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -+ "{{\n\t" -+ ".reg .b32 %lhs;\n\t" -+ ".reg .b32 %rhs;\n\t" -+ ".reg .b32 %amt2;\n\t" -+ "shr.b32 \t%lhs, $src, $amt;\n\t" -+ "sub.s32 \t%amt2, 32, $amt;\n\t" -+ "shl.b32 \t%rhs, $src, %amt2;\n\t" -+ "add.u32 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -+ Requires<[noHWROT32]>; -+ -+// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. -+def ROT64imm_sw : -+ NVPTXInst<(outs Int64Regs:$dst), -+ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), -+ "{{\n\t" -+ ".reg .b64 %lhs;\n\t" -+ ".reg .b64 %rhs;\n\t" -+ "shl.b64 \t%lhs, $src, $amt1;\n\t" -+ "shr.b64 \t%rhs, $src, $amt2;\n\t" -+ "add.u64 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ []>; -+ -+def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); -+}]>; -+ -+def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), -+ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; -+def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), -+ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; -+ -+// 64-bit software rotate left by register. -+def ROTL64reg_sw : -+ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), -+ "{{\n\t" -+ ".reg .b64 %lhs;\n\t" -+ ".reg .b64 %rhs;\n\t" -+ ".reg .u32 %amt2;\n\t" -+ "and.b32 \t%amt2, $amt, 63;\n\t" -+ "shl.b64 \t%lhs, $src, %amt2;\n\t" -+ "sub.u32 \t%amt2, 64, %amt2;\n\t" -+ "shr.b64 \t%rhs, $src, %amt2;\n\t" -+ "add.u64 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; -+ -+def ROTR64reg_sw : -+ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), -+ "{{\n\t" -+ ".reg .b64 %lhs;\n\t" -+ ".reg .b64 %rhs;\n\t" -+ ".reg .u32 %amt2;\n\t" -+ "and.b32 \t%amt2, $amt, 63;\n\t" -+ "shr.b64 \t%lhs, $src, %amt2;\n\t" -+ "sub.u32 \t%amt2, 64, %amt2;\n\t" -+ "shl.b64 \t%rhs, $src, %amt2;\n\t" -+ "add.u64 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; -+ -+// -+// Funnnel shift in clamp mode -+// -+ -+// Create SDNodes so they can be used in the DAG code, e.g. -+// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) -+def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; -+def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; -+ -+def FUNSHFLCLAMP : -+ NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", -+ [(set Int32Regs:$dst, -+ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; -+ -+def FUNSHFRCLAMP : -+ NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", -+ [(set Int32Regs:$dst, -+ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; - - // - // BFE - bit-field extract -@@ -3492,42 +3657,6 @@ - def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), - (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; - --// --// Funnel-Shift --// -- --// Create SDNodes so they can be used in the DAG code, e.g. --// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) --def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; --def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; -- --// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so --// no side effects. --let hasSideEffects = false in { -- multiclass ShfInst { -- def _i -- : NVPTXInst<(outs Int32Regs:$dst), -- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", -- [(set Int32Regs:$dst, -- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, -- Requires<[hasHWROT32]>; -- -- def _r -- : NVPTXInst<(outs Int32Regs:$dst), -- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", -- [(set Int32Regs:$dst, -- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, -- Requires<[hasHWROT32]>; -- } -- -- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; -- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; -- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; -- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; --} -- - // Count leading zeros - let hasSideEffects = false in { - def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), -diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td ---- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -@@ -2537,45 +2537,59 @@ - : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; - - --multiclass NG_TO_G { -+multiclass NG_TO_G { - def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), -- "cvta." # Str # ".u32 \t$result, $src;", []>; -+ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), -+ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; - def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), -- "cvta." # Str # ".u64 \t$result, $src;", []>; -+ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), -+ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; -+ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), -+ "{{ .reg .b64 %tmp;\n\t" -+ #" cvt.u64.u32 \t%tmp, $src;\n\t" -+ #" cvta." # Str # ".u64 \t$result, %tmp; }}", -+ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, -+ Requires<[ShortPtr]>; - } - --multiclass G_TO_NG { -+multiclass G_TO_NG { - def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), -- "cvta.to." # Str # ".u32 \t$result, $src;", []>; -+ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), -+ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; - def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), -- "cvta.to." # Str # ".u64 \t$result, $src;", []>; -+ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), -+ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; -+ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), -+ "{{ .reg .b64 %tmp;\n\t" -+ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" -+ #" cvt.u32.u64 \t$result, %tmp; }}", -+ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, -+ Requires<[ShortPtr]>; - } - --defm cvta_local : NG_TO_G<"local">; --defm cvta_shared : NG_TO_G<"shared">; --defm cvta_global : NG_TO_G<"global">; --defm cvta_const : NG_TO_G<"const">; -- --defm cvta_to_local : G_TO_NG<"local">; --defm cvta_to_shared : G_TO_NG<"shared">; --defm cvta_to_global : G_TO_NG<"global">; --defm cvta_to_const : G_TO_NG<"const">; -- --// nvvm.ptr.param.to.gen --defm cvta_param : NG_TO_G<"param">; -- --def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), -- (cvta_param Int32Regs:$src)>; -- --def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), -- (cvta_param_64 Int64Regs:$src)>; -+defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; -+defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; -+defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; -+defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; -+defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; -+ -+defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; -+defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; -+defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; -+defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; - - // nvvm.ptr.gen.to.param --def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), -- (IMOV32rr Int32Regs:$src)>; -+def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), -+ (ins Int32Regs:$src), -+ "mov.u32 \t$result, $src;", -+ [(set Int32Regs:$result, -+ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; -+def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), -+ (ins Int64Regs:$src), -+ "mov.u64 \t$result, $src;", -+ [(set Int64Regs:$result, -+ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; - --def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), -- (IMOV64rr Int64Regs:$src)>; - - // nvvm.move intrinsicc - def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), -@@ -2618,6 +2632,24 @@ - [(set Int64Regs:$r, - (int_nvvm_move_ptr texternalsym:$s))]>;*/ - -+ -+// MoveParam %r1, param -+// ptr_local_to_gen %r2, %r1 -+// ptr_gen_to_local %r3, %r2 -+// -> -+// mov %r1, param -+ -+// @TODO: Revisit this. There is a type -+// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym -+// instructions are not currently defined. However, we can use the ptr -+// variants and the asm printer will do the right thing. -+def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen -+ (MoveParam texternalsym:$src)))), -+ (nvvm_move_ptr64 texternalsym:$src)>; -+def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen -+ (MoveParam texternalsym:$src)))), -+ (nvvm_move_ptr32 texternalsym:$src)>; -+ - def texsurf_handles - : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), - "mov.u64 \t$result, $src;", []>; -@@ -2701,9 +2733,134 @@ - def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; - - -+// rotate builtin support -+ -+def ROTATE_B32_HW_IMM -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$src, i32imm:$amt), -+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, -+ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, -+ Requires<[hasHWROT32]> ; -+ -+def ROTATE_B32_HW_REG -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$src, Int32Regs:$amt), -+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, -+ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, -+ Requires<[hasHWROT32]> ; -+ -+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), -+ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, -+ Requires<[noHWROT32]> ; -+ -+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), -+ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, -+ Requires<[noHWROT32]> ; -+ -+let hasSideEffects = false in { -+ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), -+ !strconcat("{{\n\t", -+ ".reg .b32 %dummy;\n\t", -+ "mov.b64 \t{$dst,%dummy}, $src;\n\t", -+ "}}"), -+ []> ; -+ -+ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), -+ !strconcat("{{\n\t", -+ ".reg .b32 %dummy;\n\t", -+ "mov.b64 \t{%dummy,$dst}, $src;\n\t", -+ "}}"), -+ []> ; -+} -+ -+let hasSideEffects = false in { -+ def PACK_TWO_INT32 -+ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), -+ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; -+} -+ - def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), -- (V2I32toI64 (I64toI32H Int64Regs:$src), -- (I64toI32L Int64Regs:$src))> ; -+ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src))> ; -+ -+// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so -+// no side effects. -+let hasSideEffects = false in { -+ def SHF_L_WRAP_B32_IMM -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -+ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -+ Requires<[hasHWROT32]>; -+ -+ def SHF_L_WRAP_B32_REG -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -+ Requires<[hasHWROT32]>; -+ -+ def SHF_R_WRAP_B32_IMM -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -+ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -+ Requires<[hasHWROT32]>; -+ -+ def SHF_R_WRAP_B32_REG -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -+ Requires<[hasHWROT32]>; -+} -+ -+// HW version of rotate 64 -+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), -+ (PACK_TWO_INT32 -+ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src), imm:$amt), -+ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), -+ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, -+ Requires<[hasHWROT32]>; -+ -+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), -+ (PACK_TWO_INT32 -+ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), -+ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), -+ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, -+ Requires<[hasHWROT32]>; -+ -+ -+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), -+ (PACK_TWO_INT32 -+ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), -+ (GET_HI_INT64 Int64Regs:$src), imm:$amt), -+ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, -+ Requires<[hasHWROT32]>; -+ -+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), -+ (PACK_TWO_INT32 -+ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), -+ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), -+ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, -+ Requires<[hasHWROT32]>; -+ -+// SW version of rotate 64 -+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), -+ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, -+ Requires<[noHWROT32]>; -+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), -+ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, -+ Requires<[noHWROT32]>; -+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), -+ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, -+ Requires<[noHWROT32]>; -+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), -+ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, -+ Requires<[noHWROT32]>; -+ - - //----------------------------------- - // Texture Intrinsics -diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp ---- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -@@ -1109,21 +1109,11 @@ - AddrSpaceCastSDNode *CastN = cast(N); - unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); - unsigned DstAddrSpace = CastN->getDestAddressSpace(); -- SDLoc DL(N); - assert(SrcAddrSpace != DstAddrSpace && - "addrspacecast must be between different address spaces"); - - if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { - // Specific to generic -- -- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { -- SDValue CvtNone = -- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); -- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, -- Src, CvtNone); -- Src = SDValue(Cvt, 0); -- } -- - unsigned Opc; - switch (SrcAddrSpace) { - default: report_fatal_error("Bad address space in addrspacecast"); -@@ -1131,16 +1121,26 @@ - Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; - break; - case ADDRESS_SPACE_SHARED: -- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -+ ? NVPTX::cvta_shared_6432 -+ : NVPTX::cvta_shared_64) -+ : NVPTX::cvta_shared; - break; - case ADDRESS_SPACE_CONST: -- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -+ ? NVPTX::cvta_const_6432 -+ : NVPTX::cvta_const_64) -+ : NVPTX::cvta_const; - break; - case ADDRESS_SPACE_LOCAL: -- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -+ ? NVPTX::cvta_local_6432 -+ : NVPTX::cvta_local_64) -+ : NVPTX::cvta_local; - break; - } -- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); -+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -+ Src)); - return; - } else { - // Generic to specific -@@ -1153,28 +1153,30 @@ - Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; - break; - case ADDRESS_SPACE_SHARED: -- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -+ ? NVPTX::cvta_to_shared_3264 -+ : NVPTX::cvta_to_shared_64) -+ : NVPTX::cvta_to_shared; - break; - case ADDRESS_SPACE_CONST: -- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -+ ? NVPTX::cvta_to_const_3264 -+ : NVPTX::cvta_to_const_64) -+ : NVPTX::cvta_to_const; - break; - case ADDRESS_SPACE_LOCAL: -- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -+ ? NVPTX::cvta_to_local_3264 -+ : NVPTX::cvta_to_local_64) -+ : NVPTX::cvta_to_local; - break; - case ADDRESS_SPACE_PARAM: -- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; -+ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 -+ : NVPTX::nvvm_ptr_gen_to_param; - break; - } -- -- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); -- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { -- SDValue CvtNone = -- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); -- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, -- SDValue(CVTA, 0), CvtNone); -- } -- -- ReplaceNode(N, CVTA); -+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -+ Src)); - return; - } - } -diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp ---- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -@@ -594,13 +594,20 @@ - setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); - setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); - -- setOperationAction({ISD::ROTL, ISD::ROTR}, -- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, -- Expand); -- -- if (STI.hasHWROT32()) -- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); -+ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs -+ // that don't have h/w rotation we lower them to multi-instruction assembly. -+ // See ROT*_sw in NVPTXIntrInfo.td -+ setOperationAction(ISD::ROTL, MVT::i64, Legal); -+ setOperationAction(ISD::ROTR, MVT::i64, Legal); -+ setOperationAction(ISD::ROTL, MVT::i32, Legal); -+ setOperationAction(ISD::ROTR, MVT::i32, Legal); - -+ setOperationAction(ISD::ROTL, MVT::i16, Expand); -+ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); -+ setOperationAction(ISD::ROTR, MVT::i16, Expand); -+ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); -+ setOperationAction(ISD::ROTL, MVT::i8, Expand); -+ setOperationAction(ISD::ROTR, MVT::i8, Expand); - setOperationAction(ISD::BSWAP, MVT::i16, Expand); - - setOperationAction(ISD::BR_JT, MVT::Other, Custom); -diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll ---- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -+++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -@@ -31,19 +31,6 @@ - declare i64 @llvm.nvvm.bitcast.d2ll(double) - declare double @llvm.nvvm.bitcast.ll2d(i64) - --declare i32 @llvm.nvvm.rotate.b32(i32, i32) --declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) --declare i64 @llvm.nvvm.rotate.b64(i64, i32) -- --declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) --declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) --declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) --declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) --declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) --declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) --declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) --declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) -- - ; CHECK-LABEL: @simple_upgrade - define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { - ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) -@@ -152,42 +139,4 @@ - %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) - - ret void --} -- --; CHECK-LABEL: @rotate --define void @rotate(i32 %a, i64 %b) { --; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) --; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) --; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) --; -- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) -- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) -- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) -- ret void --} -- --; CHECK-LABEL: @addrspacecast --define void @addrspacecast(ptr %p0) { --; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) --; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr --; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) --; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr --; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) --; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr --; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) --; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr --; -- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) -- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) -- -- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) -- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) -- -- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) -- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) -- -- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) -- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) -- -- ret void --} -+} -\ No newline at end of file -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll ---- a/llvm/test/CodeGen/AMDGPU/freeze.ll -+++ b/llvm/test/CodeGen/AMDGPU/freeze.ll -@@ -1,1856 +0,0 @@ --; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py --; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s --; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s --; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s --; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s -- --define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v2i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v2i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <2 x i32> %a -- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v3i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v3i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <3 x i32> %a -- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v4i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v4i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <4 x i32> %a -- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v5i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v5i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v5i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v5i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <5 x i32> %a -- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v6i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v6i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v6i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v6i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <6 x i32> %a -- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v7i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v7i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v7i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v7i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <7 x i32> %a -- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v8i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v8i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v8i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v8i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <8 x i32> %a -- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v9i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x2 --; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v9i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x2 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v9i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x2 --; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v9i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x2 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <9 x i32> %a -- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v10i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: s_clause 0x2 --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 --; GFX10-NEXT: s_waitcnt vmcnt(2) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_waitcnt vmcnt(1) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v10i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: s_clause 0x2 --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 --; GFX11-NEXT: s_waitcnt vmcnt(2) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_waitcnt vmcnt(1) --; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <10 x i32> %a -- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v11i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: s_clause 0x2 --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 --; GFX10-NEXT: s_waitcnt vmcnt(2) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_waitcnt vmcnt(1) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v11i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: s_clause 0x2 --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 --; GFX11-NEXT: s_waitcnt vmcnt(2) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_waitcnt vmcnt(1) --; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <11 x i32> %a -- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v12i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: s_clause 0x2 --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-NEXT: s_waitcnt vmcnt(2) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_waitcnt vmcnt(1) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v12i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: s_clause 0x2 --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-NEXT: s_waitcnt vmcnt(2) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_waitcnt vmcnt(1) --; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <12 x i32> %a -- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} --define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v13i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x3 --; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v13i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x3 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v13i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x3 --; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v13i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x3 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <13 x i32> %a -- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v14i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x3 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v14i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x3 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v14i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x3 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v14i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x3 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <14 x i32> %a -- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v15i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x3 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v15i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x3 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v15i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x3 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v15i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x3 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <15 x i32> %a -- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v16i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x3 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v16i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x3 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v16i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x3 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v16i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x3 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <16 x i32> %a -- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v17i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x4 --; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v17i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x4 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v17i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x4 --; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v17i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x4 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <17 x i32> %a -- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v18i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x4 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v18i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x4 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v18i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x4 --; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v18i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x4 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <18 x i32> %a -- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v19i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x4 --; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v19i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x4 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v19i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x4 --; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v19i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x4 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <19 x i32> %a -- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v20i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x4 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v20i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x4 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v20i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x4 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v20i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x4 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <20 x i32> %a -- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v21i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x5 --; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v21i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x5 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v21i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x5 --; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v21i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x5 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <21 x i32> %a -- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v22i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x5 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v22i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x5 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v22i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x5 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v22i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x5 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <22 x i32> %a -- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v30i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x7 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v30i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x7 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 --; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v30i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x7 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 --; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v30i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x7 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <30 x i32> %a -- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v31i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x7 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 --; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v31i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x7 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 --; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v31i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x7 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 --; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v31i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x7 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 --; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <31 x i32> %a -- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v32i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x7 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v32i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x7 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v32i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x7 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v32i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x7 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 --; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <32 x i32> %a -- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dword v0, v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dword v[2:3], v0, off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b32 v0, v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b32 v[2:3], v0, off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load i32, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze i32 %a -- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_i64: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_i64: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load i64, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze i64 %a -- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_float: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dword v0, v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dword v[2:3], v0, off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_float: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b32 v0, v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b32 v[2:3], v0, off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load float, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze float %a -- store float %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_i128: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_i128: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load i128, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze i128 %a -- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_i256: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_i256: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_i256: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_i256: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load i256, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze i256 %a -- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -@@ -171,9 +171,11 @@ - ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 - ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 - ; GCN-NEXT: {{ $}} -- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF -- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) -- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) -+ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF -+ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 -+ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 -+ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 -+ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] - %0:sgpr(s192) = G_IMPLICIT_DEF - %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 - S_ENDPGM 0, implicit %1, implicit %2, implicit %3 -@@ -292,11 +294,11 @@ - ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) - ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) - ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) -- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) -- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) -- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) -- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) -- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) -+ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) -+ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) -+ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) -+ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) -+ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) - %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 - %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 - %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -@@ -171,8 +171,12 @@ - - ; CHECK-LABEL: name: test_freeze_s448 - ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 -- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] -- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) -+ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) -+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) -+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) -+ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) - %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 - %1:_(s448) = G_TRUNC %0 - %2:_(s448) = G_FREEZE %1 -@@ -395,12 +399,14 @@ - bb.0: - - ; CHECK-LABEL: name: test_freeze_v33s32 -- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF - ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] -- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) -- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) -+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) -+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) -+ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) - ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) - %0:_(<33 x s32>) = G_IMPLICIT_DEF - %1:_(<33 x s32>) = G_FREEZE %0 -@@ -413,10 +419,12 @@ - bb.0: - - ; CHECK-LABEL: name: test_freeze_v64s32 -- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) -+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) - ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) - %0:_(<64 x s32>) = G_IMPLICIT_DEF - %1:_(<64 x s32>) = G_FREEZE %0 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -@@ -135,9 +135,8 @@ - bb.0: - - ; CHECK-LABEL: name: test_implicit_def_s448 -- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) -- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 -+ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 - ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) - %0:_(s448) = G_IMPLICIT_DEF - %1:_(s32) = G_EXTRACT %0, 0 -@@ -297,6 +296,18 @@ - ... - - --- -+name: test_implicit_def_v17s32 -+body: | -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_implicit_def_v17s32 -+ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) -+ %0:_(<17 x s32>) = G_IMPLICIT_DEF -+ S_NOP 0, implicit %0 -+... -+ -+--- - name: test_implicit_def_v32s32 - body: | - bb.0: -@@ -317,9 +328,9 @@ - ; CHECK-LABEL: name: test_implicit_def_v33s32 - ; CHECK: liveins: $vgpr0_vgpr1 - ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF - ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) - ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 - ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) - ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) -@@ -337,9 +348,10 @@ - bb.0: - - ; CHECK-LABEL: name: test_implicit_def_v64s32 -- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) -- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) -+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) - %0:_(<64 x s32>) = G_IMPLICIT_DEF - %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 - S_NOP 0, implicit %0, implicit %1 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -@@ -190,11 +190,13 @@ - ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 - ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 - ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF - ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 - ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 -- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) - ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) - ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 - ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) -@@ -241,8 +243,10 @@ - ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 - ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) - ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) -- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) - ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) - ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) - ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -@@ -673,86 +673,88 @@ - ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) - ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 - ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF - ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 - ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 - ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] -- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) - ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 - ; CHECK-NEXT: G_BR %bb.2 - ; CHECK-NEXT: {{ $}} - ; CHECK-NEXT: bb.1: - ; CHECK-NEXT: successors: %bb.2(0x80000000) - ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] -- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] -- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] -- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] -- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] -- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] -- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] -- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] -- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] -- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] -- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] -- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] -- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] -- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] -- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] -- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] -- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] -- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] -- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] -- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] -- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] -- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] -- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] -- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] -- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] -- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] -- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] -- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] -- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] -- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] -- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] -- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] -- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] -- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] -- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] -- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] -- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] -- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] -- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] -- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] -- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] -- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] -- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] -- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] -- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] -- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] -- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] -- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] -- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] -- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] -- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] -- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] -- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] -- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] -- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] -- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] -- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] -- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] -- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] -- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] -- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] -- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] -- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] -- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] -+ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] -+ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] -+ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] -+ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] -+ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] -+ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] -+ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] -+ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] -+ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] -+ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] -+ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] -+ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] -+ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] -+ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] -+ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] -+ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] -+ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] -+ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] -+ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] -+ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] -+ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] -+ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] -+ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] -+ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] -+ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] -+ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] -+ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] -+ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] -+ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] -+ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] -+ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] -+ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] -+ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] -+ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] -+ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] -+ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] -+ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] -+ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] -+ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] -+ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] -+ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] -+ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] -+ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] -+ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] -+ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] -+ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] -+ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] -+ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] -+ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] -+ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] -+ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] -+ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] -+ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] -+ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] -+ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] -+ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] -+ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] -+ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] -+ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] -+ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] -+ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] -+ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] -+ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] - ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) - ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) - ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) -@@ -760,10 +762,10 @@ - ; CHECK-NEXT: G_BR %bb.2 - ; CHECK-NEXT: {{ $}} - ; CHECK-NEXT: bb.2: -- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 -- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 -- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 -- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 - ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) - ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) - bb.0: -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -@@ -42,6 +42,8 @@ - ret void - } - -+ define void @non_power_of_2() { ret void } -+ - define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { - ret void - } -@@ -185,6 +187,23 @@ - ... - - --- -+name: non_power_of_2 -+legalized: true -+ -+body: | -+ bb.0: -+ ; CHECK-LABEL: name: non_power_of_2 -+ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 -+ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) -+ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 -+ %0:_(s448) = G_IMPLICIT_DEF -+ %1:_(s32) = G_EXTRACT %0:_(s448), 0 -+ $sgpr0 = COPY %1:_(s32) -+ SI_RETURN_TO_EPILOG $sgpr0 -+... -+ -+--- - name: load_constant_v4i16_from_8_align8 - legalized: true - -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll ---- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -+++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -@@ -0,0 +1,21 @@ -+; RUN: opt < %s -O3 -S | FileCheck %s -+ -+; Address space intrinsics were erroneously marked NoCapture, leading to bad -+; optimizations (such as the store below being eliminated as dead code). This -+; test makes sure we don't regress. -+ -+declare void @foo(ptr addrspace(1)) -+ -+declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -+ -+; CHECK: @bar -+define void @bar() { -+ %t1 = alloca i32 -+; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) -+; CHECK-NEXT: store i32 10, ptr %t1 -+ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) -+ store i32 10, ptr %t1 -+ call void @foo(ptr addrspace(1) %t2) -+ ret void -+} -+ -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll ---- a/llvm/test/CodeGen/NVPTX/rotate_64.ll -+++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll -@@ -1,38 +1,25 @@ --; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 - ; RUN: llc < %s -march=nvptx64 | FileCheck %s - ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} - - declare i64 @llvm.nvvm.rotate.b64(i64, i32) - declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) - -+; CHECK: rotate64 - define i64 @rotate64(i64 %a, i32 %b) { --; CHECK-LABEL: rotate64( --; CHECK: { --; CHECK-NEXT: .reg .b64 %rd<5>; --; CHECK-EMPTY: --; CHECK-NEXT: // %bb.0: --; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; --; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; --; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; --; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; --; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; --; CHECK-NEXT: ret; -+; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; -+; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; -+; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; -+; CHECK: ret - %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) - ret i64 %val - } - -+; CHECK: rotateright64 - define i64 @rotateright64(i64 %a, i32 %b) { --; CHECK-LABEL: rotateright64( --; CHECK: { --; CHECK-NEXT: .reg .b64 %rd<5>; --; CHECK-EMPTY: --; CHECK-NEXT: // %bb.0: --; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; --; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; --; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; --; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; --; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; --; CHECK-NEXT: ret; -+; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; -+; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; -+; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; -+; CHECK: ret - %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) - ret i64 %val - } -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll ---- a/llvm/test/CodeGen/NVPTX/rotate.ll -+++ b/llvm/test/CodeGen/NVPTX/rotate.ll -@@ -9,29 +9,26 @@ - declare i64 @llvm.nvvm.rotate.b64(i64, i32) - declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) - --declare i64 @llvm.fshl.i64(i64, i64, i64) --declare i64 @llvm.fshr.i64(i64, i64, i64) --declare i32 @llvm.fshl.i32(i32, i32, i32) --declare i32 @llvm.fshr.i32(i32, i32, i32) -- -- - ; SM20: rotate32 - ; SM35: rotate32 - define i32 @rotate32(i32 %a, i32 %b) { - ; SM20-LABEL: rotate32( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<9>; -+; SM20-NEXT: .reg .b32 %r<4>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; - ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; --; SM20-NEXT: and.b32 %r3, %r2, 31; --; SM20-NEXT: shl.b32 %r4, %r1, %r3; --; SM20-NEXT: neg.s32 %r5, %r2; --; SM20-NEXT: and.b32 %r6, %r5, 31; --; SM20-NEXT: shr.u32 %r7, %r1, %r6; --; SM20-NEXT: or.b32 %r8, %r4, %r7; --; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b32 %lhs; -+; SM20-NEXT: .reg .b32 %rhs; -+; SM20-NEXT: .reg .b32 %amt2; -+; SM20-NEXT: shl.b32 %lhs, %r1, %r2; -+; SM20-NEXT: sub.s32 %amt2, 32, %r2; -+; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; -+; SM20-NEXT: add.u32 %r3, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotate32( -@@ -53,36 +50,45 @@ - define i64 @rotate64(i64 %a, i32 %b) { - ; SM20-LABEL: rotate64( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b32 %r<2>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; - ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM20-NEXT: neg.s32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; --; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: .reg .u32 %amt2; -+; SM20-NEXT: and.b32 %amt2, %r1, 63; -+; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; -+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotate64( - ; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b32 %r<6>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; --; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM35-NEXT: neg.s32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; --; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b32 %dummy; -+; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; -+; SM35-NEXT: } -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b32 %dummy; -+; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; -+; SM35-NEXT: } -+; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; -+; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; -+; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; -+; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) - ret i64 %val -@@ -93,36 +99,45 @@ - define i64 @rotateright64(i64 %a, i32 %b) { - ; SM20-LABEL: rotateright64( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b32 %r<2>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; - ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; --; SM20-NEXT: neg.s32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; --; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: .reg .u32 %amt2; -+; SM20-NEXT: and.b32 %amt2, %r1, 63; -+; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; -+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotateright64( - ; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b32 %r<6>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; --; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; --; SM35-NEXT: neg.s32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; --; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b32 %dummy; -+; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; -+; SM35-NEXT: } -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b32 %dummy; -+; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; -+; SM35-NEXT: } -+; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; -+; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; -+; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; -+; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) - ret i64 %val -@@ -133,14 +148,18 @@ - define i32 @rotl0(i32 %x) { - ; SM20-LABEL: rotl0( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; -+; SM20-NEXT: .reg .b32 %r<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; --; SM20-NEXT: shr.u32 %r2, %r1, 24; --; SM20-NEXT: shl.b32 %r3, %r1, 8; --; SM20-NEXT: or.b32 %r4, %r3, %r2; --; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b32 %lhs; -+; SM20-NEXT: .reg .b32 %rhs; -+; SM20-NEXT: shl.b32 %lhs, %r1, 8; -+; SM20-NEXT: shr.b32 %rhs, %r1, 24; -+; SM20-NEXT: add.u32 %r2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotl0( -@@ -158,40 +177,51 @@ - ret i32 %t2 - } - -+declare i64 @llvm.fshl.i64(i64, i64, i64) -+declare i64 @llvm.fshr.i64(i64, i64, i64) -+ - ; SM35: rotl64 - define i64 @rotl64(i64 %a, i64 %n) { - ; SM20-LABEL: rotl64( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b32 %r<2>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; - ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM20-NEXT: neg.s32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; --; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: .reg .u32 %amt2; -+; SM20-NEXT: and.b32 %amt2, %r1, 63; -+; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; -+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotl64( - ; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b32 %r<2>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; - ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM35-NEXT: neg.s32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; --; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b64 %lhs; -+; SM35-NEXT: .reg .b64 %rhs; -+; SM35-NEXT: .reg .u32 %amt2; -+; SM35-NEXT: and.b32 %amt2, %r1, 63; -+; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; -+; SM35-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; -+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM35-NEXT: } -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) - ret i64 %val -@@ -201,26 +231,34 @@ - define i64 @rotl64_imm(i64 %a) { - ; SM20-LABEL: rotl64_imm( - ; SM20: { --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; --; SM20-NEXT: shr.u64 %rd2, %rd1, 62; --; SM20-NEXT: shl.b64 %rd3, %rd1, 2; --; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: shl.b64 %lhs, %rd1, 2; -+; SM20-NEXT: shr.b64 %rhs, %rd1, 62; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotl64_imm( - ; SM35: { --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; --; SM35-NEXT: shr.u64 %rd2, %rd1, 62; --; SM35-NEXT: shl.b64 %rd3, %rd1, 2; --; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b64 %lhs; -+; SM35-NEXT: .reg .b64 %rhs; -+; SM35-NEXT: shl.b64 %lhs, %rd1, 2; -+; SM35-NEXT: shr.b64 %rhs, %rd1, 62; -+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM35-NEXT: } -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) - ret i64 %val -@@ -230,36 +268,44 @@ - define i64 @rotr64(i64 %a, i64 %n) { - ; SM20-LABEL: rotr64( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b32 %r<2>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; - ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; --; SM20-NEXT: neg.s32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; --; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: .reg .u32 %amt2; -+; SM20-NEXT: and.b32 %amt2, %r1, 63; -+; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; -+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotr64( - ; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b32 %r<2>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; - ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; --; SM35-NEXT: neg.s32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; --; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b64 %lhs; -+; SM35-NEXT: .reg .b64 %rhs; -+; SM35-NEXT: .reg .u32 %amt2; -+; SM35-NEXT: and.b32 %amt2, %r1, 63; -+; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; -+; SM35-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; -+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM35-NEXT: } -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) - ret i64 %val -@@ -269,180 +315,35 @@ - define i64 @rotr64_imm(i64 %a) { - ; SM20-LABEL: rotr64_imm( - ; SM20: { --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; --; SM20-NEXT: shl.b64 %rd2, %rd1, 62; --; SM20-NEXT: shr.u64 %rd3, %rd1, 2; --; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: shl.b64 %lhs, %rd1, 62; -+; SM20-NEXT: shr.b64 %rhs, %rd1, 2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotr64_imm( - ; SM35: { --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; --; SM35-NEXT: shl.b64 %rd2, %rd1, 62; --; SM35-NEXT: shr.u64 %rd3, %rd1, 2; --; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b64 %lhs; -+; SM35-NEXT: .reg .b64 %rhs; -+; SM35-NEXT: shl.b64 %lhs, %rd1, 62; -+; SM35-NEXT: shr.b64 %rhs, %rd1, 2; -+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM35-NEXT: } -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) - ret i64 %val - } -- --define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { --; SM20-LABEL: funnel_shift_right_32( --; SM20: { --; SM20-NEXT: .reg .b32 %r<11>; --; SM20-EMPTY: --; SM20-NEXT: // %bb.0: --; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; --; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; --; SM20-NEXT: and.b32 %r3, %r2, 31; --; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; --; SM20-NEXT: shr.u32 %r5, %r4, %r3; --; SM20-NEXT: shl.b32 %r6, %r1, 1; --; SM20-NEXT: not.b32 %r7, %r2; --; SM20-NEXT: and.b32 %r8, %r7, 31; --; SM20-NEXT: shl.b32 %r9, %r6, %r8; --; SM20-NEXT: or.b32 %r10, %r9, %r5; --; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; --; SM20-NEXT: ret; --; --; SM35-LABEL: funnel_shift_right_32( --; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-EMPTY: --; SM35-NEXT: // %bb.0: --; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; --; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; --; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; --; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; --; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; --; SM35-NEXT: ret; -- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) -- ret i32 %val --} -- --define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { --; SM20-LABEL: funnel_shift_left_32( --; SM20: { --; SM20-NEXT: .reg .b32 %r<11>; --; SM20-EMPTY: --; SM20-NEXT: // %bb.0: --; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; --; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; --; SM20-NEXT: and.b32 %r3, %r2, 31; --; SM20-NEXT: shl.b32 %r4, %r1, %r3; --; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; --; SM20-NEXT: shr.u32 %r6, %r5, 1; --; SM20-NEXT: not.b32 %r7, %r2; --; SM20-NEXT: and.b32 %r8, %r7, 31; --; SM20-NEXT: shr.u32 %r9, %r6, %r8; --; SM20-NEXT: or.b32 %r10, %r4, %r9; --; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; --; SM20-NEXT: ret; --; --; SM35-LABEL: funnel_shift_left_32( --; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-EMPTY: --; SM35-NEXT: // %bb.0: --; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; --; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; --; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; --; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; --; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; --; SM35-NEXT: ret; -- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) -- ret i32 %val --} -- --define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { --; SM20-LABEL: funnel_shift_right_64( --; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<7>; --; SM20-EMPTY: --; SM20-NEXT: // %bb.0: --; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; --; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; --; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; --; SM20-NEXT: shl.b64 %rd4, %rd1, 1; --; SM20-NEXT: not.b32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; --; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; --; SM20-NEXT: ret; --; --; SM35-LABEL: funnel_shift_right_64( --; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<7>; --; SM35-EMPTY: --; SM35-NEXT: // %bb.0: --; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; --; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; --; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; --; SM35-NEXT: shl.b64 %rd4, %rd1, 1; --; SM35-NEXT: not.b32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; --; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; --; SM35-NEXT: ret; -- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) -- ret i64 %val --} -- --define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { --; SM20-LABEL: funnel_shift_left_64( --; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<7>; --; SM20-EMPTY: --; SM20-NEXT: // %bb.0: --; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; --; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; --; SM20-NEXT: shr.u64 %rd4, %rd3, 1; --; SM20-NEXT: not.b32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; --; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; --; SM20-NEXT: ret; --; --; SM35-LABEL: funnel_shift_left_64( --; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<7>; --; SM35-EMPTY: --; SM35-NEXT: // %bb.0: --; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; --; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; --; SM35-NEXT: shr.u64 %rd4, %rd3, 1; --; SM35-NEXT: not.b32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; --; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; --; SM35-NEXT: ret; -- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) -- ret i64 %val --} -- -diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll ---- a/llvm/test/DebugInfo/NVPTX/debug-info.ll -+++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll -@@ -25,10 +25,6 @@ - ; CHECK-DAG: .reg .b64 %rd<8>; - ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 - ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; --; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; --; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; --; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; --; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; - ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 - ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; - ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 -@@ -42,6 +38,10 @@ - ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 - ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; - ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; -+; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -+; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -+; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -+; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; - ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 - ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; - ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; -@@ -2661,22 +2661,22 @@ - ; CHECK-NEXT:.b32 4579 // DW_AT_type - ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine - ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin --; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc --; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc -+; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc -+; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc - ; CHECK-NEXT:.b8 1 // DW_AT_call_file - ; CHECK-NEXT:.b8 6 // DW_AT_call_line - ; CHECK-NEXT:.b8 11 // DW_AT_call_column - ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine - ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin --; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc --; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc -+; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc -+; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc - ; CHECK-NEXT:.b8 1 // DW_AT_call_file - ; CHECK-NEXT:.b8 6 // DW_AT_call_line - ; CHECK-NEXT:.b8 24 // DW_AT_call_column - ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine - ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin --; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc --; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc -+; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc -+; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc - ; CHECK-NEXT:.b8 1 // DW_AT_call_file - ; CHECK-NEXT:.b8 6 // DW_AT_call_line - ; CHECK-NEXT:.b8 37 // DW_AT_call_column diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index af35fe705c0b99..7b11086785b618 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" - LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" + LLVM_COMMIT = "29b92d07746fac26cd64c914bc9c5c3833974f6d" + LLVM_SHA256 = "3e8e93e3749454af4b64f7f34b792a4748b62fc533bca1703d33b2b04e34eb70" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index d3fd21823cce19..014b81b4e7518b 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,4115 +1,4115 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..de92cb4 100644 +index de92cb4..509398d 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1 +1,4095 @@ +@@ -1,4095 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst -+--- a/llvm/docs/NVPTXUsage.rst -++++ b/llvm/docs/NVPTXUsage.rst -+@@ -127,6 +127,69 @@ -+ NVPTX Intrinsics -+ ================ -+ -++Address Space Conversion -++------------------------ -++ -++'``llvm.nvvm.ptr.*.to.gen``' Intrinsics -++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -++ -++Syntax: -++""""""" -++ -++These are overloaded intrinsics. You can use these on any pointer types. -++ -++.. code-block:: llvm -++ -++ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) -++ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) -++ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) -++ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) -++ -++Overview: -++""""""""" -++ -++The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic -++address space to a generic address space pointer. -++ -++Semantics: -++"""""""""" -++ -++These intrinsics modify the pointer value to be a valid generic address space -++pointer. -++ -++ -++'``llvm.nvvm.ptr.gen.to.*``' Intrinsics -++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -++ -++Syntax: -++""""""" -++ -++These are overloaded intrinsics. You can use these on any pointer types. -++ -++.. code-block:: llvm -++ -++ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -++ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) -++ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) -++ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) -++ -++Overview: -++""""""""" -++ -++The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic -++address space to a pointer in the target address space. Note that these -++intrinsics are only useful if the address space of the target address space of -++the pointer is known. It is not legal to use address space conversion -++intrinsics to convert a pointer from one non-generic address space to another -++non-generic address space. -++ -++Semantics: -++"""""""""" -++ -++These intrinsics modify the pointer value to be a valid pointer in the target -++non-generic address space. -++ -++ -+ Reading PTX Special Registers -+ ----------------------------- -+ -+diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst -+--- a/llvm/docs/ReleaseNotes.rst -++++ b/llvm/docs/ReleaseNotes.rst -+@@ -63,24 +63,6 @@ -+ * ``llvm.nvvm.bitcast.d2ll`` -+ * ``llvm.nvvm.bitcast.ll2d`` -+ -+-* Remove the following intrinsics which can be replaced with a funnel-shift: -+- -+- * ``llvm.nvvm.rotate.b32`` -+- * ``llvm.nvvm.rotate.right.b64`` -+- * ``llvm.nvvm.rotate.b64`` -+- -+-* Remove the following intrinsics which can be replaced with an -+- ``addrspacecast``: -+- -+- * ``llvm.nvvm.ptr.gen.to.global`` -+- * ``llvm.nvvm.ptr.gen.to.shared`` -+- * ``llvm.nvvm.ptr.gen.to.constant`` -+- * ``llvm.nvvm.ptr.gen.to.local`` -+- * ``llvm.nvvm.ptr.global.to.gen`` -+- * ``llvm.nvvm.ptr.shared.to.gen`` -+- * ``llvm.nvvm.ptr.constant.to.gen`` -+- * ``llvm.nvvm.ptr.local.to.gen`` -+- -+ Changes to LLVM infrastructure -+ ------------------------------ -+ -+diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td -+--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td -++++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td -+@@ -30,18 +30,10 @@ -+ // * llvm.nvvm.max.ui --> select(x ule y, x, y) -+ // * llvm.nvvm.max.ull --> ibid. -+ // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 -+-// * llvm.nvvm.bitcast.f2i --> bitcast -+-// * llvm.nvvm.bitcast.i2f --> ibid. -+-// * llvm.nvvm.bitcast.d2ll --> ibid. -+-// * llvm.nvvm.bitcast.ll2d --> ibid. -+-// * llvm.nvvm.ptr.gen.to.global --> addrspacecast -+-// * llvm.nvvm.ptr.gen.to.shared --> ibid. -+-// * llvm.nvvm.ptr.gen.to.constant --> ibid. -+-// * llvm.nvvm.ptr.gen.to.local --> ibid. -+-// * llvm.nvvm.ptr.global.to.gen --> ibid. -+-// * llvm.nvvm.ptr.shared.to.gen --> ibid. -+-// * llvm.nvvm.ptr.constant.to.gen --> ibid. -+-// * llvm.nvvm.ptr.local.to.gen --> ibid. -++// * llvm.nvvm.bitcast.f2i --> bitcast -++// * llvm.nvvm.bitcast.i2f --> ibid. -++// * llvm.nvvm.bitcast.d2ll --> ibid. -++// * llvm.nvvm.bitcast.ll2d --> ibid. -+ -+ def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr -+ def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr -+@@ -1610,6 +1602,40 @@ -+ [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], -+ "llvm.nvvm.ldg.global.p">; -+ -++// Use for generic pointers -++// - These intrinsics are used to convert address spaces. -++// - The input pointer and output pointer must have the same type, except for -++// the address-space. (This restriction is not enforced here as there is -++// currently no way to describe it). -++// - This complements the llvm bitcast, which can be used to cast one type -++// of pointer to another type of pointer, while the address space remains -++// the same. -++def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.local.to.gen">; -++def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.shared.to.gen">; -++def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.global.to.gen">; -++def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.constant.to.gen">; -++ -++def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.gen.to.global">; -++def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.gen.to.shared">; -++def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.gen.to.local">; -++def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.gen.to.constant">; -++ -+ // Used in nvvm internally to help address space opt and ptx code generation -+ // This is for params that are passed to kernel functions by pointer by-val. -+ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], -+@@ -4453,6 +4479,22 @@ -+ "llvm.nvvm.sust.p.3d.v4i32.trap">, -+ ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; -+ -++ -++def int_nvvm_rotate_b32 -++ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], -++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, -++ ClangBuiltin<"__nvvm_rotate_b32">; -++ -++def int_nvvm_rotate_b64 -++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], -++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, -++ ClangBuiltin<"__nvvm_rotate_b64">; -++ -++def int_nvvm_rotate_right_b64 -++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], -++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, -++ ClangBuiltin<"__nvvm_rotate_right_b64">; -++ -+ def int_nvvm_swap_lo_hi_b64 -+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], -+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, -+diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp -+--- a/llvm/lib/IR/AutoUpgrade.cpp -++++ b/llvm/lib/IR/AutoUpgrade.cpp -+@@ -1272,19 +1272,6 @@ -+ // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} -+ Expand = -+ Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; -+- else if (Name.consume_front("rotate.")) -+- // nvvm.rotate.{b32,b64,right.b64} -+- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; -+- else if (Name.consume_front("ptr.gen.to.")) -+- // nvvm.ptr.gen.to.{local,shared,global,constant} -+- Expand = Name.starts_with("local") || Name.starts_with("shared") || -+- Name.starts_with("global") || Name.starts_with("constant"); -+- else if (Name.consume_front("ptr.")) -+- // nvvm.ptr.{local,shared,global,constant}.to.gen -+- Expand = -+- (Name.consume_front("local") || Name.consume_front("shared") || -+- Name.consume_front("global") || Name.consume_front("constant")) && -+- Name.starts_with(".to.gen"); -+ else -+ Expand = false; -+ -+@@ -2271,117 +2258,6 @@ -+ } -+ } -+ -+-static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, -+- Function *F, IRBuilder<> &Builder) { -+- Value *Rep = nullptr; -+- -+- if (Name == "abs.i" || Name == "abs.ll") { -+- Value *Arg = CI->getArgOperand(0); -+- Value *Neg = Builder.CreateNeg(Arg, "neg"); -+- Value *Cmp = Builder.CreateICmpSGE( -+- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); -+- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); -+- } else if (Name.starts_with("atomic.load.add.f32.p") || -+- Name.starts_with("atomic.load.add.f64.p")) { -+- Value *Ptr = CI->getArgOperand(0); -+- Value *Val = CI->getArgOperand(1); -+- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), -+- AtomicOrdering::SequentiallyConsistent); -+- } else if (Name.consume_front("max.") && -+- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -+- Name == "ui" || Name == "ull")) { -+- Value *Arg0 = CI->getArgOperand(0); -+- Value *Arg1 = CI->getArgOperand(1); -+- Value *Cmp = Name.starts_with("u") -+- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") -+- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); -+- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); -+- } else if (Name.consume_front("min.") && -+- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -+- Name == "ui" || Name == "ull")) { -+- Value *Arg0 = CI->getArgOperand(0); -+- Value *Arg1 = CI->getArgOperand(1); -+- Value *Cmp = Name.starts_with("u") -+- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") -+- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); -+- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); -+- } else if (Name == "clz.ll") { -+- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. -+- Value *Arg = CI->getArgOperand(0); -+- Value *Ctlz = Builder.CreateCall( -+- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, -+- {Arg->getType()}), -+- {Arg, Builder.getFalse()}, "ctlz"); -+- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); -+- } else if (Name == "popc.ll") { -+- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an -+- // i64. -+- Value *Arg = CI->getArgOperand(0); -+- Value *Popc = Builder.CreateCall( -+- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, -+- {Arg->getType()}), -+- Arg, "ctpop"); -+- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); -+- } else if (Name == "h2f") { -+- Rep = Builder.CreateCall( -+- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, -+- {Builder.getFloatTy()}), -+- CI->getArgOperand(0), "h2f"); -+- } else if (Name.consume_front("bitcast.") && -+- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || -+- Name == "d2ll")) { -+- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); -+- } else if (Name == "rotate.b32") { -+- Value *Arg = CI->getOperand(0); -+- Value *ShiftAmt = CI->getOperand(1); -+- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, -+- {Arg, Arg, ShiftAmt}); -+- } else if (Name == "rotate.b64") { -+- Type *Int64Ty = Builder.getInt64Ty(); -+- Value *Arg = CI->getOperand(0); -+- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); -+- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, -+- {Arg, Arg, ZExtShiftAmt}); -+- } else if (Name == "rotate.right.b64") { -+- Type *Int64Ty = Builder.getInt64Ty(); -+- Value *Arg = CI->getOperand(0); -+- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); -+- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, -+- {Arg, Arg, ZExtShiftAmt}); -+- } else if ((Name.consume_front("ptr.gen.to.") && -+- (Name.starts_with("local") || Name.starts_with("shared") || -+- Name.starts_with("global") || Name.starts_with("constant"))) || -+- (Name.consume_front("ptr.") && -+- (Name.consume_front("local") || Name.consume_front("shared") || -+- Name.consume_front("global") || -+- Name.consume_front("constant")) && -+- Name.starts_with(".to.gen"))) { -+- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); -+- } else { -+- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); -+- if (IID != Intrinsic::not_intrinsic && -+- !F->getReturnType()->getScalarType()->isBFloatTy()) { -+- rename(F); -+- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); -+- SmallVector Args; -+- for (size_t I = 0; I < NewFn->arg_size(); ++I) { -+- Value *Arg = CI->getArgOperand(I); -+- Type *OldType = Arg->getType(); -+- Type *NewType = NewFn->getArg(I)->getType(); -+- Args.push_back( -+- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) -+- ? Builder.CreateBitCast(Arg, NewType) -+- : Arg); -+- } -+- Rep = Builder.CreateCall(NewFn, Args); -+- if (F->getReturnType()->isIntegerTy()) -+- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); -+- } -+- } -+- -+- return Rep; -+-} -+- -+ static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, -+ IRBuilder<> &Builder) { -+ LLVMContext &C = F->getContext(); -+@@ -4332,8 +4208,85 @@ -+ -+ if (!IsX86 && Name == "stackprotectorcheck") { -+ Rep = nullptr; -++ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { -++ Value *Arg = CI->getArgOperand(0); -++ Value *Neg = Builder.CreateNeg(Arg, "neg"); -++ Value *Cmp = Builder.CreateICmpSGE( -++ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); -++ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); -++ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || -++ Name.starts_with("atomic.load.add.f64.p"))) { -++ Value *Ptr = CI->getArgOperand(0); -++ Value *Val = CI->getArgOperand(1); -++ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), -++ AtomicOrdering::SequentiallyConsistent); -++ } else if (IsNVVM && Name.consume_front("max.") && -++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -++ Name == "ui" || Name == "ull")) { -++ Value *Arg0 = CI->getArgOperand(0); -++ Value *Arg1 = CI->getArgOperand(1); -++ Value *Cmp = Name.starts_with("u") -++ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") -++ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); -++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); -++ } else if (IsNVVM && Name.consume_front("min.") && -++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -++ Name == "ui" || Name == "ull")) { -++ Value *Arg0 = CI->getArgOperand(0); -++ Value *Arg1 = CI->getArgOperand(1); -++ Value *Cmp = Name.starts_with("u") -++ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") -++ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); -++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); -++ } else if (IsNVVM && Name == "clz.ll") { -++ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. -++ Value *Arg = CI->getArgOperand(0); -++ Value *Ctlz = Builder.CreateCall( -++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, -++ {Arg->getType()}), -++ {Arg, Builder.getFalse()}, "ctlz"); -++ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); -++ } else if (IsNVVM && Name == "popc.ll") { -++ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an -++ // i64. -++ Value *Arg = CI->getArgOperand(0); -++ Value *Popc = Builder.CreateCall( -++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, -++ {Arg->getType()}), -++ Arg, "ctpop"); -++ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); -+ } else if (IsNVVM) { -+- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); -++ if (Name == "h2f") { -++ Rep = -++ Builder.CreateCall(Intrinsic::getDeclaration( -++ F->getParent(), Intrinsic::convert_from_fp16, -++ {Builder.getFloatTy()}), -++ CI->getArgOperand(0), "h2f"); -++ } else if (Name.consume_front("bitcast.") && -++ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || -++ Name == "d2ll")) { -++ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); -++ } else { -++ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); -++ if (IID != Intrinsic::not_intrinsic && -++ !F->getReturnType()->getScalarType()->isBFloatTy()) { -++ rename(F); -++ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); -++ SmallVector Args; -++ for (size_t I = 0; I < NewFn->arg_size(); ++I) { -++ Value *Arg = CI->getArgOperand(I); -++ Type *OldType = Arg->getType(); -++ Type *NewType = NewFn->getArg(I)->getType(); -++ Args.push_back((OldType->isIntegerTy() && -++ NewType->getScalarType()->isBFloatTy()) -++ ? Builder.CreateBitCast(Arg, NewType) -++ : Arg); -++ } -++ Rep = Builder.CreateCall(NewFn, Args); -++ if (F->getReturnType()->isIntegerTy()) -++ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); -++ } -++ } -+ } else if (IsX86) { -+ Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); -+ } else if (IsARM) { -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -+--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -++++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -+@@ -292,7 +292,6 @@ -+ static const LLT S224 = LLT::scalar(224); -+ static const LLT S256 = LLT::scalar(256); -+ static const LLT S512 = LLT::scalar(512); -+-static const LLT S1024 = LLT::scalar(1024); -+ static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); -+ -+ static const LLT V2S8 = LLT::fixed_vector(2, 8); -+@@ -333,8 +332,8 @@ -+ static const LLT V2S128 = LLT::fixed_vector(2, 128); -+ static const LLT V4S128 = LLT::fixed_vector(4, 128); -+ -+-static std::initializer_list AllScalarTypes = { -+- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; -++static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, -++ S160, S224, S256, S512}; -+ -+ static std::initializer_list AllS16Vectors{ -+ V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; -+@@ -890,11 +889,10 @@ -+ .clampScalar(0, S16, S64); -+ -+ getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) -+- .legalIf(isRegisterClassType(0)) -++ .legalIf(isRegisterType(0)) -+ // s1 and s16 are special cases because they have legal operations on -+ // them, but don't really occupy registers in the normal way. -+ .legalFor({S1, S16}) -+- .clampNumElements(0, V16S32, V32S32) -+ .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) -+ .clampScalarOrElt(0, S32, MaxScalar) -+ .widenScalarToNextPow2(0, 32) -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -+--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -++++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -+@@ -174,6 +174,10 @@ -+ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" -+ "&& Subtarget->getPTXVersion() >= 64)">; -+ -++def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; -++def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; -++def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; -++ -+ def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; -+ def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; -+ -+@@ -1661,6 +1665,167 @@ -+ "brev.b64 \t$dst, $a;", -+ [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; -+ -++// -++// Rotate: Use ptx shf instruction if available. -++// -++ -++// 32 bit r2 = rotl r1, n -++// => -++// r2 = shf.l r1, r1, n -++def ROTL32imm_hw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), -++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, -++ Requires<[hasHWROT32]>; -++ -++def ROTL32reg_hw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -++ Requires<[hasHWROT32]>; -++ -++// 32 bit r2 = rotr r1, n -++// => -++// r2 = shf.r r1, r1, n -++def ROTR32imm_hw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), -++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, -++ Requires<[hasHWROT32]>; -++ -++def ROTR32reg_hw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -++ Requires<[hasHWROT32]>; -++ -++// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. -++def ROT32imm_sw : -++ NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), -++ "{{\n\t" -++ ".reg .b32 %lhs;\n\t" -++ ".reg .b32 %rhs;\n\t" -++ "shl.b32 \t%lhs, $src, $amt1;\n\t" -++ "shr.b32 \t%rhs, $src, $amt2;\n\t" -++ "add.u32 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ []>; -++ -++def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); -++}]>; -++ -++def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), -++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, -++ Requires<[noHWROT32]>; -++def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), -++ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, -++ Requires<[noHWROT32]>; -++ -++// 32-bit software rotate left by register. -++def ROTL32reg_sw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -++ "{{\n\t" -++ ".reg .b32 %lhs;\n\t" -++ ".reg .b32 %rhs;\n\t" -++ ".reg .b32 %amt2;\n\t" -++ "shl.b32 \t%lhs, $src, $amt;\n\t" -++ "sub.s32 \t%amt2, 32, $amt;\n\t" -++ "shr.b32 \t%rhs, $src, %amt2;\n\t" -++ "add.u32 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -++ Requires<[noHWROT32]>; -++ -++// 32-bit software rotate right by register. -++def ROTR32reg_sw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -++ "{{\n\t" -++ ".reg .b32 %lhs;\n\t" -++ ".reg .b32 %rhs;\n\t" -++ ".reg .b32 %amt2;\n\t" -++ "shr.b32 \t%lhs, $src, $amt;\n\t" -++ "sub.s32 \t%amt2, 32, $amt;\n\t" -++ "shl.b32 \t%rhs, $src, %amt2;\n\t" -++ "add.u32 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -++ Requires<[noHWROT32]>; -++ -++// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. -++def ROT64imm_sw : -++ NVPTXInst<(outs Int64Regs:$dst), -++ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), -++ "{{\n\t" -++ ".reg .b64 %lhs;\n\t" -++ ".reg .b64 %rhs;\n\t" -++ "shl.b64 \t%lhs, $src, $amt1;\n\t" -++ "shr.b64 \t%rhs, $src, $amt2;\n\t" -++ "add.u64 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ []>; -++ -++def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); -++}]>; -++ -++def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), -++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; -++def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), -++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; -++ -++// 64-bit software rotate left by register. -++def ROTL64reg_sw : -++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), -++ "{{\n\t" -++ ".reg .b64 %lhs;\n\t" -++ ".reg .b64 %rhs;\n\t" -++ ".reg .u32 %amt2;\n\t" -++ "and.b32 \t%amt2, $amt, 63;\n\t" -++ "shl.b64 \t%lhs, $src, %amt2;\n\t" -++ "sub.u32 \t%amt2, 64, %amt2;\n\t" -++ "shr.b64 \t%rhs, $src, %amt2;\n\t" -++ "add.u64 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; -++ -++def ROTR64reg_sw : -++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), -++ "{{\n\t" -++ ".reg .b64 %lhs;\n\t" -++ ".reg .b64 %rhs;\n\t" -++ ".reg .u32 %amt2;\n\t" -++ "and.b32 \t%amt2, $amt, 63;\n\t" -++ "shr.b64 \t%lhs, $src, %amt2;\n\t" -++ "sub.u32 \t%amt2, 64, %amt2;\n\t" -++ "shl.b64 \t%rhs, $src, %amt2;\n\t" -++ "add.u64 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; -++ -++// -++// Funnnel shift in clamp mode -++// -++ -++// Create SDNodes so they can be used in the DAG code, e.g. -++// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) -++def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; -++def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; -++ -++def FUNSHFLCLAMP : -++ NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -++ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", -++ [(set Int32Regs:$dst, -++ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; -++ -++def FUNSHFRCLAMP : -++ NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -++ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", -++ [(set Int32Regs:$dst, -++ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; -+ -+ // -+ // BFE - bit-field extract -+@@ -3492,42 +3657,6 @@ -+ def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), -+ (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; -+ -+-// -+-// Funnel-Shift -+-// -+- -+-// Create SDNodes so they can be used in the DAG code, e.g. -+-// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) -+-def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; -+-def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; -+- -+-// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so -+-// no side effects. -+-let hasSideEffects = false in { -+- multiclass ShfInst { -+- def _i -+- : NVPTXInst<(outs Int32Regs:$dst), -+- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -+- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", -+- [(set Int32Regs:$dst, -+- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, -+- Requires<[hasHWROT32]>; -+- -+- def _r -+- : NVPTXInst<(outs Int32Regs:$dst), -+- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", -+- [(set Int32Regs:$dst, -+- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, -+- Requires<[hasHWROT32]>; -+- } -+- -+- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; -+- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; -+- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; -+- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; -+-} -+- -+ // Count leading zeros -+ let hasSideEffects = false in { -+ def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -+--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -++++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -+@@ -2537,45 +2537,59 @@ -+ : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; -+ -+ -+-multiclass NG_TO_G { -++multiclass NG_TO_G { -+ def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), -+- "cvta." # Str # ".u32 \t$result, $src;", []>; -++ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), -++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; -+ def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), -+- "cvta." # Str # ".u64 \t$result, $src;", []>; -++ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), -++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; -++ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), -++ "{{ .reg .b64 %tmp;\n\t" -++ #" cvt.u64.u32 \t%tmp, $src;\n\t" -++ #" cvta." # Str # ".u64 \t$result, %tmp; }}", -++ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, -++ Requires<[ShortPtr]>; -+ } -+ -+-multiclass G_TO_NG { -++multiclass G_TO_NG { -+ def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), -+- "cvta.to." # Str # ".u32 \t$result, $src;", []>; -++ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), -++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; -+ def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), -+- "cvta.to." # Str # ".u64 \t$result, $src;", []>; -++ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), -++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; -++ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), -++ "{{ .reg .b64 %tmp;\n\t" -++ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" -++ #" cvt.u32.u64 \t$result, %tmp; }}", -++ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, -++ Requires<[ShortPtr]>; -+ } -+ -+-defm cvta_local : NG_TO_G<"local">; -+-defm cvta_shared : NG_TO_G<"shared">; -+-defm cvta_global : NG_TO_G<"global">; -+-defm cvta_const : NG_TO_G<"const">; -+- -+-defm cvta_to_local : G_TO_NG<"local">; -+-defm cvta_to_shared : G_TO_NG<"shared">; -+-defm cvta_to_global : G_TO_NG<"global">; -+-defm cvta_to_const : G_TO_NG<"const">; -+- -+-// nvvm.ptr.param.to.gen -+-defm cvta_param : NG_TO_G<"param">; -+- -+-def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), -+- (cvta_param Int32Regs:$src)>; -+- -+-def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), -+- (cvta_param_64 Int64Regs:$src)>; -++defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; -++defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; -++defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; -++defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; -++defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; -++ -++defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; -++defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; -++defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; -++defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; -+ -+ // nvvm.ptr.gen.to.param -+-def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), -+- (IMOV32rr Int32Regs:$src)>; -++def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), -++ (ins Int32Regs:$src), -++ "mov.u32 \t$result, $src;", -++ [(set Int32Regs:$result, -++ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; -++def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), -++ (ins Int64Regs:$src), -++ "mov.u64 \t$result, $src;", -++ [(set Int64Regs:$result, -++ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; -+ -+-def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), -+- (IMOV64rr Int64Regs:$src)>; -+ -+ // nvvm.move intrinsicc -+ def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), -+@@ -2618,6 +2632,24 @@ -+ [(set Int64Regs:$r, -+ (int_nvvm_move_ptr texternalsym:$s))]>;*/ -+ -++ -++// MoveParam %r1, param -++// ptr_local_to_gen %r2, %r1 -++// ptr_gen_to_local %r3, %r2 -++// -> -++// mov %r1, param -++ -++// @TODO: Revisit this. There is a type -++// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym -++// instructions are not currently defined. However, we can use the ptr -++// variants and the asm printer will do the right thing. -++def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen -++ (MoveParam texternalsym:$src)))), -++ (nvvm_move_ptr64 texternalsym:$src)>; -++def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen -++ (MoveParam texternalsym:$src)))), -++ (nvvm_move_ptr32 texternalsym:$src)>; -++ -+ def texsurf_handles -+ : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), -+ "mov.u64 \t$result, $src;", []>; -+@@ -2701,9 +2733,134 @@ -+ def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; -+ -+ -++// rotate builtin support -++ -++def ROTATE_B32_HW_IMM -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$src, i32imm:$amt), -++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, -++ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, -++ Requires<[hasHWROT32]> ; -++ -++def ROTATE_B32_HW_REG -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$src, Int32Regs:$amt), -++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, -++ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, -++ Requires<[hasHWROT32]> ; -++ -++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), -++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, -++ Requires<[noHWROT32]> ; -++ -++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), -++ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, -++ Requires<[noHWROT32]> ; -++ -++let hasSideEffects = false in { -++ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), -++ !strconcat("{{\n\t", -++ ".reg .b32 %dummy;\n\t", -++ "mov.b64 \t{$dst,%dummy}, $src;\n\t", -++ "}}"), -++ []> ; -++ -++ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), -++ !strconcat("{{\n\t", -++ ".reg .b32 %dummy;\n\t", -++ "mov.b64 \t{%dummy,$dst}, $src;\n\t", -++ "}}"), -++ []> ; -++} -++ -++let hasSideEffects = false in { -++ def PACK_TWO_INT32 -++ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), -++ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; -++} -++ -+ def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), -+- (V2I32toI64 (I64toI32H Int64Regs:$src), -+- (I64toI32L Int64Regs:$src))> ; -++ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src))> ; -++ -++// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so -++// no side effects. -++let hasSideEffects = false in { -++ def SHF_L_WRAP_B32_IMM -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -++ Requires<[hasHWROT32]>; -++ -++ def SHF_L_WRAP_B32_REG -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -++ Requires<[hasHWROT32]>; -++ -++ def SHF_R_WRAP_B32_IMM -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -++ Requires<[hasHWROT32]>; -++ -++ def SHF_R_WRAP_B32_REG -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -++ Requires<[hasHWROT32]>; -++} -++ -++// HW version of rotate 64 -++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), -++ (PACK_TWO_INT32 -++ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src), imm:$amt), -++ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), -++ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, -++ Requires<[hasHWROT32]>; -++ -++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), -++ (PACK_TWO_INT32 -++ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), -++ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), -++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, -++ Requires<[hasHWROT32]>; -++ -++ -++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), -++ (PACK_TWO_INT32 -++ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), -++ (GET_HI_INT64 Int64Regs:$src), imm:$amt), -++ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, -++ Requires<[hasHWROT32]>; -++ -++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), -++ (PACK_TWO_INT32 -++ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), -++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), -++ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, -++ Requires<[hasHWROT32]>; -++ -++// SW version of rotate 64 -++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), -++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, -++ Requires<[noHWROT32]>; -++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), -++ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, -++ Requires<[noHWROT32]>; -++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), -++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, -++ Requires<[noHWROT32]>; -++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), -++ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, -++ Requires<[noHWROT32]>; -++ -+ -+ //----------------------------------- -+ // Texture Intrinsics -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -+--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -++++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -+@@ -1109,21 +1109,11 @@ -+ AddrSpaceCastSDNode *CastN = cast(N); -+ unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); -+ unsigned DstAddrSpace = CastN->getDestAddressSpace(); -+- SDLoc DL(N); -+ assert(SrcAddrSpace != DstAddrSpace && -+ "addrspacecast must be between different address spaces"); -+ -+ if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { -+ // Specific to generic -+- -+- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { -+- SDValue CvtNone = -+- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); -+- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, -+- Src, CvtNone); -+- Src = SDValue(Cvt, 0); -+- } -+- -+ unsigned Opc; -+ switch (SrcAddrSpace) { -+ default: report_fatal_error("Bad address space in addrspacecast"); -+@@ -1131,16 +1121,26 @@ -+ Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; -+ break; -+ case ADDRESS_SPACE_SHARED: -+- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -++ ? NVPTX::cvta_shared_6432 -++ : NVPTX::cvta_shared_64) -++ : NVPTX::cvta_shared; -+ break; -+ case ADDRESS_SPACE_CONST: -+- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -++ ? NVPTX::cvta_const_6432 -++ : NVPTX::cvta_const_64) -++ : NVPTX::cvta_const; -+ break; -+ case ADDRESS_SPACE_LOCAL: -+- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -++ ? NVPTX::cvta_local_6432 -++ : NVPTX::cvta_local_64) -++ : NVPTX::cvta_local; -+ break; -+ } -+- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); -++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -++ Src)); -+ return; -+ } else { -+ // Generic to specific -+@@ -1153,28 +1153,30 @@ -+ Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; -+ break; -+ case ADDRESS_SPACE_SHARED: -+- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -++ ? NVPTX::cvta_to_shared_3264 -++ : NVPTX::cvta_to_shared_64) -++ : NVPTX::cvta_to_shared; -+ break; -+ case ADDRESS_SPACE_CONST: -+- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -++ ? NVPTX::cvta_to_const_3264 -++ : NVPTX::cvta_to_const_64) -++ : NVPTX::cvta_to_const; -+ break; -+ case ADDRESS_SPACE_LOCAL: -+- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -++ ? NVPTX::cvta_to_local_3264 -++ : NVPTX::cvta_to_local_64) -++ : NVPTX::cvta_to_local; -+ break; -+ case ADDRESS_SPACE_PARAM: -+- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; -++ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 -++ : NVPTX::nvvm_ptr_gen_to_param; -+ break; -+ } -+- -+- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); -+- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { -+- SDValue CvtNone = -+- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); -+- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, -+- SDValue(CVTA, 0), CvtNone); -+- } -+- -+- ReplaceNode(N, CVTA); -++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -++ Src)); -+ return; -+ } -+ } -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -+--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -++++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -+@@ -594,13 +594,20 @@ -+ setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); -+ setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); -+ -+- setOperationAction({ISD::ROTL, ISD::ROTR}, -+- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, -+- Expand); -+- -+- if (STI.hasHWROT32()) -+- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); -++ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs -++ // that don't have h/w rotation we lower them to multi-instruction assembly. -++ // See ROT*_sw in NVPTXIntrInfo.td -++ setOperationAction(ISD::ROTL, MVT::i64, Legal); -++ setOperationAction(ISD::ROTR, MVT::i64, Legal); -++ setOperationAction(ISD::ROTL, MVT::i32, Legal); -++ setOperationAction(ISD::ROTR, MVT::i32, Legal); -+ -++ setOperationAction(ISD::ROTL, MVT::i16, Expand); -++ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); -++ setOperationAction(ISD::ROTR, MVT::i16, Expand); -++ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); -++ setOperationAction(ISD::ROTL, MVT::i8, Expand); -++ setOperationAction(ISD::ROTR, MVT::i8, Expand); -+ setOperationAction(ISD::BSWAP, MVT::i16, Expand); -+ -+ setOperationAction(ISD::BR_JT, MVT::Other, Custom); -+diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -+--- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -++++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -+@@ -31,19 +31,6 @@ -+ declare i64 @llvm.nvvm.bitcast.d2ll(double) -+ declare double @llvm.nvvm.bitcast.ll2d(i64) -+ -+-declare i32 @llvm.nvvm.rotate.b32(i32, i32) -+-declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) -+-declare i64 @llvm.nvvm.rotate.b64(i64, i32) -+- -+-declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -+-declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) -+-declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) -+-declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) -+-declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) -+-declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) -+-declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) -+-declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) -+- -+ ; CHECK-LABEL: @simple_upgrade -+ define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { -+ ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) -+@@ -152,42 +139,4 @@ -+ %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) -+ -+ ret void -+-} -+- -+-; CHECK-LABEL: @rotate -+-define void @rotate(i32 %a, i64 %b) { -+-; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) -+-; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) -+-; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) -+-; -+- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) -+- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) -+- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) -+- ret void -+-} -+- -+-; CHECK-LABEL: @addrspacecast -+-define void @addrspacecast(ptr %p0) { -+-; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) -+-; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr -+-; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) -+-; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr -+-; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) -+-; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr -+-; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) -+-; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr -+-; -+- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) -+- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) -+- -+- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) -+- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) -+- -+- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) -+- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) -+- -+- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) -+- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) -+- -+- ret void -+-} -++} -+\ No newline at end of file -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll -+--- a/llvm/test/CodeGen/AMDGPU/freeze.ll -++++ b/llvm/test/CodeGen/AMDGPU/freeze.ll -+@@ -1,1856 +0,0 @@ -+-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -+-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s -+-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s -+-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s -+-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s -+- -+-define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v2i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v2i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <2 x i32> %a -+- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v3i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v3i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <3 x i32> %a -+- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v4i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v4i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <4 x i32> %a -+- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v5i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v5i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v5i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v5i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <5 x i32> %a -+- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v6i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v6i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v6i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v6i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <6 x i32> %a -+- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v7i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v7i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v7i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v7i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <7 x i32> %a -+- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v8i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v8i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v8i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v8i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <8 x i32> %a -+- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v9i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x2 -+-; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v9i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x2 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v9i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x2 -+-; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v9i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x2 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <9 x i32> %a -+- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v10i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: s_clause 0x2 -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 -+-; GFX10-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v10i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: s_clause 0x2 -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 -+-; GFX11-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <10 x i32> %a -+- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v11i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: s_clause 0x2 -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 -+-; GFX10-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v11i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: s_clause 0x2 -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 -+-; GFX11-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <11 x i32> %a -+- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v12i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: s_clause 0x2 -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v12i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: s_clause 0x2 -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <12 x i32> %a -+- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+-define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v13i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x3 -+-; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v13i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x3 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v13i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x3 -+-; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v13i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x3 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <13 x i32> %a -+- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v14i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x3 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v14i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x3 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v14i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x3 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v14i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x3 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <14 x i32> %a -+- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v15i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x3 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v15i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x3 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v15i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x3 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v15i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x3 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <15 x i32> %a -+- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v16i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x3 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v16i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x3 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v16i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x3 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v16i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x3 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <16 x i32> %a -+- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v17i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x4 -+-; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v17i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x4 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v17i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x4 -+-; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v17i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x4 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <17 x i32> %a -+- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v18i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x4 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v18i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x4 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v18i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x4 -+-; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v18i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x4 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <18 x i32> %a -+- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v19i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x4 -+-; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v19i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x4 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v19i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x4 -+-; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v19i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x4 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <19 x i32> %a -+- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v20i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x4 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v20i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x4 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v20i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x4 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v20i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x4 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <20 x i32> %a -+- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v21i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x5 -+-; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v21i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x5 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v21i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x5 -+-; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v21i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x5 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <21 x i32> %a -+- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v22i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x5 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v22i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x5 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v22i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x5 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v22i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x5 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <22 x i32> %a -+- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v30i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x7 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v30i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x7 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v30i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x7 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 -+-; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v30i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x7 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <30 x i32> %a -+- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v31i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x7 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 -+-; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v31i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x7 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 -+-; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v31i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x7 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 -+-; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v31i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x7 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 -+-; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <31 x i32> %a -+- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v32i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x7 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v32i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x7 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v32i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x7 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v32i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x7 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 -+-; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <32 x i32> %a -+- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dword v0, v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dword v[2:3], v0, off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b32 v0, v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b32 v[2:3], v0, off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load i32, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze i32 %a -+- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_i64: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_i64: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load i64, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze i64 %a -+- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_float: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dword v0, v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dword v[2:3], v0, off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_float: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b32 v0, v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b32 v[2:3], v0, off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load float, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze float %a -+- store float %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_i128: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_i128: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load i128, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze i128 %a -+- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_i256: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_i256: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_i256: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_i256: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load i256, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze i256 %a -+- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -+@@ -171,9 +171,11 @@ -+ ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 -+ ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 -+ ; GCN-NEXT: {{ $}} -+- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF -+- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) -+- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) -++ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF -++ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 -++ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 -++ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 -++ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] -+ %0:sgpr(s192) = G_IMPLICIT_DEF -+ %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 -+ S_ENDPGM 0, implicit %1, implicit %2, implicit %3 -+@@ -292,11 +294,11 @@ -+ ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) -+ ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) -+ ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) -+- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) -+- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) -+- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) -+- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) -+- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) -++ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) -++ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) -++ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) -++ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) -++ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) -+ %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 -+ %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 -+ %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -+@@ -171,8 +171,12 @@ -+ -+ ; CHECK-LABEL: name: test_freeze_s448 -+ ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 -+- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] -+- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) -++ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) -++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) -++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) -++ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) -+ %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 -+ %1:_(s448) = G_TRUNC %0 -+ %2:_(s448) = G_FREEZE %1 -+@@ -395,12 +399,14 @@ -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_freeze_v33s32 -+- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -+- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] -+- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) -+- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) -++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) -++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) -++ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) -+ ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) -+ %0:_(<33 x s32>) = G_IMPLICIT_DEF -+ %1:_(<33 x s32>) = G_FREEZE %0 -+@@ -413,10 +419,12 @@ -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_freeze_v64s32 -+- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -+- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -+- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) -++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) -+ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) -+ %0:_(<64 x s32>) = G_IMPLICIT_DEF -+ %1:_(<64 x s32>) = G_FREEZE %0 -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -+@@ -135,9 +135,8 @@ -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_implicit_def_s448 -+- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) -+- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 -++ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 -+ ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) -+ %0:_(s448) = G_IMPLICIT_DEF -+ %1:_(s32) = G_EXTRACT %0, 0 -+@@ -297,6 +296,18 @@ -+ ... -+ -+ --- -++name: test_implicit_def_v17s32 -++body: | -++ bb.0: -++ -++ ; CHECK-LABEL: name: test_implicit_def_v17s32 -++ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) -++ %0:_(<17 x s32>) = G_IMPLICIT_DEF -++ S_NOP 0, implicit %0 -++... -++ -++--- -+ name: test_implicit_def_v32s32 -+ body: | -+ bb.0: -+@@ -317,9 +328,9 @@ -+ ; CHECK-LABEL: name: test_implicit_def_v33s32 -+ ; CHECK: liveins: $vgpr0_vgpr1 -+ ; CHECK-NEXT: {{ $}} -+- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 -+ ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) -+ ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) -+@@ -337,9 +348,10 @@ -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_implicit_def_v64s32 -+- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) -++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) -+ %0:_(<64 x s32>) = G_IMPLICIT_DEF -+ %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 -+ S_NOP 0, implicit %0, implicit %1 -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -+@@ -190,11 +190,13 @@ -+ ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 -+ ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 -+ ; CHECK-NEXT: {{ $}} -+- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 -+ ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 -+- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) -+ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 -+ ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) -+@@ -241,8 +243,10 @@ -+ ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 -+ ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) -+ ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) -+- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) -+ ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) -+ ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -+@@ -673,86 +673,88 @@ -+ ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) -+ ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 -+ ; CHECK-NEXT: {{ $}} -+- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 -+ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 -+ ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] -+- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+ ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 -+ ; CHECK-NEXT: G_BR %bb.2 -+ ; CHECK-NEXT: {{ $}} -+ ; CHECK-NEXT: bb.1: -+ ; CHECK-NEXT: successors: %bb.2(0x80000000) -+ ; CHECK-NEXT: {{ $}} -+- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] -+- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] -+- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] -+- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] -+- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] -+- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] -+- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] -+- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] -+- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] -+- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] -+- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] -+- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] -+- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] -+- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] -+- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] -+- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] -+- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] -+- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] -+- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] -+- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] -+- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] -+- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] -+- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] -+- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] -+- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] -+- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] -+- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] -+- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] -+- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] -+- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] -+- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] -+- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] -+- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] -+- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] -+- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] -+- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] -+- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] -+- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] -+- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] -+- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] -+- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] -+- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] -+- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] -+- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] -+- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] -+- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] -+- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] -+- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] -+- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] -+- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] -+- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] -+- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] -+- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] -+- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] -+- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] -+- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] -+- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] -+- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] -+- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] -+- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] -+- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] -+- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] -+- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] -+- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] -++ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] -++ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] -++ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] -++ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] -++ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] -++ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] -++ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] -++ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] -++ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] -++ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] -++ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] -++ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] -++ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] -++ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] -++ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] -++ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] -++ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] -++ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] -++ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] -++ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] -++ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] -++ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] -++ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] -++ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] -++ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] -++ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] -++ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] -++ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] -++ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] -++ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] -++ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] -++ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] -++ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] -++ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] -++ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] -++ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] -++ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] -++ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] -++ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] -++ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] -++ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] -++ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] -++ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] -++ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] -++ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] -++ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] -++ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] -++ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] -++ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] -++ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] -++ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] -++ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] -++ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] -++ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] -++ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] -++ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] -++ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] -++ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] -++ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] -++ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] -++ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] -++ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] -++ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] -+ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) -+ ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) -+ ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) -+@@ -760,10 +762,10 @@ -+ ; CHECK-NEXT: G_BR %bb.2 -+ ; CHECK-NEXT: {{ $}} -+ ; CHECK-NEXT: bb.2: -+- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 -+- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 -+- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 -+- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 -++ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 -++ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 -++ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 -++ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) -+ ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) -+ bb.0: -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -+@@ -42,6 +42,8 @@ -+ ret void -+ } -+ -++ define void @non_power_of_2() { ret void } -++ -+ define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { -+ ret void -+ } -+@@ -185,6 +187,23 @@ -+ ... -+ -+ --- -++name: non_power_of_2 -++legalized: true -++ -++body: | -++ bb.0: -++ ; CHECK-LABEL: name: non_power_of_2 -++ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 -++ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) -++ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 -++ %0:_(s448) = G_IMPLICIT_DEF -++ %1:_(s32) = G_EXTRACT %0:_(s448), 0 -++ $sgpr0 = COPY %1:_(s32) -++ SI_RETURN_TO_EPILOG $sgpr0 -++... -++ -++--- -+ name: load_constant_v4i16_from_8_align8 -+ legalized: true -+ -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -+--- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -++++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -+@@ -0,0 +1,21 @@ -++; RUN: opt < %s -O3 -S | FileCheck %s -++ -++; Address space intrinsics were erroneously marked NoCapture, leading to bad -++; optimizations (such as the store below being eliminated as dead code). This -++; test makes sure we don't regress. -++ -++declare void @foo(ptr addrspace(1)) -++ -++declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -++ -++; CHECK: @bar -++define void @bar() { -++ %t1 = alloca i32 -++; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) -++; CHECK-NEXT: store i32 10, ptr %t1 -++ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) -++ store i32 10, ptr %t1 -++ call void @foo(ptr addrspace(1) %t2) -++ ret void -++} -++ -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll -+--- a/llvm/test/CodeGen/NVPTX/rotate_64.ll -++++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll -+@@ -1,38 +1,25 @@ -+-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -+ ; RUN: llc < %s -march=nvptx64 | FileCheck %s -+ ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} -+ -+ declare i64 @llvm.nvvm.rotate.b64(i64, i32) -+ declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) -+ -++; CHECK: rotate64 -+ define i64 @rotate64(i64 %a, i32 %b) { -+-; CHECK-LABEL: rotate64( -+-; CHECK: { -+-; CHECK-NEXT: .reg .b64 %rd<5>; -+-; CHECK-EMPTY: -+-; CHECK-NEXT: // %bb.0: -+-; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; -+-; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; -+-; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; -+-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; -+-; CHECK-NEXT: ret; -++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; -++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; -++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; -++; CHECK: ret -+ %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) -+ ret i64 %val -+ } -+ -++; CHECK: rotateright64 -+ define i64 @rotateright64(i64 %a, i32 %b) { -+-; CHECK-LABEL: rotateright64( -+-; CHECK: { -+-; CHECK-NEXT: .reg .b64 %rd<5>; -+-; CHECK-EMPTY: -+-; CHECK-NEXT: // %bb.0: -+-; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; -+-; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; -+-; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; -+-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; -+-; CHECK-NEXT: ret; -++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; -++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; -++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; -++; CHECK: ret -+ %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) -+ ret i64 %val -+ } -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll -+--- a/llvm/test/CodeGen/NVPTX/rotate.ll -++++ b/llvm/test/CodeGen/NVPTX/rotate.ll -+@@ -9,29 +9,26 @@ -+ declare i64 @llvm.nvvm.rotate.b64(i64, i32) -+ declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) -+ -+-declare i64 @llvm.fshl.i64(i64, i64, i64) -+-declare i64 @llvm.fshr.i64(i64, i64, i64) -+-declare i32 @llvm.fshl.i32(i32, i32, i32) -+-declare i32 @llvm.fshr.i32(i32, i32, i32) -+- -+- -+ ; SM20: rotate32 -+ ; SM35: rotate32 -+ define i32 @rotate32(i32 %a, i32 %b) { -+ ; SM20-LABEL: rotate32( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<9>; -++; SM20-NEXT: .reg .b32 %r<4>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; -+-; SM20-NEXT: and.b32 %r3, %r2, 31; -+-; SM20-NEXT: shl.b32 %r4, %r1, %r3; -+-; SM20-NEXT: neg.s32 %r5, %r2; -+-; SM20-NEXT: and.b32 %r6, %r5, 31; -+-; SM20-NEXT: shr.u32 %r7, %r1, %r6; -+-; SM20-NEXT: or.b32 %r8, %r4, %r7; -+-; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b32 %lhs; -++; SM20-NEXT: .reg .b32 %rhs; -++; SM20-NEXT: .reg .b32 %amt2; -++; SM20-NEXT: shl.b32 %lhs, %r1, %r2; -++; SM20-NEXT: sub.s32 %amt2, 32, %r2; -++; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; -++; SM20-NEXT: add.u32 %r3, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotate32( -+@@ -53,36 +50,45 @@ -+ define i64 @rotate64(i64 %a, i32 %b) { -+ ; SM20-LABEL: rotate64( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b32 %r<2>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM20-NEXT: neg.s32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; -+-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: .reg .u32 %amt2; -++; SM20-NEXT: and.b32 %amt2, %r1, 63; -++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; -++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotate64( -+ ; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b32 %r<6>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; -+-; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM35-NEXT: neg.s32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; -+-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b32 %dummy; -++; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; -++; SM35-NEXT: } -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b32 %dummy; -++; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; -++; SM35-NEXT: } -++; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; -++; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; -++; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; -++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) -+ ret i64 %val -+@@ -93,36 +99,45 @@ -+ define i64 @rotateright64(i64 %a, i32 %b) { -+ ; SM20-LABEL: rotateright64( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b32 %r<2>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; -+-; SM20-NEXT: neg.s32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; -+-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: .reg .u32 %amt2; -++; SM20-NEXT: and.b32 %amt2, %r1, 63; -++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; -++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotateright64( -+ ; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b32 %r<6>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; -+-; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; -+-; SM35-NEXT: neg.s32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; -+-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b32 %dummy; -++; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; -++; SM35-NEXT: } -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b32 %dummy; -++; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; -++; SM35-NEXT: } -++; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; -++; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; -++; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; -++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) -+ ret i64 %val -+@@ -133,14 +148,18 @@ -+ define i32 @rotl0(i32 %x) { -+ ; SM20-LABEL: rotl0( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -++; SM20-NEXT: .reg .b32 %r<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; -+-; SM20-NEXT: shr.u32 %r2, %r1, 24; -+-; SM20-NEXT: shl.b32 %r3, %r1, 8; -+-; SM20-NEXT: or.b32 %r4, %r3, %r2; -+-; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b32 %lhs; -++; SM20-NEXT: .reg .b32 %rhs; -++; SM20-NEXT: shl.b32 %lhs, %r1, 8; -++; SM20-NEXT: shr.b32 %rhs, %r1, 24; -++; SM20-NEXT: add.u32 %r2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotl0( -+@@ -158,40 +177,51 @@ -+ ret i32 %t2 -+ } -+ -++declare i64 @llvm.fshl.i64(i64, i64, i64) -++declare i64 @llvm.fshr.i64(i64, i64, i64) -++ -+ ; SM35: rotl64 -+ define i64 @rotl64(i64 %a, i64 %n) { -+ ; SM20-LABEL: rotl64( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b32 %r<2>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM20-NEXT: neg.s32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; -+-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: .reg .u32 %amt2; -++; SM20-NEXT: and.b32 %amt2, %r1, 63; -++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; -++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotl64( -+ ; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b32 %r<2>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; -+ ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM35-NEXT: neg.s32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; -+-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b64 %lhs; -++; SM35-NEXT: .reg .b64 %rhs; -++; SM35-NEXT: .reg .u32 %amt2; -++; SM35-NEXT: and.b32 %amt2, %r1, 63; -++; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; -++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; -++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM35-NEXT: } -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) -+ ret i64 %val -+@@ -201,26 +231,34 @@ -+ define i64 @rotl64_imm(i64 %a) { -+ ; SM20-LABEL: rotl64_imm( -+ ; SM20: { -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; -+-; SM20-NEXT: shr.u64 %rd2, %rd1, 62; -+-; SM20-NEXT: shl.b64 %rd3, %rd1, 2; -+-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: shl.b64 %lhs, %rd1, 2; -++; SM20-NEXT: shr.b64 %rhs, %rd1, 62; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotl64_imm( -+ ; SM35: { -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; -+-; SM35-NEXT: shr.u64 %rd2, %rd1, 62; -+-; SM35-NEXT: shl.b64 %rd3, %rd1, 2; -+-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b64 %lhs; -++; SM35-NEXT: .reg .b64 %rhs; -++; SM35-NEXT: shl.b64 %lhs, %rd1, 2; -++; SM35-NEXT: shr.b64 %rhs, %rd1, 62; -++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM35-NEXT: } -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) -+ ret i64 %val -+@@ -230,36 +268,44 @@ -+ define i64 @rotr64(i64 %a, i64 %n) { -+ ; SM20-LABEL: rotr64( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b32 %r<2>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; -+-; SM20-NEXT: neg.s32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; -+-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: .reg .u32 %amt2; -++; SM20-NEXT: and.b32 %amt2, %r1, 63; -++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; -++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotr64( -+ ; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b32 %r<2>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; -+ ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; -+-; SM35-NEXT: neg.s32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; -+-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b64 %lhs; -++; SM35-NEXT: .reg .b64 %rhs; -++; SM35-NEXT: .reg .u32 %amt2; -++; SM35-NEXT: and.b32 %amt2, %r1, 63; -++; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; -++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; -++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM35-NEXT: } -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) -+ ret i64 %val -+@@ -269,180 +315,35 @@ -+ define i64 @rotr64_imm(i64 %a) { -+ ; SM20-LABEL: rotr64_imm( -+ ; SM20: { -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; -+-; SM20-NEXT: shl.b64 %rd2, %rd1, 62; -+-; SM20-NEXT: shr.u64 %rd3, %rd1, 2; -+-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: shl.b64 %lhs, %rd1, 62; -++; SM20-NEXT: shr.b64 %rhs, %rd1, 2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotr64_imm( -+ ; SM35: { -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; -+-; SM35-NEXT: shl.b64 %rd2, %rd1, 62; -+-; SM35-NEXT: shr.u64 %rd3, %rd1, 2; -+-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b64 %lhs; -++; SM35-NEXT: .reg .b64 %rhs; -++; SM35-NEXT: shl.b64 %lhs, %rd1, 62; -++; SM35-NEXT: shr.b64 %rhs, %rd1, 2; -++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM35-NEXT: } -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) -+ ret i64 %val -+ } -+- -+-define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { -+-; SM20-LABEL: funnel_shift_right_32( -+-; SM20: { -+-; SM20-NEXT: .reg .b32 %r<11>; -+-; SM20-EMPTY: -+-; SM20-NEXT: // %bb.0: -+-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; -+-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; -+-; SM20-NEXT: and.b32 %r3, %r2, 31; -+-; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; -+-; SM20-NEXT: shr.u32 %r5, %r4, %r3; -+-; SM20-NEXT: shl.b32 %r6, %r1, 1; -+-; SM20-NEXT: not.b32 %r7, %r2; -+-; SM20-NEXT: and.b32 %r8, %r7, 31; -+-; SM20-NEXT: shl.b32 %r9, %r6, %r8; -+-; SM20-NEXT: or.b32 %r10, %r9, %r5; -+-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; -+-; SM20-NEXT: ret; -+-; -+-; SM35-LABEL: funnel_shift_right_32( -+-; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-EMPTY: -+-; SM35-NEXT: // %bb.0: -+-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; -+-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; -+-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; -+-; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; -+-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; -+-; SM35-NEXT: ret; -+- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) -+- ret i32 %val -+-} -+- -+-define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { -+-; SM20-LABEL: funnel_shift_left_32( -+-; SM20: { -+-; SM20-NEXT: .reg .b32 %r<11>; -+-; SM20-EMPTY: -+-; SM20-NEXT: // %bb.0: -+-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; -+-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; -+-; SM20-NEXT: and.b32 %r3, %r2, 31; -+-; SM20-NEXT: shl.b32 %r4, %r1, %r3; -+-; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; -+-; SM20-NEXT: shr.u32 %r6, %r5, 1; -+-; SM20-NEXT: not.b32 %r7, %r2; -+-; SM20-NEXT: and.b32 %r8, %r7, 31; -+-; SM20-NEXT: shr.u32 %r9, %r6, %r8; -+-; SM20-NEXT: or.b32 %r10, %r4, %r9; -+-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; -+-; SM20-NEXT: ret; -+-; -+-; SM35-LABEL: funnel_shift_left_32( -+-; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-EMPTY: -+-; SM35-NEXT: // %bb.0: -+-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; -+-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; -+-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; -+-; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; -+-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; -+-; SM35-NEXT: ret; -+- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) -+- ret i32 %val -+-} -+- -+-define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { -+-; SM20-LABEL: funnel_shift_right_64( -+-; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<7>; -+-; SM20-EMPTY: -+-; SM20-NEXT: // %bb.0: -+-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; -+-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; -+-; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; -+-; SM20-NEXT: shl.b64 %rd4, %rd1, 1; -+-; SM20-NEXT: not.b32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; -+-; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; -+-; SM20-NEXT: ret; -+-; -+-; SM35-LABEL: funnel_shift_right_64( -+-; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<7>; -+-; SM35-EMPTY: -+-; SM35-NEXT: // %bb.0: -+-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; -+-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; -+-; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; -+-; SM35-NEXT: shl.b64 %rd4, %rd1, 1; -+-; SM35-NEXT: not.b32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; -+-; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; -+-; SM35-NEXT: ret; -+- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) -+- ret i64 %val -+-} -+- -+-define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { -+-; SM20-LABEL: funnel_shift_left_64( -+-; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<7>; -+-; SM20-EMPTY: -+-; SM20-NEXT: // %bb.0: -+-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; -+-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; -+-; SM20-NEXT: shr.u64 %rd4, %rd3, 1; -+-; SM20-NEXT: not.b32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; -+-; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; -+-; SM20-NEXT: ret; -+-; -+-; SM35-LABEL: funnel_shift_left_64( -+-; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<7>; -+-; SM35-EMPTY: -+-; SM35-NEXT: // %bb.0: -+-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; -+-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; -+-; SM35-NEXT: shr.u64 %rd4, %rd3, 1; -+-; SM35-NEXT: not.b32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; -+-; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; -+-; SM35-NEXT: ret; -+- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) -+- ret i64 %val -+-} -+- -+diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll -+--- a/llvm/test/DebugInfo/NVPTX/debug-info.ll -++++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll -+@@ -25,10 +25,6 @@ -+ ; CHECK-DAG: .reg .b64 %rd<8>; -+ ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 -+ ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; -+-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -+-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -+-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -+-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -+ ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 -+ ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; -+ ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 -+@@ -42,6 +38,10 @@ -+ ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 -+ ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; -+ ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; -++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -+ ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 -+ ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; -+ ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; -+@@ -2661,22 +2661,22 @@ -+ ; CHECK-NEXT:.b32 4579 // DW_AT_type -+ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine -+ ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin -+-; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc -+-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc -++; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc -++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc -+ ; CHECK-NEXT:.b8 1 // DW_AT_call_file -+ ; CHECK-NEXT:.b8 6 // DW_AT_call_line -+ ; CHECK-NEXT:.b8 11 // DW_AT_call_column -+ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine -+ ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin -+-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc -+-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc -++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc -++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc -+ ; CHECK-NEXT:.b8 1 // DW_AT_call_file -+ ; CHECK-NEXT:.b8 6 // DW_AT_call_line -+ ; CHECK-NEXT:.b8 24 // DW_AT_call_column -+ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine -+ ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin -+-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc -+-; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc -++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc -++; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc -+ ; CHECK-NEXT:.b8 1 // DW_AT_call_file -+ ; CHECK-NEXT:.b8 6 // DW_AT_call_line -+ ; CHECK-NEXT:.b8 37 // DW_AT_call_column +-diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst +---- a/llvm/docs/NVPTXUsage.rst +-+++ b/llvm/docs/NVPTXUsage.rst +-@@ -127,6 +127,69 @@ +- NVPTX Intrinsics +- ================ +- +-+Address Space Conversion +-+------------------------ +-+ +-+'``llvm.nvvm.ptr.*.to.gen``' Intrinsics +-+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +-+ +-+Syntax: +-+""""""" +-+ +-+These are overloaded intrinsics. You can use these on any pointer types. +-+ +-+.. code-block:: llvm +-+ +-+ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) +-+ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) +-+ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) +-+ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) +-+ +-+Overview: +-+""""""""" +-+ +-+The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic +-+address space to a generic address space pointer. +-+ +-+Semantics: +-+"""""""""" +-+ +-+These intrinsics modify the pointer value to be a valid generic address space +-+pointer. +-+ +-+ +-+'``llvm.nvvm.ptr.gen.to.*``' Intrinsics +-+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +-+ +-+Syntax: +-+""""""" +-+ +-+These are overloaded intrinsics. You can use these on any pointer types. +-+ +-+.. code-block:: llvm +-+ +-+ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +-+ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) +-+ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) +-+ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) +-+ +-+Overview: +-+""""""""" +-+ +-+The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic +-+address space to a pointer in the target address space. Note that these +-+intrinsics are only useful if the address space of the target address space of +-+the pointer is known. It is not legal to use address space conversion +-+intrinsics to convert a pointer from one non-generic address space to another +-+non-generic address space. +-+ +-+Semantics: +-+"""""""""" +-+ +-+These intrinsics modify the pointer value to be a valid pointer in the target +-+non-generic address space. +-+ +-+ +- Reading PTX Special Registers +- ----------------------------- +- +-diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst +---- a/llvm/docs/ReleaseNotes.rst +-+++ b/llvm/docs/ReleaseNotes.rst +-@@ -63,24 +63,6 @@ +- * ``llvm.nvvm.bitcast.d2ll`` +- * ``llvm.nvvm.bitcast.ll2d`` +- +--* Remove the following intrinsics which can be replaced with a funnel-shift: +-- +-- * ``llvm.nvvm.rotate.b32`` +-- * ``llvm.nvvm.rotate.right.b64`` +-- * ``llvm.nvvm.rotate.b64`` +-- +--* Remove the following intrinsics which can be replaced with an +-- ``addrspacecast``: +-- +-- * ``llvm.nvvm.ptr.gen.to.global`` +-- * ``llvm.nvvm.ptr.gen.to.shared`` +-- * ``llvm.nvvm.ptr.gen.to.constant`` +-- * ``llvm.nvvm.ptr.gen.to.local`` +-- * ``llvm.nvvm.ptr.global.to.gen`` +-- * ``llvm.nvvm.ptr.shared.to.gen`` +-- * ``llvm.nvvm.ptr.constant.to.gen`` +-- * ``llvm.nvvm.ptr.local.to.gen`` +-- +- Changes to LLVM infrastructure +- ------------------------------ +- +-diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td +---- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +-+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td +-@@ -30,18 +30,10 @@ +- // * llvm.nvvm.max.ui --> select(x ule y, x, y) +- // * llvm.nvvm.max.ull --> ibid. +- // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 +--// * llvm.nvvm.bitcast.f2i --> bitcast +--// * llvm.nvvm.bitcast.i2f --> ibid. +--// * llvm.nvvm.bitcast.d2ll --> ibid. +--// * llvm.nvvm.bitcast.ll2d --> ibid. +--// * llvm.nvvm.ptr.gen.to.global --> addrspacecast +--// * llvm.nvvm.ptr.gen.to.shared --> ibid. +--// * llvm.nvvm.ptr.gen.to.constant --> ibid. +--// * llvm.nvvm.ptr.gen.to.local --> ibid. +--// * llvm.nvvm.ptr.global.to.gen --> ibid. +--// * llvm.nvvm.ptr.shared.to.gen --> ibid. +--// * llvm.nvvm.ptr.constant.to.gen --> ibid. +--// * llvm.nvvm.ptr.local.to.gen --> ibid. +-+// * llvm.nvvm.bitcast.f2i --> bitcast +-+// * llvm.nvvm.bitcast.i2f --> ibid. +-+// * llvm.nvvm.bitcast.d2ll --> ibid. +-+// * llvm.nvvm.bitcast.ll2d --> ibid. +- +- def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr +- def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr +-@@ -1610,6 +1602,40 @@ +- [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], +- "llvm.nvvm.ldg.global.p">; +- +-+// Use for generic pointers +-+// - These intrinsics are used to convert address spaces. +-+// - The input pointer and output pointer must have the same type, except for +-+// the address-space. (This restriction is not enforced here as there is +-+// currently no way to describe it). +-+// - This complements the llvm bitcast, which can be used to cast one type +-+// of pointer to another type of pointer, while the address space remains +-+// the same. +-+def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.local.to.gen">; +-+def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.shared.to.gen">; +-+def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.global.to.gen">; +-+def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.constant.to.gen">; +-+ +-+def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.gen.to.global">; +-+def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.gen.to.shared">; +-+def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.gen.to.local">; +-+def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.gen.to.constant">; +-+ +- // Used in nvvm internally to help address space opt and ptx code generation +- // This is for params that are passed to kernel functions by pointer by-val. +- def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], +-@@ -4453,6 +4479,22 @@ +- "llvm.nvvm.sust.p.3d.v4i32.trap">, +- ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; +- +-+ +-+def int_nvvm_rotate_b32 +-+ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], +-+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, +-+ ClangBuiltin<"__nvvm_rotate_b32">; +-+ +-+def int_nvvm_rotate_b64 +-+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], +-+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, +-+ ClangBuiltin<"__nvvm_rotate_b64">; +-+ +-+def int_nvvm_rotate_right_b64 +-+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], +-+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, +-+ ClangBuiltin<"__nvvm_rotate_right_b64">; +-+ +- def int_nvvm_swap_lo_hi_b64 +- : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], +- [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, +-diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp +---- a/llvm/lib/IR/AutoUpgrade.cpp +-+++ b/llvm/lib/IR/AutoUpgrade.cpp +-@@ -1272,19 +1272,6 @@ +- // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} +- Expand = +- Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; +-- else if (Name.consume_front("rotate.")) +-- // nvvm.rotate.{b32,b64,right.b64} +-- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; +-- else if (Name.consume_front("ptr.gen.to.")) +-- // nvvm.ptr.gen.to.{local,shared,global,constant} +-- Expand = Name.starts_with("local") || Name.starts_with("shared") || +-- Name.starts_with("global") || Name.starts_with("constant"); +-- else if (Name.consume_front("ptr.")) +-- // nvvm.ptr.{local,shared,global,constant}.to.gen +-- Expand = +-- (Name.consume_front("local") || Name.consume_front("shared") || +-- Name.consume_front("global") || Name.consume_front("constant")) && +-- Name.starts_with(".to.gen"); +- else +- Expand = false; +- +-@@ -2271,117 +2258,6 @@ +- } +- } +- +--static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, +-- Function *F, IRBuilder<> &Builder) { +-- Value *Rep = nullptr; +-- +-- if (Name == "abs.i" || Name == "abs.ll") { +-- Value *Arg = CI->getArgOperand(0); +-- Value *Neg = Builder.CreateNeg(Arg, "neg"); +-- Value *Cmp = Builder.CreateICmpSGE( +-- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); +-- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); +-- } else if (Name.starts_with("atomic.load.add.f32.p") || +-- Name.starts_with("atomic.load.add.f64.p")) { +-- Value *Ptr = CI->getArgOperand(0); +-- Value *Val = CI->getArgOperand(1); +-- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), +-- AtomicOrdering::SequentiallyConsistent); +-- } else if (Name.consume_front("max.") && +-- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +-- Name == "ui" || Name == "ull")) { +-- Value *Arg0 = CI->getArgOperand(0); +-- Value *Arg1 = CI->getArgOperand(1); +-- Value *Cmp = Name.starts_with("u") +-- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") +-- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); +-- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); +-- } else if (Name.consume_front("min.") && +-- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +-- Name == "ui" || Name == "ull")) { +-- Value *Arg0 = CI->getArgOperand(0); +-- Value *Arg1 = CI->getArgOperand(1); +-- Value *Cmp = Name.starts_with("u") +-- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") +-- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); +-- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); +-- } else if (Name == "clz.ll") { +-- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. +-- Value *Arg = CI->getArgOperand(0); +-- Value *Ctlz = Builder.CreateCall( +-- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, +-- {Arg->getType()}), +-- {Arg, Builder.getFalse()}, "ctlz"); +-- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); +-- } else if (Name == "popc.ll") { +-- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an +-- // i64. +-- Value *Arg = CI->getArgOperand(0); +-- Value *Popc = Builder.CreateCall( +-- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, +-- {Arg->getType()}), +-- Arg, "ctpop"); +-- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); +-- } else if (Name == "h2f") { +-- Rep = Builder.CreateCall( +-- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, +-- {Builder.getFloatTy()}), +-- CI->getArgOperand(0), "h2f"); +-- } else if (Name.consume_front("bitcast.") && +-- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || +-- Name == "d2ll")) { +-- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); +-- } else if (Name == "rotate.b32") { +-- Value *Arg = CI->getOperand(0); +-- Value *ShiftAmt = CI->getOperand(1); +-- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, +-- {Arg, Arg, ShiftAmt}); +-- } else if (Name == "rotate.b64") { +-- Type *Int64Ty = Builder.getInt64Ty(); +-- Value *Arg = CI->getOperand(0); +-- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); +-- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, +-- {Arg, Arg, ZExtShiftAmt}); +-- } else if (Name == "rotate.right.b64") { +-- Type *Int64Ty = Builder.getInt64Ty(); +-- Value *Arg = CI->getOperand(0); +-- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); +-- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, +-- {Arg, Arg, ZExtShiftAmt}); +-- } else if ((Name.consume_front("ptr.gen.to.") && +-- (Name.starts_with("local") || Name.starts_with("shared") || +-- Name.starts_with("global") || Name.starts_with("constant"))) || +-- (Name.consume_front("ptr.") && +-- (Name.consume_front("local") || Name.consume_front("shared") || +-- Name.consume_front("global") || +-- Name.consume_front("constant")) && +-- Name.starts_with(".to.gen"))) { +-- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); +-- } else { +-- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); +-- if (IID != Intrinsic::not_intrinsic && +-- !F->getReturnType()->getScalarType()->isBFloatTy()) { +-- rename(F); +-- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); +-- SmallVector Args; +-- for (size_t I = 0; I < NewFn->arg_size(); ++I) { +-- Value *Arg = CI->getArgOperand(I); +-- Type *OldType = Arg->getType(); +-- Type *NewType = NewFn->getArg(I)->getType(); +-- Args.push_back( +-- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) +-- ? Builder.CreateBitCast(Arg, NewType) +-- : Arg); +-- } +-- Rep = Builder.CreateCall(NewFn, Args); +-- if (F->getReturnType()->isIntegerTy()) +-- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); +-- } +-- } +-- +-- return Rep; +--} +-- +- static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, +- IRBuilder<> &Builder) { +- LLVMContext &C = F->getContext(); +-@@ -4332,8 +4208,85 @@ +- +- if (!IsX86 && Name == "stackprotectorcheck") { +- Rep = nullptr; +-+ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { +-+ Value *Arg = CI->getArgOperand(0); +-+ Value *Neg = Builder.CreateNeg(Arg, "neg"); +-+ Value *Cmp = Builder.CreateICmpSGE( +-+ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); +-+ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); +-+ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || +-+ Name.starts_with("atomic.load.add.f64.p"))) { +-+ Value *Ptr = CI->getArgOperand(0); +-+ Value *Val = CI->getArgOperand(1); +-+ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), +-+ AtomicOrdering::SequentiallyConsistent); +-+ } else if (IsNVVM && Name.consume_front("max.") && +-+ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +-+ Name == "ui" || Name == "ull")) { +-+ Value *Arg0 = CI->getArgOperand(0); +-+ Value *Arg1 = CI->getArgOperand(1); +-+ Value *Cmp = Name.starts_with("u") +-+ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") +-+ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); +-+ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); +-+ } else if (IsNVVM && Name.consume_front("min.") && +-+ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +-+ Name == "ui" || Name == "ull")) { +-+ Value *Arg0 = CI->getArgOperand(0); +-+ Value *Arg1 = CI->getArgOperand(1); +-+ Value *Cmp = Name.starts_with("u") +-+ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") +-+ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); +-+ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); +-+ } else if (IsNVVM && Name == "clz.ll") { +-+ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. +-+ Value *Arg = CI->getArgOperand(0); +-+ Value *Ctlz = Builder.CreateCall( +-+ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, +-+ {Arg->getType()}), +-+ {Arg, Builder.getFalse()}, "ctlz"); +-+ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); +-+ } else if (IsNVVM && Name == "popc.ll") { +-+ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an +-+ // i64. +-+ Value *Arg = CI->getArgOperand(0); +-+ Value *Popc = Builder.CreateCall( +-+ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, +-+ {Arg->getType()}), +-+ Arg, "ctpop"); +-+ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); +- } else if (IsNVVM) { +-- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); +-+ if (Name == "h2f") { +-+ Rep = +-+ Builder.CreateCall(Intrinsic::getDeclaration( +-+ F->getParent(), Intrinsic::convert_from_fp16, +-+ {Builder.getFloatTy()}), +-+ CI->getArgOperand(0), "h2f"); +-+ } else if (Name.consume_front("bitcast.") && +-+ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || +-+ Name == "d2ll")) { +-+ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); +-+ } else { +-+ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); +-+ if (IID != Intrinsic::not_intrinsic && +-+ !F->getReturnType()->getScalarType()->isBFloatTy()) { +-+ rename(F); +-+ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); +-+ SmallVector Args; +-+ for (size_t I = 0; I < NewFn->arg_size(); ++I) { +-+ Value *Arg = CI->getArgOperand(I); +-+ Type *OldType = Arg->getType(); +-+ Type *NewType = NewFn->getArg(I)->getType(); +-+ Args.push_back((OldType->isIntegerTy() && +-+ NewType->getScalarType()->isBFloatTy()) +-+ ? Builder.CreateBitCast(Arg, NewType) +-+ : Arg); +-+ } +-+ Rep = Builder.CreateCall(NewFn, Args); +-+ if (F->getReturnType()->isIntegerTy()) +-+ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); +-+ } +-+ } +- } else if (IsX86) { +- Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); +- } else if (IsARM) { +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +---- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +-+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +-@@ -292,7 +292,6 @@ +- static const LLT S224 = LLT::scalar(224); +- static const LLT S256 = LLT::scalar(256); +- static const LLT S512 = LLT::scalar(512); +--static const LLT S1024 = LLT::scalar(1024); +- static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); +- +- static const LLT V2S8 = LLT::fixed_vector(2, 8); +-@@ -333,8 +332,8 @@ +- static const LLT V2S128 = LLT::fixed_vector(2, 128); +- static const LLT V4S128 = LLT::fixed_vector(4, 128); +- +--static std::initializer_list AllScalarTypes = { +-- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; +-+static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, +-+ S160, S224, S256, S512}; +- +- static std::initializer_list AllS16Vectors{ +- V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; +-@@ -890,11 +889,10 @@ +- .clampScalar(0, S16, S64); +- +- getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) +-- .legalIf(isRegisterClassType(0)) +-+ .legalIf(isRegisterType(0)) +- // s1 and s16 are special cases because they have legal operations on +- // them, but don't really occupy registers in the normal way. +- .legalFor({S1, S16}) +-- .clampNumElements(0, V16S32, V32S32) +- .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) +- .clampScalarOrElt(0, S32, MaxScalar) +- .widenScalarToNextPow2(0, 32) +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +---- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +-+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +-@@ -174,6 +174,10 @@ +- def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" +- "&& Subtarget->getPTXVersion() >= 64)">; +- +-+def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; +-+def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; +-+def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; +-+ +- def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; +- def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; +- +-@@ -1661,6 +1665,167 @@ +- "brev.b64 \t$dst, $a;", +- [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; +- +-+// +-+// Rotate: Use ptx shf instruction if available. +-+// +-+ +-+// 32 bit r2 = rotl r1, n +-+// => +-+// r2 = shf.l r1, r1, n +-+def ROTL32imm_hw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), +-+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, +-+ Requires<[hasHWROT32]>; +-+ +-+def ROTL32reg_hw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +-+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +-+ Requires<[hasHWROT32]>; +-+ +-+// 32 bit r2 = rotr r1, n +-+// => +-+// r2 = shf.r r1, r1, n +-+def ROTR32imm_hw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), +-+ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, +-+ Requires<[hasHWROT32]>; +-+ +-+def ROTR32reg_hw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +-+ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +-+ Requires<[hasHWROT32]>; +-+ +-+// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. +-+def ROT32imm_sw : +-+ NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), +-+ "{{\n\t" +-+ ".reg .b32 %lhs;\n\t" +-+ ".reg .b32 %rhs;\n\t" +-+ "shl.b32 \t%lhs, $src, $amt1;\n\t" +-+ "shr.b32 \t%rhs, $src, $amt2;\n\t" +-+ "add.u32 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ []>; +-+ +-+def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); +-+}]>; +-+ +-+def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), +-+ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, +-+ Requires<[noHWROT32]>; +-+def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), +-+ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, +-+ Requires<[noHWROT32]>; +-+ +-+// 32-bit software rotate left by register. +-+def ROTL32reg_sw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +-+ "{{\n\t" +-+ ".reg .b32 %lhs;\n\t" +-+ ".reg .b32 %rhs;\n\t" +-+ ".reg .b32 %amt2;\n\t" +-+ "shl.b32 \t%lhs, $src, $amt;\n\t" +-+ "sub.s32 \t%amt2, 32, $amt;\n\t" +-+ "shr.b32 \t%rhs, $src, %amt2;\n\t" +-+ "add.u32 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +-+ Requires<[noHWROT32]>; +-+ +-+// 32-bit software rotate right by register. +-+def ROTR32reg_sw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +-+ "{{\n\t" +-+ ".reg .b32 %lhs;\n\t" +-+ ".reg .b32 %rhs;\n\t" +-+ ".reg .b32 %amt2;\n\t" +-+ "shr.b32 \t%lhs, $src, $amt;\n\t" +-+ "sub.s32 \t%amt2, 32, $amt;\n\t" +-+ "shl.b32 \t%rhs, $src, %amt2;\n\t" +-+ "add.u32 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +-+ Requires<[noHWROT32]>; +-+ +-+// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. +-+def ROT64imm_sw : +-+ NVPTXInst<(outs Int64Regs:$dst), +-+ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), +-+ "{{\n\t" +-+ ".reg .b64 %lhs;\n\t" +-+ ".reg .b64 %rhs;\n\t" +-+ "shl.b64 \t%lhs, $src, $amt1;\n\t" +-+ "shr.b64 \t%rhs, $src, $amt2;\n\t" +-+ "add.u64 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ []>; +-+ +-+def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); +-+}]>; +-+ +-+def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), +-+ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; +-+def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), +-+ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; +-+ +-+// 64-bit software rotate left by register. +-+def ROTL64reg_sw : +-+ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), +-+ "{{\n\t" +-+ ".reg .b64 %lhs;\n\t" +-+ ".reg .b64 %rhs;\n\t" +-+ ".reg .u32 %amt2;\n\t" +-+ "and.b32 \t%amt2, $amt, 63;\n\t" +-+ "shl.b64 \t%lhs, $src, %amt2;\n\t" +-+ "sub.u32 \t%amt2, 64, %amt2;\n\t" +-+ "shr.b64 \t%rhs, $src, %amt2;\n\t" +-+ "add.u64 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; +-+ +-+def ROTR64reg_sw : +-+ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), +-+ "{{\n\t" +-+ ".reg .b64 %lhs;\n\t" +-+ ".reg .b64 %rhs;\n\t" +-+ ".reg .u32 %amt2;\n\t" +-+ "and.b32 \t%amt2, $amt, 63;\n\t" +-+ "shr.b64 \t%lhs, $src, %amt2;\n\t" +-+ "sub.u32 \t%amt2, 64, %amt2;\n\t" +-+ "shl.b64 \t%rhs, $src, %amt2;\n\t" +-+ "add.u64 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; +-+ +-+// +-+// Funnnel shift in clamp mode +-+// +-+ +-+// Create SDNodes so they can be used in the DAG code, e.g. +-+// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) +-+def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; +-+def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; +-+ +-+def FUNSHFLCLAMP : +-+ NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-+ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", +-+ [(set Int32Regs:$dst, +-+ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; +-+ +-+def FUNSHFRCLAMP : +-+ NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-+ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", +-+ [(set Int32Regs:$dst, +-+ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; +- +- // +- // BFE - bit-field extract +-@@ -3492,42 +3657,6 @@ +- def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), +- (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; +- +--// +--// Funnel-Shift +--// +-- +--// Create SDNodes so they can be used in the DAG code, e.g. +--// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) +--def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; +--def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; +-- +--// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so +--// no side effects. +--let hasSideEffects = false in { +-- multiclass ShfInst { +-- def _i +-- : NVPTXInst<(outs Int32Regs:$dst), +-- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +-- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", +-- [(set Int32Regs:$dst, +-- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, +-- Requires<[hasHWROT32]>; +-- +-- def _r +-- : NVPTXInst<(outs Int32Regs:$dst), +-- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", +-- [(set Int32Regs:$dst, +-- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, +-- Requires<[hasHWROT32]>; +-- } +-- +-- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; +-- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; +-- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; +-- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; +--} +-- +- // Count leading zeros +- let hasSideEffects = false in { +- def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +---- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +-+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +-@@ -2537,45 +2537,59 @@ +- : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; +- +- +--multiclass NG_TO_G { +-+multiclass NG_TO_G { +- def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), +-- "cvta." # Str # ".u32 \t$result, $src;", []>; +-+ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), +-+ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; +- def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), +-- "cvta." # Str # ".u64 \t$result, $src;", []>; +-+ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), +-+ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; +-+ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), +-+ "{{ .reg .b64 %tmp;\n\t" +-+ #" cvt.u64.u32 \t%tmp, $src;\n\t" +-+ #" cvta." # Str # ".u64 \t$result, %tmp; }}", +-+ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, +-+ Requires<[ShortPtr]>; +- } +- +--multiclass G_TO_NG { +-+multiclass G_TO_NG { +- def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), +-- "cvta.to." # Str # ".u32 \t$result, $src;", []>; +-+ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), +-+ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; +- def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), +-- "cvta.to." # Str # ".u64 \t$result, $src;", []>; +-+ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), +-+ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; +-+ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), +-+ "{{ .reg .b64 %tmp;\n\t" +-+ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" +-+ #" cvt.u32.u64 \t$result, %tmp; }}", +-+ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, +-+ Requires<[ShortPtr]>; +- } +- +--defm cvta_local : NG_TO_G<"local">; +--defm cvta_shared : NG_TO_G<"shared">; +--defm cvta_global : NG_TO_G<"global">; +--defm cvta_const : NG_TO_G<"const">; +-- +--defm cvta_to_local : G_TO_NG<"local">; +--defm cvta_to_shared : G_TO_NG<"shared">; +--defm cvta_to_global : G_TO_NG<"global">; +--defm cvta_to_const : G_TO_NG<"const">; +-- +--// nvvm.ptr.param.to.gen +--defm cvta_param : NG_TO_G<"param">; +-- +--def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), +-- (cvta_param Int32Regs:$src)>; +-- +--def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), +-- (cvta_param_64 Int64Regs:$src)>; +-+defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; +-+defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; +-+defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; +-+defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; +-+defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; +-+ +-+defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; +-+defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; +-+defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; +-+defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; +- +- // nvvm.ptr.gen.to.param +--def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), +-- (IMOV32rr Int32Regs:$src)>; +-+def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), +-+ (ins Int32Regs:$src), +-+ "mov.u32 \t$result, $src;", +-+ [(set Int32Regs:$result, +-+ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; +-+def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), +-+ (ins Int64Regs:$src), +-+ "mov.u64 \t$result, $src;", +-+ [(set Int64Regs:$result, +-+ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; +- +--def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), +-- (IMOV64rr Int64Regs:$src)>; +- +- // nvvm.move intrinsicc +- def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), +-@@ -2618,6 +2632,24 @@ +- [(set Int64Regs:$r, +- (int_nvvm_move_ptr texternalsym:$s))]>;*/ +- +-+ +-+// MoveParam %r1, param +-+// ptr_local_to_gen %r2, %r1 +-+// ptr_gen_to_local %r3, %r2 +-+// -> +-+// mov %r1, param +-+ +-+// @TODO: Revisit this. There is a type +-+// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym +-+// instructions are not currently defined. However, we can use the ptr +-+// variants and the asm printer will do the right thing. +-+def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen +-+ (MoveParam texternalsym:$src)))), +-+ (nvvm_move_ptr64 texternalsym:$src)>; +-+def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen +-+ (MoveParam texternalsym:$src)))), +-+ (nvvm_move_ptr32 texternalsym:$src)>; +-+ +- def texsurf_handles +- : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), +- "mov.u64 \t$result, $src;", []>; +-@@ -2701,9 +2733,134 @@ +- def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; +- +- +-+// rotate builtin support +-+ +-+def ROTATE_B32_HW_IMM +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$src, i32imm:$amt), +-+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, +-+ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, +-+ Requires<[hasHWROT32]> ; +-+ +-+def ROTATE_B32_HW_REG +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$src, Int32Regs:$amt), +-+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, +-+ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, +-+ Requires<[hasHWROT32]> ; +-+ +-+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), +-+ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, +-+ Requires<[noHWROT32]> ; +-+ +-+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), +-+ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, +-+ Requires<[noHWROT32]> ; +-+ +-+let hasSideEffects = false in { +-+ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), +-+ !strconcat("{{\n\t", +-+ ".reg .b32 %dummy;\n\t", +-+ "mov.b64 \t{$dst,%dummy}, $src;\n\t", +-+ "}}"), +-+ []> ; +-+ +-+ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), +-+ !strconcat("{{\n\t", +-+ ".reg .b32 %dummy;\n\t", +-+ "mov.b64 \t{%dummy,$dst}, $src;\n\t", +-+ "}}"), +-+ []> ; +-+} +-+ +-+let hasSideEffects = false in { +-+ def PACK_TWO_INT32 +-+ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), +-+ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; +-+} +-+ +- def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), +-- (V2I32toI64 (I64toI32H Int64Regs:$src), +-- (I64toI32L Int64Regs:$src))> ; +-+ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src))> ; +-+ +-+// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so +-+// no side effects. +-+let hasSideEffects = false in { +-+ def SHF_L_WRAP_B32_IMM +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +-+ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +-+ Requires<[hasHWROT32]>; +-+ +-+ def SHF_L_WRAP_B32_REG +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-+ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +-+ Requires<[hasHWROT32]>; +-+ +-+ def SHF_R_WRAP_B32_IMM +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +-+ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +-+ Requires<[hasHWROT32]>; +-+ +-+ def SHF_R_WRAP_B32_REG +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-+ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +-+ Requires<[hasHWROT32]>; +-+} +-+ +-+// HW version of rotate 64 +-+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), +-+ (PACK_TWO_INT32 +-+ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src), imm:$amt), +-+ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), +-+ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, +-+ Requires<[hasHWROT32]>; +-+ +-+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), +-+ (PACK_TWO_INT32 +-+ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), +-+ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), +-+ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, +-+ Requires<[hasHWROT32]>; +-+ +-+ +-+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), +-+ (PACK_TWO_INT32 +-+ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), +-+ (GET_HI_INT64 Int64Regs:$src), imm:$amt), +-+ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, +-+ Requires<[hasHWROT32]>; +-+ +-+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), +-+ (PACK_TWO_INT32 +-+ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), +-+ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), +-+ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, +-+ Requires<[hasHWROT32]>; +-+ +-+// SW version of rotate 64 +-+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), +-+ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, +-+ Requires<[noHWROT32]>; +-+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), +-+ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, +-+ Requires<[noHWROT32]>; +-+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), +-+ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, +-+ Requires<[noHWROT32]>; +-+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), +-+ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, +-+ Requires<[noHWROT32]>; +-+ +- +- //----------------------------------- +- // Texture Intrinsics +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +---- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +-+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +-@@ -1109,21 +1109,11 @@ +- AddrSpaceCastSDNode *CastN = cast(N); +- unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); +- unsigned DstAddrSpace = CastN->getDestAddressSpace(); +-- SDLoc DL(N); +- assert(SrcAddrSpace != DstAddrSpace && +- "addrspacecast must be between different address spaces"); +- +- if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { +- // Specific to generic +-- +-- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { +-- SDValue CvtNone = +-- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); +-- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, +-- Src, CvtNone); +-- Src = SDValue(Cvt, 0); +-- } +-- +- unsigned Opc; +- switch (SrcAddrSpace) { +- default: report_fatal_error("Bad address space in addrspacecast"); +-@@ -1131,16 +1121,26 @@ +- Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; +- break; +- case ADDRESS_SPACE_SHARED: +-- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +-+ ? NVPTX::cvta_shared_6432 +-+ : NVPTX::cvta_shared_64) +-+ : NVPTX::cvta_shared; +- break; +- case ADDRESS_SPACE_CONST: +-- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +-+ ? NVPTX::cvta_const_6432 +-+ : NVPTX::cvta_const_64) +-+ : NVPTX::cvta_const; +- break; +- case ADDRESS_SPACE_LOCAL: +-- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +-+ ? NVPTX::cvta_local_6432 +-+ : NVPTX::cvta_local_64) +-+ : NVPTX::cvta_local; +- break; +- } +-- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); +-+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), +-+ Src)); +- return; +- } else { +- // Generic to specific +-@@ -1153,28 +1153,30 @@ +- Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; +- break; +- case ADDRESS_SPACE_SHARED: +-- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +-+ ? NVPTX::cvta_to_shared_3264 +-+ : NVPTX::cvta_to_shared_64) +-+ : NVPTX::cvta_to_shared; +- break; +- case ADDRESS_SPACE_CONST: +-- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +-+ ? NVPTX::cvta_to_const_3264 +-+ : NVPTX::cvta_to_const_64) +-+ : NVPTX::cvta_to_const; +- break; +- case ADDRESS_SPACE_LOCAL: +-- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +-+ ? NVPTX::cvta_to_local_3264 +-+ : NVPTX::cvta_to_local_64) +-+ : NVPTX::cvta_to_local; +- break; +- case ADDRESS_SPACE_PARAM: +-- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; +-+ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 +-+ : NVPTX::nvvm_ptr_gen_to_param; +- break; +- } +-- +-- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); +-- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { +-- SDValue CvtNone = +-- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); +-- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, +-- SDValue(CVTA, 0), CvtNone); +-- } +-- +-- ReplaceNode(N, CVTA); +-+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), +-+ Src)); +- return; +- } +- } +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +---- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +-+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +-@@ -594,13 +594,20 @@ +- setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); +- setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); +- +-- setOperationAction({ISD::ROTL, ISD::ROTR}, +-- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, +-- Expand); +-- +-- if (STI.hasHWROT32()) +-- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); +-+ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs +-+ // that don't have h/w rotation we lower them to multi-instruction assembly. +-+ // See ROT*_sw in NVPTXIntrInfo.td +-+ setOperationAction(ISD::ROTL, MVT::i64, Legal); +-+ setOperationAction(ISD::ROTR, MVT::i64, Legal); +-+ setOperationAction(ISD::ROTL, MVT::i32, Legal); +-+ setOperationAction(ISD::ROTR, MVT::i32, Legal); +- +-+ setOperationAction(ISD::ROTL, MVT::i16, Expand); +-+ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); +-+ setOperationAction(ISD::ROTR, MVT::i16, Expand); +-+ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); +-+ setOperationAction(ISD::ROTL, MVT::i8, Expand); +-+ setOperationAction(ISD::ROTR, MVT::i8, Expand); +- setOperationAction(ISD::BSWAP, MVT::i16, Expand); +- +- setOperationAction(ISD::BR_JT, MVT::Other, Custom); +-diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +---- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +-+++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +-@@ -31,19 +31,6 @@ +- declare i64 @llvm.nvvm.bitcast.d2ll(double) +- declare double @llvm.nvvm.bitcast.ll2d(i64) +- +--declare i32 @llvm.nvvm.rotate.b32(i32, i32) +--declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) +--declare i64 @llvm.nvvm.rotate.b64(i64, i32) +-- +--declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +--declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) +--declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) +--declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) +--declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) +--declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) +--declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) +--declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) +-- +- ; CHECK-LABEL: @simple_upgrade +- define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { +- ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) +-@@ -152,42 +139,4 @@ +- %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) +- +- ret void +--} +-- +--; CHECK-LABEL: @rotate +--define void @rotate(i32 %a, i64 %b) { +--; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) +--; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) +--; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) +--; +-- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) +-- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) +-- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) +-- ret void +--} +-- +--; CHECK-LABEL: @addrspacecast +--define void @addrspacecast(ptr %p0) { +--; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) +--; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr +--; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) +--; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr +--; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) +--; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr +--; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) +--; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr +--; +-- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) +-- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) +-- +-- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) +-- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) +-- +-- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) +-- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) +-- +-- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) +-- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) +-- +-- ret void +--} +-+} +-\ No newline at end of file +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll +---- a/llvm/test/CodeGen/AMDGPU/freeze.ll +-+++ b/llvm/test/CodeGen/AMDGPU/freeze.ll +-@@ -1,1856 +0,0 @@ +--; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +--; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s +--; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s +--; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s +--; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s +-- +--define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v2i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v2i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <2 x i32> %a +-- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v3i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v3i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <3 x i32> %a +-- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v4i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v4i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <4 x i32> %a +-- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v5i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v5i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v5i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v5i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <5 x i32> %a +-- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v6i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v6i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v6i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v6i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <6 x i32> %a +-- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v7i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v7i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v7i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v7i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <7 x i32> %a +-- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v8i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v8i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v8i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v8i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <8 x i32> %a +-- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v9i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x2 +--; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v9i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x2 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v9i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x2 +--; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v9i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x2 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <9 x i32> %a +-- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v10i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: s_clause 0x2 +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 +--; GFX10-NEXT: s_waitcnt vmcnt(2) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_waitcnt vmcnt(1) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v10i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: s_clause 0x2 +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 +--; GFX11-NEXT: s_waitcnt vmcnt(2) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_waitcnt vmcnt(1) +--; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <10 x i32> %a +-- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v11i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: s_clause 0x2 +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 +--; GFX10-NEXT: s_waitcnt vmcnt(2) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_waitcnt vmcnt(1) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v11i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: s_clause 0x2 +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 +--; GFX11-NEXT: s_waitcnt vmcnt(2) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_waitcnt vmcnt(1) +--; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <11 x i32> %a +-- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v12i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: s_clause 0x2 +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-NEXT: s_waitcnt vmcnt(2) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_waitcnt vmcnt(1) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v12i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: s_clause 0x2 +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-NEXT: s_waitcnt vmcnt(2) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_waitcnt vmcnt(1) +--; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <12 x i32> %a +-- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +--define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v13i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x3 +--; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v13i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x3 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v13i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x3 +--; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v13i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x3 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <13 x i32> %a +-- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v14i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x3 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v14i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x3 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v14i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x3 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v14i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x3 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <14 x i32> %a +-- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v15i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x3 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v15i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x3 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v15i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x3 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v15i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x3 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <15 x i32> %a +-- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v16i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x3 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v16i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x3 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v16i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x3 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v16i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x3 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <16 x i32> %a +-- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v17i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x4 +--; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v17i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x4 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v17i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x4 +--; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v17i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x4 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <17 x i32> %a +-- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v18i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x4 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v18i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x4 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v18i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x4 +--; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v18i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x4 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <18 x i32> %a +-- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v19i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x4 +--; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v19i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x4 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v19i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x4 +--; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v19i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x4 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <19 x i32> %a +-- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v20i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x4 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v20i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x4 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v20i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x4 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v20i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x4 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <20 x i32> %a +-- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v21i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x5 +--; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v21i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x5 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v21i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x5 +--; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v21i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x5 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <21 x i32> %a +-- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v22i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x5 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v22i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x5 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v22i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x5 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v22i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x5 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <22 x i32> %a +-- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v30i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x7 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v30i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x7 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v30i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x7 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +--; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v30i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x7 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <30 x i32> %a +-- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v31i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x7 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +--; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v31i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x7 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +--; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v31i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x7 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +--; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v31i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x7 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +--; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <31 x i32> %a +-- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v32i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x7 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v32i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x7 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v32i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x7 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v32i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x7 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +--; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <32 x i32> %a +-- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dword v0, v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dword v[2:3], v0, off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b32 v0, v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b32 v[2:3], v0, off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load i32, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze i32 %a +-- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_i64: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_i64: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load i64, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze i64 %a +-- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_float: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dword v0, v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dword v[2:3], v0, off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_float: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b32 v0, v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b32 v[2:3], v0, off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load float, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze float %a +-- store float %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_i128: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_i128: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load i128, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze i128 %a +-- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_i256: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_i256: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_i256: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_i256: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load i256, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze i256 %a +-- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +-@@ -171,9 +171,11 @@ +- ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 +- ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 +- ; GCN-NEXT: {{ $}} +-- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF +-- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) +-- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) +-+ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF +-+ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 +-+ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 +-+ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 +-+ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] +- %0:sgpr(s192) = G_IMPLICIT_DEF +- %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 +- S_ENDPGM 0, implicit %1, implicit %2, implicit %3 +-@@ -292,11 +294,11 @@ +- ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) +- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) +- ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) +-- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) +-- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) +-- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) +-- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) +-- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) +-+ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) +-+ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) +-+ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) +-+ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) +-+ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) +- %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 +- %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 +- %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +-@@ -171,8 +171,12 @@ +- +- ; CHECK-LABEL: name: test_freeze_s448 +- ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 +-- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] +-- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) +-+ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) +-+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) +-+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) +-+ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) +- %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 +- %1:_(s448) = G_TRUNC %0 +- %2:_(s448) = G_FREEZE %1 +-@@ -395,12 +399,14 @@ +- bb.0: +- +- ; CHECK-LABEL: name: test_freeze_v33s32 +-- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +-- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] +-- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) +-- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) +-+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) +-+ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) +- ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) +- %0:_(<33 x s32>) = G_IMPLICIT_DEF +- %1:_(<33 x s32>) = G_FREEZE %0 +-@@ -413,10 +419,12 @@ +- bb.0: +- +- ; CHECK-LABEL: name: test_freeze_v64s32 +-- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +-- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +-- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) +-+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) +- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) +- %0:_(<64 x s32>) = G_IMPLICIT_DEF +- %1:_(<64 x s32>) = G_FREEZE %0 +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +-@@ -135,9 +135,8 @@ +- bb.0: +- +- ; CHECK-LABEL: name: test_implicit_def_s448 +-- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) +-- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 +-+ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 +- ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) +- %0:_(s448) = G_IMPLICIT_DEF +- %1:_(s32) = G_EXTRACT %0, 0 +-@@ -297,6 +296,18 @@ +- ... +- +- --- +-+name: test_implicit_def_v17s32 +-+body: | +-+ bb.0: +-+ +-+ ; CHECK-LABEL: name: test_implicit_def_v17s32 +-+ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) +-+ %0:_(<17 x s32>) = G_IMPLICIT_DEF +-+ S_NOP 0, implicit %0 +-+... +-+ +-+--- +- name: test_implicit_def_v32s32 +- body: | +- bb.0: +-@@ -317,9 +328,9 @@ +- ; CHECK-LABEL: name: test_implicit_def_v33s32 +- ; CHECK: liveins: $vgpr0_vgpr1 +- ; CHECK-NEXT: {{ $}} +-- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +- ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 +- ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) +- ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) +-@@ -337,9 +348,10 @@ +- bb.0: +- +- ; CHECK-LABEL: name: test_implicit_def_v64s32 +-- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) +-+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) +- %0:_(<64 x s32>) = G_IMPLICIT_DEF +- %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 +- S_NOP 0, implicit %0, implicit %1 +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +-@@ -190,11 +190,13 @@ +- ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 +- ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 +- ; CHECK-NEXT: {{ $}} +-- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 +- ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 +-- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +- ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) +- ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 +- ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) +-@@ -241,8 +243,10 @@ +- ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 +- ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) +- ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) +-- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +- ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) +- ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) +- ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +-@@ -673,86 +673,88 @@ +- ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) +- ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 +- ; CHECK-NEXT: {{ $}} +-- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 +- ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 +- ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] +-- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 +- ; CHECK-NEXT: G_BR %bb.2 +- ; CHECK-NEXT: {{ $}} +- ; CHECK-NEXT: bb.1: +- ; CHECK-NEXT: successors: %bb.2(0x80000000) +- ; CHECK-NEXT: {{ $}} +-- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] +-- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] +-- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] +-- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] +-- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] +-- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] +-- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] +-- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] +-- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] +-- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] +-- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] +-- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] +-- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] +-- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] +-- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] +-- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] +-- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] +-- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] +-- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] +-- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] +-- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] +-- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] +-- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] +-- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] +-- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] +-- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] +-- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] +-- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] +-- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] +-- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] +-- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] +-- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] +-- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] +-- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] +-- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] +-- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] +-- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] +-- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] +-- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] +-- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] +-- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] +-- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] +-- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] +-- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] +-- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] +-- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] +-- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] +-- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] +-- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] +-- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] +-- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] +-- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] +-- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] +-- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] +-- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] +-- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] +-- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] +-- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] +-- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] +-- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] +-- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] +-- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] +-- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] +-- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] +-+ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] +-+ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] +-+ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] +-+ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] +-+ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] +-+ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] +-+ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] +-+ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] +-+ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] +-+ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] +-+ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] +-+ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] +-+ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] +-+ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] +-+ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] +-+ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] +-+ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] +-+ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] +-+ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] +-+ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] +-+ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] +-+ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] +-+ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] +-+ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] +-+ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] +-+ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] +-+ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] +-+ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] +-+ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] +-+ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] +-+ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] +-+ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] +-+ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] +-+ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] +-+ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] +-+ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] +-+ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] +-+ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] +-+ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] +-+ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] +-+ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] +-+ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] +-+ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] +-+ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] +-+ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] +-+ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] +-+ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] +-+ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] +-+ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] +-+ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] +-+ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] +-+ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] +-+ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] +-+ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] +-+ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] +-+ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] +-+ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] +-+ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] +-+ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] +-+ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] +-+ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] +-+ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] +-+ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] +- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) +- ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) +- ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) +-@@ -760,10 +762,10 @@ +- ; CHECK-NEXT: G_BR %bb.2 +- ; CHECK-NEXT: {{ $}} +- ; CHECK-NEXT: bb.2: +-- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 +-- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 +-- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 +-- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 +-+ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 +-+ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 +-+ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 +-+ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 +- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) +- ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) +- bb.0: +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +-@@ -42,6 +42,8 @@ +- ret void +- } +- +-+ define void @non_power_of_2() { ret void } +-+ +- define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { +- ret void +- } +-@@ -185,6 +187,23 @@ +- ... +- +- --- +-+name: non_power_of_2 +-+legalized: true +-+ +-+body: | +-+ bb.0: +-+ ; CHECK-LABEL: name: non_power_of_2 +-+ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 +-+ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) +-+ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 +-+ %0:_(s448) = G_IMPLICIT_DEF +-+ %1:_(s32) = G_EXTRACT %0:_(s448), 0 +-+ $sgpr0 = COPY %1:_(s32) +-+ SI_RETURN_TO_EPILOG $sgpr0 +-+... +-+ +-+--- +- name: load_constant_v4i16_from_8_align8 +- legalized: true +- +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +---- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +-+++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +-@@ -0,0 +1,21 @@ +-+; RUN: opt < %s -O3 -S | FileCheck %s +-+ +-+; Address space intrinsics were erroneously marked NoCapture, leading to bad +-+; optimizations (such as the store below being eliminated as dead code). This +-+; test makes sure we don't regress. +-+ +-+declare void @foo(ptr addrspace(1)) +-+ +-+declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +-+ +-+; CHECK: @bar +-+define void @bar() { +-+ %t1 = alloca i32 +-+; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) +-+; CHECK-NEXT: store i32 10, ptr %t1 +-+ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) +-+ store i32 10, ptr %t1 +-+ call void @foo(ptr addrspace(1) %t2) +-+ ret void +-+} +-+ +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll +---- a/llvm/test/CodeGen/NVPTX/rotate_64.ll +-+++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll +-@@ -1,38 +1,25 @@ +--; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +- ; RUN: llc < %s -march=nvptx64 | FileCheck %s +- ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} +- +- declare i64 @llvm.nvvm.rotate.b64(i64, i32) +- declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) +- +-+; CHECK: rotate64 +- define i64 @rotate64(i64 %a, i32 %b) { +--; CHECK-LABEL: rotate64( +--; CHECK: { +--; CHECK-NEXT: .reg .b64 %rd<5>; +--; CHECK-EMPTY: +--; CHECK-NEXT: // %bb.0: +--; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; +--; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; +--; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; +--; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; +--; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; +--; CHECK-NEXT: ret; +-+; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; +-+; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; +-+; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; +-+; CHECK: ret +- %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) +- ret i64 %val +- } +- +-+; CHECK: rotateright64 +- define i64 @rotateright64(i64 %a, i32 %b) { +--; CHECK-LABEL: rotateright64( +--; CHECK: { +--; CHECK-NEXT: .reg .b64 %rd<5>; +--; CHECK-EMPTY: +--; CHECK-NEXT: // %bb.0: +--; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; +--; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; +--; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; +--; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; +--; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; +--; CHECK-NEXT: ret; +-+; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; +-+; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; +-+; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; +-+; CHECK: ret +- %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) +- ret i64 %val +- } +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll +---- a/llvm/test/CodeGen/NVPTX/rotate.ll +-+++ b/llvm/test/CodeGen/NVPTX/rotate.ll +-@@ -9,29 +9,26 @@ +- declare i64 @llvm.nvvm.rotate.b64(i64, i32) +- declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) +- +--declare i64 @llvm.fshl.i64(i64, i64, i64) +--declare i64 @llvm.fshr.i64(i64, i64, i64) +--declare i32 @llvm.fshl.i32(i32, i32, i32) +--declare i32 @llvm.fshr.i32(i32, i32, i32) +-- +-- +- ; SM20: rotate32 +- ; SM35: rotate32 +- define i32 @rotate32(i32 %a, i32 %b) { +- ; SM20-LABEL: rotate32( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<9>; +-+; SM20-NEXT: .reg .b32 %r<4>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; +- ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; +--; SM20-NEXT: and.b32 %r3, %r2, 31; +--; SM20-NEXT: shl.b32 %r4, %r1, %r3; +--; SM20-NEXT: neg.s32 %r5, %r2; +--; SM20-NEXT: and.b32 %r6, %r5, 31; +--; SM20-NEXT: shr.u32 %r7, %r1, %r6; +--; SM20-NEXT: or.b32 %r8, %r4, %r7; +--; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b32 %lhs; +-+; SM20-NEXT: .reg .b32 %rhs; +-+; SM20-NEXT: .reg .b32 %amt2; +-+; SM20-NEXT: shl.b32 %lhs, %r1, %r2; +-+; SM20-NEXT: sub.s32 %amt2, 32, %r2; +-+; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; +-+; SM20-NEXT: add.u32 %r3, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotate32( +-@@ -53,36 +50,45 @@ +- define i64 @rotate64(i64 %a, i32 %b) { +- ; SM20-LABEL: rotate64( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b32 %r<2>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; +- ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM20-NEXT: neg.s32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; +--; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: .reg .u32 %amt2; +-+; SM20-NEXT: and.b32 %amt2, %r1, 63; +-+; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; +-+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotate64( +- ; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b32 %r<6>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; +--; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM35-NEXT: neg.s32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; +--; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b32 %dummy; +-+; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; +-+; SM35-NEXT: } +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b32 %dummy; +-+; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; +-+; SM35-NEXT: } +-+; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; +-+; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; +-+; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; +-+; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) +- ret i64 %val +-@@ -93,36 +99,45 @@ +- define i64 @rotateright64(i64 %a, i32 %b) { +- ; SM20-LABEL: rotateright64( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b32 %r<2>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; +- ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; +--; SM20-NEXT: neg.s32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; +--; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: .reg .u32 %amt2; +-+; SM20-NEXT: and.b32 %amt2, %r1, 63; +-+; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; +-+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotateright64( +- ; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b32 %r<6>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; +--; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; +--; SM35-NEXT: neg.s32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; +--; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b32 %dummy; +-+; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; +-+; SM35-NEXT: } +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b32 %dummy; +-+; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; +-+; SM35-NEXT: } +-+; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; +-+; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; +-+; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; +-+; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) +- ret i64 %val +-@@ -133,14 +148,18 @@ +- define i32 @rotl0(i32 %x) { +- ; SM20-LABEL: rotl0( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +-+; SM20-NEXT: .reg .b32 %r<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; +--; SM20-NEXT: shr.u32 %r2, %r1, 24; +--; SM20-NEXT: shl.b32 %r3, %r1, 8; +--; SM20-NEXT: or.b32 %r4, %r3, %r2; +--; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b32 %lhs; +-+; SM20-NEXT: .reg .b32 %rhs; +-+; SM20-NEXT: shl.b32 %lhs, %r1, 8; +-+; SM20-NEXT: shr.b32 %rhs, %r1, 24; +-+; SM20-NEXT: add.u32 %r2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotl0( +-@@ -158,40 +177,51 @@ +- ret i32 %t2 +- } +- +-+declare i64 @llvm.fshl.i64(i64, i64, i64) +-+declare i64 @llvm.fshr.i64(i64, i64, i64) +-+ +- ; SM35: rotl64 +- define i64 @rotl64(i64 %a, i64 %n) { +- ; SM20-LABEL: rotl64( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b32 %r<2>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; +- ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM20-NEXT: neg.s32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; +--; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: .reg .u32 %amt2; +-+; SM20-NEXT: and.b32 %amt2, %r1, 63; +-+; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; +-+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotl64( +- ; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b32 %r<2>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; +- ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM35-NEXT: neg.s32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; +--; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b64 %lhs; +-+; SM35-NEXT: .reg .b64 %rhs; +-+; SM35-NEXT: .reg .u32 %amt2; +-+; SM35-NEXT: and.b32 %amt2, %r1, 63; +-+; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; +-+; SM35-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; +-+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM35-NEXT: } +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) +- ret i64 %val +-@@ -201,26 +231,34 @@ +- define i64 @rotl64_imm(i64 %a) { +- ; SM20-LABEL: rotl64_imm( +- ; SM20: { +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; +--; SM20-NEXT: shr.u64 %rd2, %rd1, 62; +--; SM20-NEXT: shl.b64 %rd3, %rd1, 2; +--; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: shl.b64 %lhs, %rd1, 2; +-+; SM20-NEXT: shr.b64 %rhs, %rd1, 62; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotl64_imm( +- ; SM35: { +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; +--; SM35-NEXT: shr.u64 %rd2, %rd1, 62; +--; SM35-NEXT: shl.b64 %rd3, %rd1, 2; +--; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b64 %lhs; +-+; SM35-NEXT: .reg .b64 %rhs; +-+; SM35-NEXT: shl.b64 %lhs, %rd1, 2; +-+; SM35-NEXT: shr.b64 %rhs, %rd1, 62; +-+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM35-NEXT: } +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) +- ret i64 %val +-@@ -230,36 +268,44 @@ +- define i64 @rotr64(i64 %a, i64 %n) { +- ; SM20-LABEL: rotr64( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b32 %r<2>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; +- ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; +--; SM20-NEXT: neg.s32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; +--; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: .reg .u32 %amt2; +-+; SM20-NEXT: and.b32 %amt2, %r1, 63; +-+; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; +-+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotr64( +- ; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b32 %r<2>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; +- ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; +--; SM35-NEXT: neg.s32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; +--; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b64 %lhs; +-+; SM35-NEXT: .reg .b64 %rhs; +-+; SM35-NEXT: .reg .u32 %amt2; +-+; SM35-NEXT: and.b32 %amt2, %r1, 63; +-+; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; +-+; SM35-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; +-+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM35-NEXT: } +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) +- ret i64 %val +-@@ -269,180 +315,35 @@ +- define i64 @rotr64_imm(i64 %a) { +- ; SM20-LABEL: rotr64_imm( +- ; SM20: { +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; +--; SM20-NEXT: shl.b64 %rd2, %rd1, 62; +--; SM20-NEXT: shr.u64 %rd3, %rd1, 2; +--; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: shl.b64 %lhs, %rd1, 62; +-+; SM20-NEXT: shr.b64 %rhs, %rd1, 2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotr64_imm( +- ; SM35: { +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; +--; SM35-NEXT: shl.b64 %rd2, %rd1, 62; +--; SM35-NEXT: shr.u64 %rd3, %rd1, 2; +--; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b64 %lhs; +-+; SM35-NEXT: .reg .b64 %rhs; +-+; SM35-NEXT: shl.b64 %lhs, %rd1, 62; +-+; SM35-NEXT: shr.b64 %rhs, %rd1, 2; +-+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM35-NEXT: } +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) +- ret i64 %val +- } +-- +--define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { +--; SM20-LABEL: funnel_shift_right_32( +--; SM20: { +--; SM20-NEXT: .reg .b32 %r<11>; +--; SM20-EMPTY: +--; SM20-NEXT: // %bb.0: +--; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; +--; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; +--; SM20-NEXT: and.b32 %r3, %r2, 31; +--; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; +--; SM20-NEXT: shr.u32 %r5, %r4, %r3; +--; SM20-NEXT: shl.b32 %r6, %r1, 1; +--; SM20-NEXT: not.b32 %r7, %r2; +--; SM20-NEXT: and.b32 %r8, %r7, 31; +--; SM20-NEXT: shl.b32 %r9, %r6, %r8; +--; SM20-NEXT: or.b32 %r10, %r9, %r5; +--; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; +--; SM20-NEXT: ret; +--; +--; SM35-LABEL: funnel_shift_right_32( +--; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-EMPTY: +--; SM35-NEXT: // %bb.0: +--; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; +--; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; +--; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; +--; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; +--; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; +--; SM35-NEXT: ret; +-- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) +-- ret i32 %val +--} +-- +--define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { +--; SM20-LABEL: funnel_shift_left_32( +--; SM20: { +--; SM20-NEXT: .reg .b32 %r<11>; +--; SM20-EMPTY: +--; SM20-NEXT: // %bb.0: +--; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; +--; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; +--; SM20-NEXT: and.b32 %r3, %r2, 31; +--; SM20-NEXT: shl.b32 %r4, %r1, %r3; +--; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; +--; SM20-NEXT: shr.u32 %r6, %r5, 1; +--; SM20-NEXT: not.b32 %r7, %r2; +--; SM20-NEXT: and.b32 %r8, %r7, 31; +--; SM20-NEXT: shr.u32 %r9, %r6, %r8; +--; SM20-NEXT: or.b32 %r10, %r4, %r9; +--; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; +--; SM20-NEXT: ret; +--; +--; SM35-LABEL: funnel_shift_left_32( +--; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-EMPTY: +--; SM35-NEXT: // %bb.0: +--; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; +--; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; +--; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; +--; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; +--; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; +--; SM35-NEXT: ret; +-- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) +-- ret i32 %val +--} +-- +--define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { +--; SM20-LABEL: funnel_shift_right_64( +--; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<7>; +--; SM20-EMPTY: +--; SM20-NEXT: // %bb.0: +--; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; +--; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; +--; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; +--; SM20-NEXT: shl.b64 %rd4, %rd1, 1; +--; SM20-NEXT: not.b32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; +--; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; +--; SM20-NEXT: ret; +--; +--; SM35-LABEL: funnel_shift_right_64( +--; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<7>; +--; SM35-EMPTY: +--; SM35-NEXT: // %bb.0: +--; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; +--; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; +--; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; +--; SM35-NEXT: shl.b64 %rd4, %rd1, 1; +--; SM35-NEXT: not.b32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; +--; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; +--; SM35-NEXT: ret; +-- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) +-- ret i64 %val +--} +-- +--define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { +--; SM20-LABEL: funnel_shift_left_64( +--; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<7>; +--; SM20-EMPTY: +--; SM20-NEXT: // %bb.0: +--; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; +--; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; +--; SM20-NEXT: shr.u64 %rd4, %rd3, 1; +--; SM20-NEXT: not.b32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; +--; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; +--; SM20-NEXT: ret; +--; +--; SM35-LABEL: funnel_shift_left_64( +--; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<7>; +--; SM35-EMPTY: +--; SM35-NEXT: // %bb.0: +--; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; +--; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; +--; SM35-NEXT: shr.u64 %rd4, %rd3, 1; +--; SM35-NEXT: not.b32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; +--; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; +--; SM35-NEXT: ret; +-- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) +-- ret i64 %val +--} +-- +-diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll +---- a/llvm/test/DebugInfo/NVPTX/debug-info.ll +-+++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll +-@@ -25,10 +25,6 @@ +- ; CHECK-DAG: .reg .b64 %rd<8>; +- ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 +- ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; +--; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +--; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +--; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +--; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +- ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 +- ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; +- ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 +-@@ -42,6 +38,10 @@ +- ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 +- ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; +- ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; +-+; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +-+; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +-+; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +-+; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +- ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 +- ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; +- ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; +-@@ -2661,22 +2661,22 @@ +- ; CHECK-NEXT:.b32 4579 // DW_AT_type +- ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine +- ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin +--; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc +--; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc +-+; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc +-+; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc +- ; CHECK-NEXT:.b8 1 // DW_AT_call_file +- ; CHECK-NEXT:.b8 6 // DW_AT_call_line +- ; CHECK-NEXT:.b8 11 // DW_AT_call_column +- ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine +- ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin +--; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc +--; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc +-+; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc +-+; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc +- ; CHECK-NEXT:.b8 1 // DW_AT_call_file +- ; CHECK-NEXT:.b8 6 // DW_AT_call_line +- ; CHECK-NEXT:.b8 24 // DW_AT_call_column +- ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine +- ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin +--; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc +--; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc +-+; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc +-+; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc +- ; CHECK-NEXT:.b8 1 // DW_AT_call_file +- ; CHECK-NEXT:.b8 6 // DW_AT_call_line +- ; CHECK-NEXT:.b8 37 // DW_AT_call_column diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index abe15ef..af35fe7 100644 +index af35fe7..7b11086 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" -- LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" -+ LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" -+ LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" +- LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" +- LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" ++ LLVM_COMMIT = "29b92d07746fac26cd64c914bc9c5c3833974f6d" ++ LLVM_SHA256 = "3e8e93e3749454af4b64f7f34b792a4748b62fc533bca1703d33b2b04e34eb70" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 4f6a0785270667..3ffc08a6fd8eb9 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "22e68fa19cfb2d28434a75d4d20d0efc182b166a" - SHARDY_SHA256 = "2b47b0ee994feca2bd782e20aca7d709e29bc870c2ac435aca967f7664c9f949" + SHARDY_COMMIT = "c4642106cba935c06f437e542cb376bce8fbd16c" + SHARDY_SHA256 = "286661a749a4ed03dea624c15613bf615a357a720c5623e80f08479010cde42d" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index d3fd21823cce19..014b81b4e7518b 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,4115 +1,4115 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..de92cb4 100644 +index de92cb4..509398d 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1 +1,4095 @@ +@@ -1,4095 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst -+--- a/llvm/docs/NVPTXUsage.rst -++++ b/llvm/docs/NVPTXUsage.rst -+@@ -127,6 +127,69 @@ -+ NVPTX Intrinsics -+ ================ -+ -++Address Space Conversion -++------------------------ -++ -++'``llvm.nvvm.ptr.*.to.gen``' Intrinsics -++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -++ -++Syntax: -++""""""" -++ -++These are overloaded intrinsics. You can use these on any pointer types. -++ -++.. code-block:: llvm -++ -++ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) -++ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) -++ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) -++ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) -++ -++Overview: -++""""""""" -++ -++The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic -++address space to a generic address space pointer. -++ -++Semantics: -++"""""""""" -++ -++These intrinsics modify the pointer value to be a valid generic address space -++pointer. -++ -++ -++'``llvm.nvvm.ptr.gen.to.*``' Intrinsics -++^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -++ -++Syntax: -++""""""" -++ -++These are overloaded intrinsics. You can use these on any pointer types. -++ -++.. code-block:: llvm -++ -++ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -++ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) -++ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) -++ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) -++ -++Overview: -++""""""""" -++ -++The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic -++address space to a pointer in the target address space. Note that these -++intrinsics are only useful if the address space of the target address space of -++the pointer is known. It is not legal to use address space conversion -++intrinsics to convert a pointer from one non-generic address space to another -++non-generic address space. -++ -++Semantics: -++"""""""""" -++ -++These intrinsics modify the pointer value to be a valid pointer in the target -++non-generic address space. -++ -++ -+ Reading PTX Special Registers -+ ----------------------------- -+ -+diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst -+--- a/llvm/docs/ReleaseNotes.rst -++++ b/llvm/docs/ReleaseNotes.rst -+@@ -63,24 +63,6 @@ -+ * ``llvm.nvvm.bitcast.d2ll`` -+ * ``llvm.nvvm.bitcast.ll2d`` -+ -+-* Remove the following intrinsics which can be replaced with a funnel-shift: -+- -+- * ``llvm.nvvm.rotate.b32`` -+- * ``llvm.nvvm.rotate.right.b64`` -+- * ``llvm.nvvm.rotate.b64`` -+- -+-* Remove the following intrinsics which can be replaced with an -+- ``addrspacecast``: -+- -+- * ``llvm.nvvm.ptr.gen.to.global`` -+- * ``llvm.nvvm.ptr.gen.to.shared`` -+- * ``llvm.nvvm.ptr.gen.to.constant`` -+- * ``llvm.nvvm.ptr.gen.to.local`` -+- * ``llvm.nvvm.ptr.global.to.gen`` -+- * ``llvm.nvvm.ptr.shared.to.gen`` -+- * ``llvm.nvvm.ptr.constant.to.gen`` -+- * ``llvm.nvvm.ptr.local.to.gen`` -+- -+ Changes to LLVM infrastructure -+ ------------------------------ -+ -+diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td -+--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td -++++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td -+@@ -30,18 +30,10 @@ -+ // * llvm.nvvm.max.ui --> select(x ule y, x, y) -+ // * llvm.nvvm.max.ull --> ibid. -+ // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 -+-// * llvm.nvvm.bitcast.f2i --> bitcast -+-// * llvm.nvvm.bitcast.i2f --> ibid. -+-// * llvm.nvvm.bitcast.d2ll --> ibid. -+-// * llvm.nvvm.bitcast.ll2d --> ibid. -+-// * llvm.nvvm.ptr.gen.to.global --> addrspacecast -+-// * llvm.nvvm.ptr.gen.to.shared --> ibid. -+-// * llvm.nvvm.ptr.gen.to.constant --> ibid. -+-// * llvm.nvvm.ptr.gen.to.local --> ibid. -+-// * llvm.nvvm.ptr.global.to.gen --> ibid. -+-// * llvm.nvvm.ptr.shared.to.gen --> ibid. -+-// * llvm.nvvm.ptr.constant.to.gen --> ibid. -+-// * llvm.nvvm.ptr.local.to.gen --> ibid. -++// * llvm.nvvm.bitcast.f2i --> bitcast -++// * llvm.nvvm.bitcast.i2f --> ibid. -++// * llvm.nvvm.bitcast.d2ll --> ibid. -++// * llvm.nvvm.bitcast.ll2d --> ibid. -+ -+ def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr -+ def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr -+@@ -1610,6 +1602,40 @@ -+ [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], -+ "llvm.nvvm.ldg.global.p">; -+ -++// Use for generic pointers -++// - These intrinsics are used to convert address spaces. -++// - The input pointer and output pointer must have the same type, except for -++// the address-space. (This restriction is not enforced here as there is -++// currently no way to describe it). -++// - This complements the llvm bitcast, which can be used to cast one type -++// of pointer to another type of pointer, while the address space remains -++// the same. -++def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.local.to.gen">; -++def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.shared.to.gen">; -++def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.global.to.gen">; -++def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.constant.to.gen">; -++ -++def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.gen.to.global">; -++def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.gen.to.shared">; -++def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.gen.to.local">; -++def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -++ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -++ "llvm.nvvm.ptr.gen.to.constant">; -++ -+ // Used in nvvm internally to help address space opt and ptx code generation -+ // This is for params that are passed to kernel functions by pointer by-val. -+ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], -+@@ -4453,6 +4479,22 @@ -+ "llvm.nvvm.sust.p.3d.v4i32.trap">, -+ ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; -+ -++ -++def int_nvvm_rotate_b32 -++ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], -++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, -++ ClangBuiltin<"__nvvm_rotate_b32">; -++ -++def int_nvvm_rotate_b64 -++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], -++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, -++ ClangBuiltin<"__nvvm_rotate_b64">; -++ -++def int_nvvm_rotate_right_b64 -++ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], -++ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, -++ ClangBuiltin<"__nvvm_rotate_right_b64">; -++ -+ def int_nvvm_swap_lo_hi_b64 -+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], -+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, -+diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp -+--- a/llvm/lib/IR/AutoUpgrade.cpp -++++ b/llvm/lib/IR/AutoUpgrade.cpp -+@@ -1272,19 +1272,6 @@ -+ // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} -+ Expand = -+ Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; -+- else if (Name.consume_front("rotate.")) -+- // nvvm.rotate.{b32,b64,right.b64} -+- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; -+- else if (Name.consume_front("ptr.gen.to.")) -+- // nvvm.ptr.gen.to.{local,shared,global,constant} -+- Expand = Name.starts_with("local") || Name.starts_with("shared") || -+- Name.starts_with("global") || Name.starts_with("constant"); -+- else if (Name.consume_front("ptr.")) -+- // nvvm.ptr.{local,shared,global,constant}.to.gen -+- Expand = -+- (Name.consume_front("local") || Name.consume_front("shared") || -+- Name.consume_front("global") || Name.consume_front("constant")) && -+- Name.starts_with(".to.gen"); -+ else -+ Expand = false; -+ -+@@ -2271,117 +2258,6 @@ -+ } -+ } -+ -+-static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, -+- Function *F, IRBuilder<> &Builder) { -+- Value *Rep = nullptr; -+- -+- if (Name == "abs.i" || Name == "abs.ll") { -+- Value *Arg = CI->getArgOperand(0); -+- Value *Neg = Builder.CreateNeg(Arg, "neg"); -+- Value *Cmp = Builder.CreateICmpSGE( -+- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); -+- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); -+- } else if (Name.starts_with("atomic.load.add.f32.p") || -+- Name.starts_with("atomic.load.add.f64.p")) { -+- Value *Ptr = CI->getArgOperand(0); -+- Value *Val = CI->getArgOperand(1); -+- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), -+- AtomicOrdering::SequentiallyConsistent); -+- } else if (Name.consume_front("max.") && -+- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -+- Name == "ui" || Name == "ull")) { -+- Value *Arg0 = CI->getArgOperand(0); -+- Value *Arg1 = CI->getArgOperand(1); -+- Value *Cmp = Name.starts_with("u") -+- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") -+- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); -+- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); -+- } else if (Name.consume_front("min.") && -+- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -+- Name == "ui" || Name == "ull")) { -+- Value *Arg0 = CI->getArgOperand(0); -+- Value *Arg1 = CI->getArgOperand(1); -+- Value *Cmp = Name.starts_with("u") -+- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") -+- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); -+- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); -+- } else if (Name == "clz.ll") { -+- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. -+- Value *Arg = CI->getArgOperand(0); -+- Value *Ctlz = Builder.CreateCall( -+- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, -+- {Arg->getType()}), -+- {Arg, Builder.getFalse()}, "ctlz"); -+- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); -+- } else if (Name == "popc.ll") { -+- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an -+- // i64. -+- Value *Arg = CI->getArgOperand(0); -+- Value *Popc = Builder.CreateCall( -+- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, -+- {Arg->getType()}), -+- Arg, "ctpop"); -+- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); -+- } else if (Name == "h2f") { -+- Rep = Builder.CreateCall( -+- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, -+- {Builder.getFloatTy()}), -+- CI->getArgOperand(0), "h2f"); -+- } else if (Name.consume_front("bitcast.") && -+- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || -+- Name == "d2ll")) { -+- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); -+- } else if (Name == "rotate.b32") { -+- Value *Arg = CI->getOperand(0); -+- Value *ShiftAmt = CI->getOperand(1); -+- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, -+- {Arg, Arg, ShiftAmt}); -+- } else if (Name == "rotate.b64") { -+- Type *Int64Ty = Builder.getInt64Ty(); -+- Value *Arg = CI->getOperand(0); -+- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); -+- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, -+- {Arg, Arg, ZExtShiftAmt}); -+- } else if (Name == "rotate.right.b64") { -+- Type *Int64Ty = Builder.getInt64Ty(); -+- Value *Arg = CI->getOperand(0); -+- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); -+- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, -+- {Arg, Arg, ZExtShiftAmt}); -+- } else if ((Name.consume_front("ptr.gen.to.") && -+- (Name.starts_with("local") || Name.starts_with("shared") || -+- Name.starts_with("global") || Name.starts_with("constant"))) || -+- (Name.consume_front("ptr.") && -+- (Name.consume_front("local") || Name.consume_front("shared") || -+- Name.consume_front("global") || -+- Name.consume_front("constant")) && -+- Name.starts_with(".to.gen"))) { -+- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); -+- } else { -+- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); -+- if (IID != Intrinsic::not_intrinsic && -+- !F->getReturnType()->getScalarType()->isBFloatTy()) { -+- rename(F); -+- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); -+- SmallVector Args; -+- for (size_t I = 0; I < NewFn->arg_size(); ++I) { -+- Value *Arg = CI->getArgOperand(I); -+- Type *OldType = Arg->getType(); -+- Type *NewType = NewFn->getArg(I)->getType(); -+- Args.push_back( -+- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) -+- ? Builder.CreateBitCast(Arg, NewType) -+- : Arg); -+- } -+- Rep = Builder.CreateCall(NewFn, Args); -+- if (F->getReturnType()->isIntegerTy()) -+- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); -+- } -+- } -+- -+- return Rep; -+-} -+- -+ static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, -+ IRBuilder<> &Builder) { -+ LLVMContext &C = F->getContext(); -+@@ -4332,8 +4208,85 @@ -+ -+ if (!IsX86 && Name == "stackprotectorcheck") { -+ Rep = nullptr; -++ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { -++ Value *Arg = CI->getArgOperand(0); -++ Value *Neg = Builder.CreateNeg(Arg, "neg"); -++ Value *Cmp = Builder.CreateICmpSGE( -++ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); -++ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); -++ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || -++ Name.starts_with("atomic.load.add.f64.p"))) { -++ Value *Ptr = CI->getArgOperand(0); -++ Value *Val = CI->getArgOperand(1); -++ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), -++ AtomicOrdering::SequentiallyConsistent); -++ } else if (IsNVVM && Name.consume_front("max.") && -++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -++ Name == "ui" || Name == "ull")) { -++ Value *Arg0 = CI->getArgOperand(0); -++ Value *Arg1 = CI->getArgOperand(1); -++ Value *Cmp = Name.starts_with("u") -++ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") -++ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); -++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); -++ } else if (IsNVVM && Name.consume_front("min.") && -++ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -++ Name == "ui" || Name == "ull")) { -++ Value *Arg0 = CI->getArgOperand(0); -++ Value *Arg1 = CI->getArgOperand(1); -++ Value *Cmp = Name.starts_with("u") -++ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") -++ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); -++ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); -++ } else if (IsNVVM && Name == "clz.ll") { -++ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. -++ Value *Arg = CI->getArgOperand(0); -++ Value *Ctlz = Builder.CreateCall( -++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, -++ {Arg->getType()}), -++ {Arg, Builder.getFalse()}, "ctlz"); -++ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); -++ } else if (IsNVVM && Name == "popc.ll") { -++ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an -++ // i64. -++ Value *Arg = CI->getArgOperand(0); -++ Value *Popc = Builder.CreateCall( -++ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, -++ {Arg->getType()}), -++ Arg, "ctpop"); -++ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); -+ } else if (IsNVVM) { -+- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); -++ if (Name == "h2f") { -++ Rep = -++ Builder.CreateCall(Intrinsic::getDeclaration( -++ F->getParent(), Intrinsic::convert_from_fp16, -++ {Builder.getFloatTy()}), -++ CI->getArgOperand(0), "h2f"); -++ } else if (Name.consume_front("bitcast.") && -++ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || -++ Name == "d2ll")) { -++ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); -++ } else { -++ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); -++ if (IID != Intrinsic::not_intrinsic && -++ !F->getReturnType()->getScalarType()->isBFloatTy()) { -++ rename(F); -++ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); -++ SmallVector Args; -++ for (size_t I = 0; I < NewFn->arg_size(); ++I) { -++ Value *Arg = CI->getArgOperand(I); -++ Type *OldType = Arg->getType(); -++ Type *NewType = NewFn->getArg(I)->getType(); -++ Args.push_back((OldType->isIntegerTy() && -++ NewType->getScalarType()->isBFloatTy()) -++ ? Builder.CreateBitCast(Arg, NewType) -++ : Arg); -++ } -++ Rep = Builder.CreateCall(NewFn, Args); -++ if (F->getReturnType()->isIntegerTy()) -++ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); -++ } -++ } -+ } else if (IsX86) { -+ Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); -+ } else if (IsARM) { -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -+--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -++++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -+@@ -292,7 +292,6 @@ -+ static const LLT S224 = LLT::scalar(224); -+ static const LLT S256 = LLT::scalar(256); -+ static const LLT S512 = LLT::scalar(512); -+-static const LLT S1024 = LLT::scalar(1024); -+ static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); -+ -+ static const LLT V2S8 = LLT::fixed_vector(2, 8); -+@@ -333,8 +332,8 @@ -+ static const LLT V2S128 = LLT::fixed_vector(2, 128); -+ static const LLT V4S128 = LLT::fixed_vector(4, 128); -+ -+-static std::initializer_list AllScalarTypes = { -+- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; -++static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, -++ S160, S224, S256, S512}; -+ -+ static std::initializer_list AllS16Vectors{ -+ V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; -+@@ -890,11 +889,10 @@ -+ .clampScalar(0, S16, S64); -+ -+ getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) -+- .legalIf(isRegisterClassType(0)) -++ .legalIf(isRegisterType(0)) -+ // s1 and s16 are special cases because they have legal operations on -+ // them, but don't really occupy registers in the normal way. -+ .legalFor({S1, S16}) -+- .clampNumElements(0, V16S32, V32S32) -+ .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) -+ .clampScalarOrElt(0, S32, MaxScalar) -+ .widenScalarToNextPow2(0, 32) -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -+--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -++++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -+@@ -174,6 +174,10 @@ -+ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" -+ "&& Subtarget->getPTXVersion() >= 64)">; -+ -++def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; -++def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; -++def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; -++ -+ def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; -+ def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; -+ -+@@ -1661,6 +1665,167 @@ -+ "brev.b64 \t$dst, $a;", -+ [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; -+ -++// -++// Rotate: Use ptx shf instruction if available. -++// -++ -++// 32 bit r2 = rotl r1, n -++// => -++// r2 = shf.l r1, r1, n -++def ROTL32imm_hw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), -++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, -++ Requires<[hasHWROT32]>; -++ -++def ROTL32reg_hw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -++ Requires<[hasHWROT32]>; -++ -++// 32 bit r2 = rotr r1, n -++// => -++// r2 = shf.r r1, r1, n -++def ROTR32imm_hw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), -++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, -++ Requires<[hasHWROT32]>; -++ -++def ROTR32reg_hw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -++ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -++ Requires<[hasHWROT32]>; -++ -++// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. -++def ROT32imm_sw : -++ NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), -++ "{{\n\t" -++ ".reg .b32 %lhs;\n\t" -++ ".reg .b32 %rhs;\n\t" -++ "shl.b32 \t%lhs, $src, $amt1;\n\t" -++ "shr.b32 \t%rhs, $src, $amt2;\n\t" -++ "add.u32 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ []>; -++ -++def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); -++}]>; -++ -++def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), -++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, -++ Requires<[noHWROT32]>; -++def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), -++ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, -++ Requires<[noHWROT32]>; -++ -++// 32-bit software rotate left by register. -++def ROTL32reg_sw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -++ "{{\n\t" -++ ".reg .b32 %lhs;\n\t" -++ ".reg .b32 %rhs;\n\t" -++ ".reg .b32 %amt2;\n\t" -++ "shl.b32 \t%lhs, $src, $amt;\n\t" -++ "sub.s32 \t%amt2, 32, $amt;\n\t" -++ "shr.b32 \t%rhs, $src, %amt2;\n\t" -++ "add.u32 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -++ Requires<[noHWROT32]>; -++ -++// 32-bit software rotate right by register. -++def ROTR32reg_sw : -++ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -++ "{{\n\t" -++ ".reg .b32 %lhs;\n\t" -++ ".reg .b32 %rhs;\n\t" -++ ".reg .b32 %amt2;\n\t" -++ "shr.b32 \t%lhs, $src, $amt;\n\t" -++ "sub.s32 \t%amt2, 32, $amt;\n\t" -++ "shl.b32 \t%rhs, $src, %amt2;\n\t" -++ "add.u32 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -++ Requires<[noHWROT32]>; -++ -++// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. -++def ROT64imm_sw : -++ NVPTXInst<(outs Int64Regs:$dst), -++ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), -++ "{{\n\t" -++ ".reg .b64 %lhs;\n\t" -++ ".reg .b64 %rhs;\n\t" -++ "shl.b64 \t%lhs, $src, $amt1;\n\t" -++ "shr.b64 \t%rhs, $src, $amt2;\n\t" -++ "add.u64 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ []>; -++ -++def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); -++}]>; -++ -++def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), -++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; -++def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), -++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; -++ -++// 64-bit software rotate left by register. -++def ROTL64reg_sw : -++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), -++ "{{\n\t" -++ ".reg .b64 %lhs;\n\t" -++ ".reg .b64 %rhs;\n\t" -++ ".reg .u32 %amt2;\n\t" -++ "and.b32 \t%amt2, $amt, 63;\n\t" -++ "shl.b64 \t%lhs, $src, %amt2;\n\t" -++ "sub.u32 \t%amt2, 64, %amt2;\n\t" -++ "shr.b64 \t%rhs, $src, %amt2;\n\t" -++ "add.u64 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; -++ -++def ROTR64reg_sw : -++ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), -++ "{{\n\t" -++ ".reg .b64 %lhs;\n\t" -++ ".reg .b64 %rhs;\n\t" -++ ".reg .u32 %amt2;\n\t" -++ "and.b32 \t%amt2, $amt, 63;\n\t" -++ "shr.b64 \t%lhs, $src, %amt2;\n\t" -++ "sub.u32 \t%amt2, 64, %amt2;\n\t" -++ "shl.b64 \t%rhs, $src, %amt2;\n\t" -++ "add.u64 \t$dst, %lhs, %rhs;\n\t" -++ "}}", -++ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; -++ -++// -++// Funnnel shift in clamp mode -++// -++ -++// Create SDNodes so they can be used in the DAG code, e.g. -++// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) -++def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; -++def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; -++ -++def FUNSHFLCLAMP : -++ NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -++ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", -++ [(set Int32Regs:$dst, -++ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; -++ -++def FUNSHFRCLAMP : -++ NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -++ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", -++ [(set Int32Regs:$dst, -++ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; -+ -+ // -+ // BFE - bit-field extract -+@@ -3492,42 +3657,6 @@ -+ def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), -+ (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; -+ -+-// -+-// Funnel-Shift -+-// -+- -+-// Create SDNodes so they can be used in the DAG code, e.g. -+-// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) -+-def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; -+-def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; -+- -+-// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so -+-// no side effects. -+-let hasSideEffects = false in { -+- multiclass ShfInst { -+- def _i -+- : NVPTXInst<(outs Int32Regs:$dst), -+- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -+- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", -+- [(set Int32Regs:$dst, -+- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, -+- Requires<[hasHWROT32]>; -+- -+- def _r -+- : NVPTXInst<(outs Int32Regs:$dst), -+- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", -+- [(set Int32Regs:$dst, -+- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, -+- Requires<[hasHWROT32]>; -+- } -+- -+- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; -+- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; -+- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; -+- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; -+-} -+- -+ // Count leading zeros -+ let hasSideEffects = false in { -+ def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -+--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -++++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -+@@ -2537,45 +2537,59 @@ -+ : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; -+ -+ -+-multiclass NG_TO_G { -++multiclass NG_TO_G { -+ def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), -+- "cvta." # Str # ".u32 \t$result, $src;", []>; -++ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), -++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; -+ def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), -+- "cvta." # Str # ".u64 \t$result, $src;", []>; -++ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), -++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; -++ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), -++ "{{ .reg .b64 %tmp;\n\t" -++ #" cvt.u64.u32 \t%tmp, $src;\n\t" -++ #" cvta." # Str # ".u64 \t$result, %tmp; }}", -++ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, -++ Requires<[ShortPtr]>; -+ } -+ -+-multiclass G_TO_NG { -++multiclass G_TO_NG { -+ def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), -+- "cvta.to." # Str # ".u32 \t$result, $src;", []>; -++ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), -++ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; -+ def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), -+- "cvta.to." # Str # ".u64 \t$result, $src;", []>; -++ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), -++ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; -++ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), -++ "{{ .reg .b64 %tmp;\n\t" -++ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" -++ #" cvt.u32.u64 \t$result, %tmp; }}", -++ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, -++ Requires<[ShortPtr]>; -+ } -+ -+-defm cvta_local : NG_TO_G<"local">; -+-defm cvta_shared : NG_TO_G<"shared">; -+-defm cvta_global : NG_TO_G<"global">; -+-defm cvta_const : NG_TO_G<"const">; -+- -+-defm cvta_to_local : G_TO_NG<"local">; -+-defm cvta_to_shared : G_TO_NG<"shared">; -+-defm cvta_to_global : G_TO_NG<"global">; -+-defm cvta_to_const : G_TO_NG<"const">; -+- -+-// nvvm.ptr.param.to.gen -+-defm cvta_param : NG_TO_G<"param">; -+- -+-def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), -+- (cvta_param Int32Regs:$src)>; -+- -+-def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), -+- (cvta_param_64 Int64Regs:$src)>; -++defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; -++defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; -++defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; -++defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; -++defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; -++ -++defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; -++defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; -++defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; -++defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; -+ -+ // nvvm.ptr.gen.to.param -+-def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), -+- (IMOV32rr Int32Regs:$src)>; -++def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), -++ (ins Int32Regs:$src), -++ "mov.u32 \t$result, $src;", -++ [(set Int32Regs:$result, -++ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; -++def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), -++ (ins Int64Regs:$src), -++ "mov.u64 \t$result, $src;", -++ [(set Int64Regs:$result, -++ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; -+ -+-def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), -+- (IMOV64rr Int64Regs:$src)>; -+ -+ // nvvm.move intrinsicc -+ def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), -+@@ -2618,6 +2632,24 @@ -+ [(set Int64Regs:$r, -+ (int_nvvm_move_ptr texternalsym:$s))]>;*/ -+ -++ -++// MoveParam %r1, param -++// ptr_local_to_gen %r2, %r1 -++// ptr_gen_to_local %r3, %r2 -++// -> -++// mov %r1, param -++ -++// @TODO: Revisit this. There is a type -++// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym -++// instructions are not currently defined. However, we can use the ptr -++// variants and the asm printer will do the right thing. -++def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen -++ (MoveParam texternalsym:$src)))), -++ (nvvm_move_ptr64 texternalsym:$src)>; -++def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen -++ (MoveParam texternalsym:$src)))), -++ (nvvm_move_ptr32 texternalsym:$src)>; -++ -+ def texsurf_handles -+ : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), -+ "mov.u64 \t$result, $src;", []>; -+@@ -2701,9 +2733,134 @@ -+ def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; -+ -+ -++// rotate builtin support -++ -++def ROTATE_B32_HW_IMM -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$src, i32imm:$amt), -++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, -++ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, -++ Requires<[hasHWROT32]> ; -++ -++def ROTATE_B32_HW_REG -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$src, Int32Regs:$amt), -++ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -++ [(set Int32Regs:$dst, -++ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, -++ Requires<[hasHWROT32]> ; -++ -++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), -++ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, -++ Requires<[noHWROT32]> ; -++ -++def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), -++ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, -++ Requires<[noHWROT32]> ; -++ -++let hasSideEffects = false in { -++ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), -++ !strconcat("{{\n\t", -++ ".reg .b32 %dummy;\n\t", -++ "mov.b64 \t{$dst,%dummy}, $src;\n\t", -++ "}}"), -++ []> ; -++ -++ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), -++ !strconcat("{{\n\t", -++ ".reg .b32 %dummy;\n\t", -++ "mov.b64 \t{%dummy,$dst}, $src;\n\t", -++ "}}"), -++ []> ; -++} -++ -++let hasSideEffects = false in { -++ def PACK_TWO_INT32 -++ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), -++ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; -++} -++ -+ def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), -+- (V2I32toI64 (I64toI32H Int64Regs:$src), -+- (I64toI32L Int64Regs:$src))> ; -++ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src))> ; -++ -++// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so -++// no side effects. -++let hasSideEffects = false in { -++ def SHF_L_WRAP_B32_IMM -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -++ Requires<[hasHWROT32]>; -++ -++ def SHF_L_WRAP_B32_REG -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -++ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -++ Requires<[hasHWROT32]>; -++ -++ def SHF_R_WRAP_B32_IMM -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -++ Requires<[hasHWROT32]>; -++ -++ def SHF_R_WRAP_B32_REG -++ : NVPTXInst<(outs Int32Regs:$dst), -++ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -++ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -++ Requires<[hasHWROT32]>; -++} -++ -++// HW version of rotate 64 -++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), -++ (PACK_TWO_INT32 -++ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src), imm:$amt), -++ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), -++ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, -++ Requires<[hasHWROT32]>; -++ -++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), -++ (PACK_TWO_INT32 -++ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), -++ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), -++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, -++ Requires<[hasHWROT32]>; -++ -++ -++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), -++ (PACK_TWO_INT32 -++ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), -++ (GET_HI_INT64 Int64Regs:$src), imm:$amt), -++ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, -++ Requires<[hasHWROT32]>; -++ -++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), -++ (PACK_TWO_INT32 -++ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), -++ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), -++ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), -++ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, -++ Requires<[hasHWROT32]>; -++ -++// SW version of rotate 64 -++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), -++ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, -++ Requires<[noHWROT32]>; -++def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), -++ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, -++ Requires<[noHWROT32]>; -++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), -++ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, -++ Requires<[noHWROT32]>; -++def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), -++ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, -++ Requires<[noHWROT32]>; -++ -+ -+ //----------------------------------- -+ // Texture Intrinsics -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -+--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -++++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -+@@ -1109,21 +1109,11 @@ -+ AddrSpaceCastSDNode *CastN = cast(N); -+ unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); -+ unsigned DstAddrSpace = CastN->getDestAddressSpace(); -+- SDLoc DL(N); -+ assert(SrcAddrSpace != DstAddrSpace && -+ "addrspacecast must be between different address spaces"); -+ -+ if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { -+ // Specific to generic -+- -+- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { -+- SDValue CvtNone = -+- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); -+- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, -+- Src, CvtNone); -+- Src = SDValue(Cvt, 0); -+- } -+- -+ unsigned Opc; -+ switch (SrcAddrSpace) { -+ default: report_fatal_error("Bad address space in addrspacecast"); -+@@ -1131,16 +1121,26 @@ -+ Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; -+ break; -+ case ADDRESS_SPACE_SHARED: -+- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -++ ? NVPTX::cvta_shared_6432 -++ : NVPTX::cvta_shared_64) -++ : NVPTX::cvta_shared; -+ break; -+ case ADDRESS_SPACE_CONST: -+- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -++ ? NVPTX::cvta_const_6432 -++ : NVPTX::cvta_const_64) -++ : NVPTX::cvta_const; -+ break; -+ case ADDRESS_SPACE_LOCAL: -+- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -++ ? NVPTX::cvta_local_6432 -++ : NVPTX::cvta_local_64) -++ : NVPTX::cvta_local; -+ break; -+ } -+- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); -++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -++ Src)); -+ return; -+ } else { -+ // Generic to specific -+@@ -1153,28 +1153,30 @@ -+ Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; -+ break; -+ case ADDRESS_SPACE_SHARED: -+- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -++ ? NVPTX::cvta_to_shared_3264 -++ : NVPTX::cvta_to_shared_64) -++ : NVPTX::cvta_to_shared; -+ break; -+ case ADDRESS_SPACE_CONST: -+- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -++ ? NVPTX::cvta_to_const_3264 -++ : NVPTX::cvta_to_const_64) -++ : NVPTX::cvta_to_const; -+ break; -+ case ADDRESS_SPACE_LOCAL: -+- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; -++ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -++ ? NVPTX::cvta_to_local_3264 -++ : NVPTX::cvta_to_local_64) -++ : NVPTX::cvta_to_local; -+ break; -+ case ADDRESS_SPACE_PARAM: -+- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; -++ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 -++ : NVPTX::nvvm_ptr_gen_to_param; -+ break; -+ } -+- -+- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); -+- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { -+- SDValue CvtNone = -+- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); -+- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, -+- SDValue(CVTA, 0), CvtNone); -+- } -+- -+- ReplaceNode(N, CVTA); -++ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -++ Src)); -+ return; -+ } -+ } -+diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -+--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -++++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -+@@ -594,13 +594,20 @@ -+ setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); -+ setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); -+ -+- setOperationAction({ISD::ROTL, ISD::ROTR}, -+- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, -+- Expand); -+- -+- if (STI.hasHWROT32()) -+- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); -++ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs -++ // that don't have h/w rotation we lower them to multi-instruction assembly. -++ // See ROT*_sw in NVPTXIntrInfo.td -++ setOperationAction(ISD::ROTL, MVT::i64, Legal); -++ setOperationAction(ISD::ROTR, MVT::i64, Legal); -++ setOperationAction(ISD::ROTL, MVT::i32, Legal); -++ setOperationAction(ISD::ROTR, MVT::i32, Legal); -+ -++ setOperationAction(ISD::ROTL, MVT::i16, Expand); -++ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); -++ setOperationAction(ISD::ROTR, MVT::i16, Expand); -++ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); -++ setOperationAction(ISD::ROTL, MVT::i8, Expand); -++ setOperationAction(ISD::ROTR, MVT::i8, Expand); -+ setOperationAction(ISD::BSWAP, MVT::i16, Expand); -+ -+ setOperationAction(ISD::BR_JT, MVT::Other, Custom); -+diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -+--- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -++++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -+@@ -31,19 +31,6 @@ -+ declare i64 @llvm.nvvm.bitcast.d2ll(double) -+ declare double @llvm.nvvm.bitcast.ll2d(i64) -+ -+-declare i32 @llvm.nvvm.rotate.b32(i32, i32) -+-declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) -+-declare i64 @llvm.nvvm.rotate.b64(i64, i32) -+- -+-declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -+-declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) -+-declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) -+-declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) -+-declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) -+-declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) -+-declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) -+-declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) -+- -+ ; CHECK-LABEL: @simple_upgrade -+ define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { -+ ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) -+@@ -152,42 +139,4 @@ -+ %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) -+ -+ ret void -+-} -+- -+-; CHECK-LABEL: @rotate -+-define void @rotate(i32 %a, i64 %b) { -+-; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) -+-; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) -+-; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) -+-; -+- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) -+- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) -+- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) -+- ret void -+-} -+- -+-; CHECK-LABEL: @addrspacecast -+-define void @addrspacecast(ptr %p0) { -+-; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) -+-; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr -+-; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) -+-; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr -+-; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) -+-; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr -+-; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) -+-; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr -+-; -+- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) -+- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) -+- -+- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) -+- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) -+- -+- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) -+- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) -+- -+- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) -+- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) -+- -+- ret void -+-} -++} -+\ No newline at end of file -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll -+--- a/llvm/test/CodeGen/AMDGPU/freeze.ll -++++ b/llvm/test/CodeGen/AMDGPU/freeze.ll -+@@ -1,1856 +0,0 @@ -+-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -+-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s -+-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s -+-; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s -+-; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s -+- -+-define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v2i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v2i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <2 x i32> %a -+- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v3i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v3i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <3 x i32> %a -+- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v4i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v4i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <4 x i32> %a -+- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v5i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v5i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v5i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v5i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <5 x i32> %a -+- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v6i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v6i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v6i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v6i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <6 x i32> %a -+- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v7i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v7i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v7i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v7i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <7 x i32> %a -+- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v8i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v8i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v8i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v8i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <8 x i32> %a -+- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v9i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x2 -+-; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v9i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x2 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v9i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x2 -+-; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v9i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x2 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <9 x i32> %a -+- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v10i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: s_clause 0x2 -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 -+-; GFX10-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v10i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: s_clause 0x2 -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 -+-; GFX11-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <10 x i32> %a -+- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v11i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: s_clause 0x2 -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 -+-; GFX10-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v11i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: s_clause 0x2 -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 -+-; GFX11-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <11 x i32> %a -+- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_v12i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: s_clause 0x2 -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_v12i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: s_clause 0x2 -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <12 x i32> %a -+- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+-define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v13i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x3 -+-; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v13i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x3 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v13i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x3 -+-; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v13i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x3 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <13 x i32> %a -+- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v14i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x3 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v14i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x3 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v14i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x3 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v14i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x3 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <14 x i32> %a -+- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v15i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x3 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v15i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x3 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v15i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x3 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v15i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x3 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <15 x i32> %a -+- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v16i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x3 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v16i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x3 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v16i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x3 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v16i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x3 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <16 x i32> %a -+- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v17i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x4 -+-; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v17i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x4 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v17i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x4 -+-; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v17i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x4 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <17 x i32> %a -+- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v18i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x4 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v18i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x4 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v18i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x4 -+-; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v18i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x4 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <18 x i32> %a -+- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v19i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x4 -+-; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v19i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x4 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v19i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x4 -+-; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v19i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x4 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <19 x i32> %a -+- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v20i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x4 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v20i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x4 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v20i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x4 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v20i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x4 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <20 x i32> %a -+- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v21i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x5 -+-; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v21i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x5 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v21i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x5 -+-; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v21i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x5 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <21 x i32> %a -+- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v22i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x5 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v22i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x5 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v22i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x5 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v22i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x5 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <22 x i32> %a -+- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v30i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x7 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 -+-; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v30i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x7 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 -+-; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v30i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x7 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 -+-; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v30i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x7 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 -+-; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <30 x i32> %a -+- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v31i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x7 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 -+-; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v31i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x7 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 -+-; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v31i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x7 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 -+-; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v31i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x7 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 -+-; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <31 x i32> %a -+- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_v32i32: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x7 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_v32i32: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x7 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_v32i32: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x7 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 -+-; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 -+-; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 -+-; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 -+-; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 -+-; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off -+-; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_v32i32: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x7 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 -+-; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 -+-; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 -+-; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 -+-; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 -+-; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze <32 x i32> %a -+- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_i32: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dword v0, v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dword v[2:3], v0, off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_i32: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b32 v0, v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b32 v[2:3], v0, off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load i32, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze i32 %a -+- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_i64: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_i64: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load i64, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze i64 %a -+- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_float: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dword v0, v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dword v[2:3], v0, off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_float: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b32 v0, v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b32 v[2:3], v0, off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load float, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze float %a -+- store float %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-LABEL: freeze_i128: -+-; GFX10: ; %bb.0: -+-; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-LABEL: freeze_i128: -+-; GFX11: ; %bb.0: -+-; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-NEXT: s_setpc_b64 s[30:31] -+- %a = load i128, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze i128 %a -+- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+- -+-define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { -+-; GFX10-SDAG-LABEL: freeze_i256: -+-; GFX10-SDAG: ; %bb.0: -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-SDAG-NEXT: s_clause 0x1 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 -+-; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 -+-; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off -+-; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX10-GISEL-LABEL: freeze_i256: -+-; GFX10-GISEL: ; %bb.0: -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX10-GISEL-NEXT: s_clause 0x1 -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off -+-; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off -+-; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 -+-; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-SDAG-LABEL: freeze_i256: -+-; GFX11-SDAG: ; %bb.0: -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-SDAG-NEXT: s_clause 0x1 -+-; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 -+-; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 -+-; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off -+-; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] -+-; -+-; GFX11-GISEL-LABEL: freeze_i256: -+-; GFX11-GISEL: ; %bb.0: -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -+-; GFX11-GISEL-NEXT: s_clause 0x1 -+-; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off -+-; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off -+-; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) -+-; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 -+-; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -+- %a = load i256, ptr addrspace(1) %ptra, align 4 -+- %freeze = freeze i256 %a -+- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 -+- ret void -+-} -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -+@@ -171,9 +171,11 @@ -+ ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 -+ ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 -+ ; GCN-NEXT: {{ $}} -+- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF -+- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) -+- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) -++ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF -++ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 -++ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 -++ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 -++ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] -+ %0:sgpr(s192) = G_IMPLICIT_DEF -+ %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 -+ S_ENDPGM 0, implicit %1, implicit %2, implicit %3 -+@@ -292,11 +294,11 @@ -+ ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) -+ ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) -+ ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) -+- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) -+- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) -+- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) -+- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) -+- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) -++ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) -++ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) -++ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) -++ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) -++ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) -+ %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 -+ %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 -+ %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -+@@ -171,8 +171,12 @@ -+ -+ ; CHECK-LABEL: name: test_freeze_s448 -+ ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 -+- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] -+- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) -++ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) -++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) -++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) -++ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) -+ %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 -+ %1:_(s448) = G_TRUNC %0 -+ %2:_(s448) = G_FREEZE %1 -+@@ -395,12 +399,14 @@ -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_freeze_v33s32 -+- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -+- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] -+- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) -+- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) -++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) -++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) -++ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) -+ ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) -+ %0:_(<33 x s32>) = G_IMPLICIT_DEF -+ %1:_(<33 x s32>) = G_FREEZE %0 -+@@ -413,10 +419,12 @@ -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_freeze_v64s32 -+- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -+- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -+- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) -++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) -+ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) -+ %0:_(<64 x s32>) = G_IMPLICIT_DEF -+ %1:_(<64 x s32>) = G_FREEZE %0 -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -+@@ -135,9 +135,8 @@ -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_implicit_def_s448 -+- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) -+- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 -++ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 -+ ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) -+ %0:_(s448) = G_IMPLICIT_DEF -+ %1:_(s32) = G_EXTRACT %0, 0 -+@@ -297,6 +296,18 @@ -+ ... -+ -+ --- -++name: test_implicit_def_v17s32 -++body: | -++ bb.0: -++ -++ ; CHECK-LABEL: name: test_implicit_def_v17s32 -++ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) -++ %0:_(<17 x s32>) = G_IMPLICIT_DEF -++ S_NOP 0, implicit %0 -++... -++ -++--- -+ name: test_implicit_def_v32s32 -+ body: | -+ bb.0: -+@@ -317,9 +328,9 @@ -+ ; CHECK-LABEL: name: test_implicit_def_v33s32 -+ ; CHECK: liveins: $vgpr0_vgpr1 -+ ; CHECK-NEXT: {{ $}} -+- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 -+ ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) -+ ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) -+@@ -337,9 +348,10 @@ -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_implicit_def_v64s32 -+- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) -++ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) -+ %0:_(<64 x s32>) = G_IMPLICIT_DEF -+ %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 -+ S_NOP 0, implicit %0, implicit %1 -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -+@@ -190,11 +190,13 @@ -+ ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 -+ ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 -+ ; CHECK-NEXT: {{ $}} -+- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 -+ ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 -+- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) -+ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 -+ ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) -+@@ -241,8 +243,10 @@ -+ ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 -+ ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) -+ ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) -+- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) -+ ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) -+ ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -+@@ -673,86 +673,88 @@ -+ ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) -+ ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 -+ ; CHECK-NEXT: {{ $}} -+- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 -+ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 -+ ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] -+- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+ ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 -+ ; CHECK-NEXT: G_BR %bb.2 -+ ; CHECK-NEXT: {{ $}} -+ ; CHECK-NEXT: bb.1: -+ ; CHECK-NEXT: successors: %bb.2(0x80000000) -+ ; CHECK-NEXT: {{ $}} -+- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] -+- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] -+- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] -+- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] -+- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] -+- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] -+- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] -+- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] -+- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] -+- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] -+- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] -+- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] -+- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] -+- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] -+- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] -+- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] -+- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] -+- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] -+- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] -+- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] -+- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] -+- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] -+- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] -+- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] -+- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] -+- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] -+- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] -+- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] -+- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] -+- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] -+- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] -+- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] -+- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] -+- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] -+- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] -+- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] -+- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] -+- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] -+- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] -+- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] -+- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] -+- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] -+- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] -+- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] -+- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] -+- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] -+- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] -+- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] -+- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] -+- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] -+- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] -+- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] -+- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] -+- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] -+- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] -+- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] -+- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] -+- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] -+- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] -+- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] -+- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] -+- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] -+- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] -+- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] -++ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -++ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] -++ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] -++ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] -++ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] -++ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] -++ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] -++ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] -++ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] -++ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] -++ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] -++ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] -++ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] -++ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] -++ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] -++ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] -++ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] -++ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] -++ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] -++ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] -++ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] -++ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] -++ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] -++ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] -++ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] -++ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] -++ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] -++ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] -++ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] -++ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] -++ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] -++ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] -++ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] -++ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] -++ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] -++ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] -++ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] -++ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] -++ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] -++ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] -++ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] -++ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] -++ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] -++ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] -++ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] -++ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] -++ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] -++ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] -++ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] -++ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] -++ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] -++ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] -++ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] -++ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] -++ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] -++ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] -++ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] -++ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] -++ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] -++ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] -++ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] -++ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] -++ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] -++ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] -++ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] -+ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) -+ ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) -+ ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) -+@@ -760,10 +762,10 @@ -+ ; CHECK-NEXT: G_BR %bb.2 -+ ; CHECK-NEXT: {{ $}} -+ ; CHECK-NEXT: bb.2: -+- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 -+- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 -+- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 -+- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 -++ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 -++ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 -++ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 -++ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) -+ ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) -+ bb.0: -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -+--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -++++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -+@@ -42,6 +42,8 @@ -+ ret void -+ } -+ -++ define void @non_power_of_2() { ret void } -++ -+ define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { -+ ret void -+ } -+@@ -185,6 +187,23 @@ -+ ... -+ -+ --- -++name: non_power_of_2 -++legalized: true -++ -++body: | -++ bb.0: -++ ; CHECK-LABEL: name: non_power_of_2 -++ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF -++ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 -++ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) -++ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 -++ %0:_(s448) = G_IMPLICIT_DEF -++ %1:_(s32) = G_EXTRACT %0:_(s448), 0 -++ $sgpr0 = COPY %1:_(s32) -++ SI_RETURN_TO_EPILOG $sgpr0 -++... -++ -++--- -+ name: load_constant_v4i16_from_8_align8 -+ legalized: true -+ -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -+--- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -++++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -+@@ -0,0 +1,21 @@ -++; RUN: opt < %s -O3 -S | FileCheck %s -++ -++; Address space intrinsics were erroneously marked NoCapture, leading to bad -++; optimizations (such as the store below being eliminated as dead code). This -++; test makes sure we don't regress. -++ -++declare void @foo(ptr addrspace(1)) -++ -++declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -++ -++; CHECK: @bar -++define void @bar() { -++ %t1 = alloca i32 -++; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) -++; CHECK-NEXT: store i32 10, ptr %t1 -++ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) -++ store i32 10, ptr %t1 -++ call void @foo(ptr addrspace(1) %t2) -++ ret void -++} -++ -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll -+--- a/llvm/test/CodeGen/NVPTX/rotate_64.ll -++++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll -+@@ -1,38 +1,25 @@ -+-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -+ ; RUN: llc < %s -march=nvptx64 | FileCheck %s -+ ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} -+ -+ declare i64 @llvm.nvvm.rotate.b64(i64, i32) -+ declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) -+ -++; CHECK: rotate64 -+ define i64 @rotate64(i64 %a, i32 %b) { -+-; CHECK-LABEL: rotate64( -+-; CHECK: { -+-; CHECK-NEXT: .reg .b64 %rd<5>; -+-; CHECK-EMPTY: -+-; CHECK-NEXT: // %bb.0: -+-; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; -+-; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; -+-; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; -+-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; -+-; CHECK-NEXT: ret; -++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; -++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; -++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; -++; CHECK: ret -+ %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) -+ ret i64 %val -+ } -+ -++; CHECK: rotateright64 -+ define i64 @rotateright64(i64 %a, i32 %b) { -+-; CHECK-LABEL: rotateright64( -+-; CHECK: { -+-; CHECK-NEXT: .reg .b64 %rd<5>; -+-; CHECK-EMPTY: -+-; CHECK-NEXT: // %bb.0: -+-; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; -+-; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; -+-; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; -+-; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; -+-; CHECK-NEXT: ret; -++; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; -++; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; -++; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; -++; CHECK: ret -+ %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) -+ ret i64 %val -+ } -+diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll -+--- a/llvm/test/CodeGen/NVPTX/rotate.ll -++++ b/llvm/test/CodeGen/NVPTX/rotate.ll -+@@ -9,29 +9,26 @@ -+ declare i64 @llvm.nvvm.rotate.b64(i64, i32) -+ declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) -+ -+-declare i64 @llvm.fshl.i64(i64, i64, i64) -+-declare i64 @llvm.fshr.i64(i64, i64, i64) -+-declare i32 @llvm.fshl.i32(i32, i32, i32) -+-declare i32 @llvm.fshr.i32(i32, i32, i32) -+- -+- -+ ; SM20: rotate32 -+ ; SM35: rotate32 -+ define i32 @rotate32(i32 %a, i32 %b) { -+ ; SM20-LABEL: rotate32( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<9>; -++; SM20-NEXT: .reg .b32 %r<4>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; -+-; SM20-NEXT: and.b32 %r3, %r2, 31; -+-; SM20-NEXT: shl.b32 %r4, %r1, %r3; -+-; SM20-NEXT: neg.s32 %r5, %r2; -+-; SM20-NEXT: and.b32 %r6, %r5, 31; -+-; SM20-NEXT: shr.u32 %r7, %r1, %r6; -+-; SM20-NEXT: or.b32 %r8, %r4, %r7; -+-; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b32 %lhs; -++; SM20-NEXT: .reg .b32 %rhs; -++; SM20-NEXT: .reg .b32 %amt2; -++; SM20-NEXT: shl.b32 %lhs, %r1, %r2; -++; SM20-NEXT: sub.s32 %amt2, 32, %r2; -++; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; -++; SM20-NEXT: add.u32 %r3, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotate32( -+@@ -53,36 +50,45 @@ -+ define i64 @rotate64(i64 %a, i32 %b) { -+ ; SM20-LABEL: rotate64( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b32 %r<2>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM20-NEXT: neg.s32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; -+-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: .reg .u32 %amt2; -++; SM20-NEXT: and.b32 %amt2, %r1, 63; -++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; -++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotate64( -+ ; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b32 %r<6>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; -+-; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM35-NEXT: neg.s32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; -+-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b32 %dummy; -++; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; -++; SM35-NEXT: } -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b32 %dummy; -++; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; -++; SM35-NEXT: } -++; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; -++; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; -++; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; -++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) -+ ret i64 %val -+@@ -93,36 +99,45 @@ -+ define i64 @rotateright64(i64 %a, i32 %b) { -+ ; SM20-LABEL: rotateright64( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b32 %r<2>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; -+-; SM20-NEXT: neg.s32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; -+-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: .reg .u32 %amt2; -++; SM20-NEXT: and.b32 %amt2, %r1, 63; -++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; -++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotateright64( -+ ; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b32 %r<6>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; -+-; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; -+-; SM35-NEXT: neg.s32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; -+-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b32 %dummy; -++; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; -++; SM35-NEXT: } -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b32 %dummy; -++; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; -++; SM35-NEXT: } -++; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; -++; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; -++; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; -++; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) -+ ret i64 %val -+@@ -133,14 +148,18 @@ -+ define i32 @rotl0(i32 %x) { -+ ; SM20-LABEL: rotl0( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -++; SM20-NEXT: .reg .b32 %r<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; -+-; SM20-NEXT: shr.u32 %r2, %r1, 24; -+-; SM20-NEXT: shl.b32 %r3, %r1, 8; -+-; SM20-NEXT: or.b32 %r4, %r3, %r2; -+-; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b32 %lhs; -++; SM20-NEXT: .reg .b32 %rhs; -++; SM20-NEXT: shl.b32 %lhs, %r1, 8; -++; SM20-NEXT: shr.b32 %rhs, %r1, 24; -++; SM20-NEXT: add.u32 %r2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotl0( -+@@ -158,40 +177,51 @@ -+ ret i32 %t2 -+ } -+ -++declare i64 @llvm.fshl.i64(i64, i64, i64) -++declare i64 @llvm.fshr.i64(i64, i64, i64) -++ -+ ; SM35: rotl64 -+ define i64 @rotl64(i64 %a, i64 %n) { -+ ; SM20-LABEL: rotl64( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b32 %r<2>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM20-NEXT: neg.s32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; -+-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: .reg .u32 %amt2; -++; SM20-NEXT: and.b32 %amt2, %r1, 63; -++; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; -++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotl64( -+ ; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b32 %r<2>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; -+ ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM35-NEXT: neg.s32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; -+-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b64 %lhs; -++; SM35-NEXT: .reg .b64 %rhs; -++; SM35-NEXT: .reg .u32 %amt2; -++; SM35-NEXT: and.b32 %amt2, %r1, 63; -++; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; -++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; -++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM35-NEXT: } -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) -+ ret i64 %val -+@@ -201,26 +231,34 @@ -+ define i64 @rotl64_imm(i64 %a) { -+ ; SM20-LABEL: rotl64_imm( -+ ; SM20: { -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; -+-; SM20-NEXT: shr.u64 %rd2, %rd1, 62; -+-; SM20-NEXT: shl.b64 %rd3, %rd1, 2; -+-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: shl.b64 %lhs, %rd1, 2; -++; SM20-NEXT: shr.b64 %rhs, %rd1, 62; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotl64_imm( -+ ; SM35: { -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; -+-; SM35-NEXT: shr.u64 %rd2, %rd1, 62; -+-; SM35-NEXT: shl.b64 %rd3, %rd1, 2; -+-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b64 %lhs; -++; SM35-NEXT: .reg .b64 %rhs; -++; SM35-NEXT: shl.b64 %lhs, %rd1, 2; -++; SM35-NEXT: shr.b64 %rhs, %rd1, 62; -++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM35-NEXT: } -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) -+ ret i64 %val -+@@ -230,36 +268,44 @@ -+ define i64 @rotr64(i64 %a, i64 %n) { -+ ; SM20-LABEL: rotr64( -+ ; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b32 %r<2>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; -+ ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; -+-; SM20-NEXT: neg.s32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; -+-; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: .reg .u32 %amt2; -++; SM20-NEXT: and.b32 %amt2, %r1, 63; -++; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; -++; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotr64( -+ ; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b32 %r<2>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; -+ ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; -+-; SM35-NEXT: neg.s32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; -+-; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b64 %lhs; -++; SM35-NEXT: .reg .b64 %rhs; -++; SM35-NEXT: .reg .u32 %amt2; -++; SM35-NEXT: and.b32 %amt2, %r1, 63; -++; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; -++; SM35-NEXT: sub.u32 %amt2, 64, %amt2; -++; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; -++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM35-NEXT: } -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) -+ ret i64 %val -+@@ -269,180 +315,35 @@ -+ define i64 @rotr64_imm(i64 %a) { -+ ; SM20-LABEL: rotr64_imm( -+ ; SM20: { -+-; SM20-NEXT: .reg .b64 %rd<5>; -++; SM20-NEXT: .reg .b64 %rd<3>; -+ ; SM20-EMPTY: -+ ; SM20-NEXT: // %bb.0: -+ ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; -+-; SM20-NEXT: shl.b64 %rd2, %rd1, 62; -+-; SM20-NEXT: shr.u64 %rd3, %rd1, 2; -+-; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM20-NEXT: { -++; SM20-NEXT: .reg .b64 %lhs; -++; SM20-NEXT: .reg .b64 %rhs; -++; SM20-NEXT: shl.b64 %lhs, %rd1, 62; -++; SM20-NEXT: shr.b64 %rhs, %rd1, 2; -++; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM20-NEXT: } -++; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM20-NEXT: ret; -+ ; -+ ; SM35-LABEL: rotr64_imm( -+ ; SM35: { -+-; SM35-NEXT: .reg .b64 %rd<5>; -++; SM35-NEXT: .reg .b64 %rd<3>; -+ ; SM35-EMPTY: -+ ; SM35-NEXT: // %bb.0: -+ ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; -+-; SM35-NEXT: shl.b64 %rd2, %rd1, 62; -+-; SM35-NEXT: shr.u64 %rd3, %rd1, 2; -+-; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -++; SM35-NEXT: { -++; SM35-NEXT: .reg .b64 %lhs; -++; SM35-NEXT: .reg .b64 %rhs; -++; SM35-NEXT: shl.b64 %lhs, %rd1, 62; -++; SM35-NEXT: shr.b64 %rhs, %rd1, 2; -++; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -++; SM35-NEXT: } -++; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; -+ ; SM35-NEXT: ret; -+ %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) -+ ret i64 %val -+ } -+- -+-define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { -+-; SM20-LABEL: funnel_shift_right_32( -+-; SM20: { -+-; SM20-NEXT: .reg .b32 %r<11>; -+-; SM20-EMPTY: -+-; SM20-NEXT: // %bb.0: -+-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; -+-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; -+-; SM20-NEXT: and.b32 %r3, %r2, 31; -+-; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; -+-; SM20-NEXT: shr.u32 %r5, %r4, %r3; -+-; SM20-NEXT: shl.b32 %r6, %r1, 1; -+-; SM20-NEXT: not.b32 %r7, %r2; -+-; SM20-NEXT: and.b32 %r8, %r7, 31; -+-; SM20-NEXT: shl.b32 %r9, %r6, %r8; -+-; SM20-NEXT: or.b32 %r10, %r9, %r5; -+-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; -+-; SM20-NEXT: ret; -+-; -+-; SM35-LABEL: funnel_shift_right_32( -+-; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-EMPTY: -+-; SM35-NEXT: // %bb.0: -+-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; -+-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; -+-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; -+-; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; -+-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; -+-; SM35-NEXT: ret; -+- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) -+- ret i32 %val -+-} -+- -+-define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { -+-; SM20-LABEL: funnel_shift_left_32( -+-; SM20: { -+-; SM20-NEXT: .reg .b32 %r<11>; -+-; SM20-EMPTY: -+-; SM20-NEXT: // %bb.0: -+-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; -+-; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; -+-; SM20-NEXT: and.b32 %r3, %r2, 31; -+-; SM20-NEXT: shl.b32 %r4, %r1, %r3; -+-; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; -+-; SM20-NEXT: shr.u32 %r6, %r5, 1; -+-; SM20-NEXT: not.b32 %r7, %r2; -+-; SM20-NEXT: and.b32 %r8, %r7, 31; -+-; SM20-NEXT: shr.u32 %r9, %r6, %r8; -+-; SM20-NEXT: or.b32 %r10, %r4, %r9; -+-; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; -+-; SM20-NEXT: ret; -+-; -+-; SM35-LABEL: funnel_shift_left_32( -+-; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-EMPTY: -+-; SM35-NEXT: // %bb.0: -+-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; -+-; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; -+-; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; -+-; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; -+-; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; -+-; SM35-NEXT: ret; -+- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) -+- ret i32 %val -+-} -+- -+-define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { -+-; SM20-LABEL: funnel_shift_right_64( -+-; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<7>; -+-; SM20-EMPTY: -+-; SM20-NEXT: // %bb.0: -+-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; -+-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; -+-; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; -+-; SM20-NEXT: shl.b64 %rd4, %rd1, 1; -+-; SM20-NEXT: not.b32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; -+-; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; -+-; SM20-NEXT: ret; -+-; -+-; SM35-LABEL: funnel_shift_right_64( -+-; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<7>; -+-; SM35-EMPTY: -+-; SM35-NEXT: // %bb.0: -+-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; -+-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; -+-; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; -+-; SM35-NEXT: shl.b64 %rd4, %rd1, 1; -+-; SM35-NEXT: not.b32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; -+-; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; -+-; SM35-NEXT: ret; -+- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) -+- ret i64 %val -+-} -+- -+-define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { -+-; SM20-LABEL: funnel_shift_left_64( -+-; SM20: { -+-; SM20-NEXT: .reg .b32 %r<5>; -+-; SM20-NEXT: .reg .b64 %rd<7>; -+-; SM20-EMPTY: -+-; SM20-NEXT: // %bb.0: -+-; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; -+-; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; -+-; SM20-NEXT: and.b32 %r2, %r1, 63; -+-; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; -+-; SM20-NEXT: shr.u64 %rd4, %rd3, 1; -+-; SM20-NEXT: not.b32 %r3, %r1; -+-; SM20-NEXT: and.b32 %r4, %r3, 63; -+-; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; -+-; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; -+-; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; -+-; SM20-NEXT: ret; -+-; -+-; SM35-LABEL: funnel_shift_left_64( -+-; SM35: { -+-; SM35-NEXT: .reg .b32 %r<5>; -+-; SM35-NEXT: .reg .b64 %rd<7>; -+-; SM35-EMPTY: -+-; SM35-NEXT: // %bb.0: -+-; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; -+-; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; -+-; SM35-NEXT: and.b32 %r2, %r1, 63; -+-; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; -+-; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; -+-; SM35-NEXT: shr.u64 %rd4, %rd3, 1; -+-; SM35-NEXT: not.b32 %r3, %r1; -+-; SM35-NEXT: and.b32 %r4, %r3, 63; -+-; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; -+-; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; -+-; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; -+-; SM35-NEXT: ret; -+- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) -+- ret i64 %val -+-} -+- -+diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll -+--- a/llvm/test/DebugInfo/NVPTX/debug-info.ll -++++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll -+@@ -25,10 +25,6 @@ -+ ; CHECK-DAG: .reg .b64 %rd<8>; -+ ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 -+ ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; -+-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -+-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -+-; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -+-; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -+ ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 -+ ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; -+ ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 -+@@ -42,6 +38,10 @@ -+ ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 -+ ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; -+ ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; -++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -++; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -++; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -+ ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 -+ ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; -+ ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; -+@@ -2661,22 +2661,22 @@ -+ ; CHECK-NEXT:.b32 4579 // DW_AT_type -+ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine -+ ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin -+-; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc -+-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc -++; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc -++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc -+ ; CHECK-NEXT:.b8 1 // DW_AT_call_file -+ ; CHECK-NEXT:.b8 6 // DW_AT_call_line -+ ; CHECK-NEXT:.b8 11 // DW_AT_call_column -+ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine -+ ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin -+-; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc -+-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc -++; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc -++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc -+ ; CHECK-NEXT:.b8 1 // DW_AT_call_file -+ ; CHECK-NEXT:.b8 6 // DW_AT_call_line -+ ; CHECK-NEXT:.b8 24 // DW_AT_call_column -+ ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine -+ ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin -+-; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc -+-; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc -++; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc -++; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc -+ ; CHECK-NEXT:.b8 1 // DW_AT_call_file -+ ; CHECK-NEXT:.b8 6 // DW_AT_call_line -+ ; CHECK-NEXT:.b8 37 // DW_AT_call_column +-diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst +---- a/llvm/docs/NVPTXUsage.rst +-+++ b/llvm/docs/NVPTXUsage.rst +-@@ -127,6 +127,69 @@ +- NVPTX Intrinsics +- ================ +- +-+Address Space Conversion +-+------------------------ +-+ +-+'``llvm.nvvm.ptr.*.to.gen``' Intrinsics +-+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +-+ +-+Syntax: +-+""""""" +-+ +-+These are overloaded intrinsics. You can use these on any pointer types. +-+ +-+.. code-block:: llvm +-+ +-+ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) +-+ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) +-+ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) +-+ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) +-+ +-+Overview: +-+""""""""" +-+ +-+The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic +-+address space to a generic address space pointer. +-+ +-+Semantics: +-+"""""""""" +-+ +-+These intrinsics modify the pointer value to be a valid generic address space +-+pointer. +-+ +-+ +-+'``llvm.nvvm.ptr.gen.to.*``' Intrinsics +-+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +-+ +-+Syntax: +-+""""""" +-+ +-+These are overloaded intrinsics. You can use these on any pointer types. +-+ +-+.. code-block:: llvm +-+ +-+ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +-+ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) +-+ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) +-+ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) +-+ +-+Overview: +-+""""""""" +-+ +-+The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic +-+address space to a pointer in the target address space. Note that these +-+intrinsics are only useful if the address space of the target address space of +-+the pointer is known. It is not legal to use address space conversion +-+intrinsics to convert a pointer from one non-generic address space to another +-+non-generic address space. +-+ +-+Semantics: +-+"""""""""" +-+ +-+These intrinsics modify the pointer value to be a valid pointer in the target +-+non-generic address space. +-+ +-+ +- Reading PTX Special Registers +- ----------------------------- +- +-diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst +---- a/llvm/docs/ReleaseNotes.rst +-+++ b/llvm/docs/ReleaseNotes.rst +-@@ -63,24 +63,6 @@ +- * ``llvm.nvvm.bitcast.d2ll`` +- * ``llvm.nvvm.bitcast.ll2d`` +- +--* Remove the following intrinsics which can be replaced with a funnel-shift: +-- +-- * ``llvm.nvvm.rotate.b32`` +-- * ``llvm.nvvm.rotate.right.b64`` +-- * ``llvm.nvvm.rotate.b64`` +-- +--* Remove the following intrinsics which can be replaced with an +-- ``addrspacecast``: +-- +-- * ``llvm.nvvm.ptr.gen.to.global`` +-- * ``llvm.nvvm.ptr.gen.to.shared`` +-- * ``llvm.nvvm.ptr.gen.to.constant`` +-- * ``llvm.nvvm.ptr.gen.to.local`` +-- * ``llvm.nvvm.ptr.global.to.gen`` +-- * ``llvm.nvvm.ptr.shared.to.gen`` +-- * ``llvm.nvvm.ptr.constant.to.gen`` +-- * ``llvm.nvvm.ptr.local.to.gen`` +-- +- Changes to LLVM infrastructure +- ------------------------------ +- +-diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td +---- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +-+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td +-@@ -30,18 +30,10 @@ +- // * llvm.nvvm.max.ui --> select(x ule y, x, y) +- // * llvm.nvvm.max.ull --> ibid. +- // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 +--// * llvm.nvvm.bitcast.f2i --> bitcast +--// * llvm.nvvm.bitcast.i2f --> ibid. +--// * llvm.nvvm.bitcast.d2ll --> ibid. +--// * llvm.nvvm.bitcast.ll2d --> ibid. +--// * llvm.nvvm.ptr.gen.to.global --> addrspacecast +--// * llvm.nvvm.ptr.gen.to.shared --> ibid. +--// * llvm.nvvm.ptr.gen.to.constant --> ibid. +--// * llvm.nvvm.ptr.gen.to.local --> ibid. +--// * llvm.nvvm.ptr.global.to.gen --> ibid. +--// * llvm.nvvm.ptr.shared.to.gen --> ibid. +--// * llvm.nvvm.ptr.constant.to.gen --> ibid. +--// * llvm.nvvm.ptr.local.to.gen --> ibid. +-+// * llvm.nvvm.bitcast.f2i --> bitcast +-+// * llvm.nvvm.bitcast.i2f --> ibid. +-+// * llvm.nvvm.bitcast.d2ll --> ibid. +-+// * llvm.nvvm.bitcast.ll2d --> ibid. +- +- def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr +- def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr +-@@ -1610,6 +1602,40 @@ +- [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], +- "llvm.nvvm.ldg.global.p">; +- +-+// Use for generic pointers +-+// - These intrinsics are used to convert address spaces. +-+// - The input pointer and output pointer must have the same type, except for +-+// the address-space. (This restriction is not enforced here as there is +-+// currently no way to describe it). +-+// - This complements the llvm bitcast, which can be used to cast one type +-+// of pointer to another type of pointer, while the address space remains +-+// the same. +-+def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.local.to.gen">; +-+def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.shared.to.gen">; +-+def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.global.to.gen">; +-+def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.constant.to.gen">; +-+ +-+def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.gen.to.global">; +-+def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.gen.to.shared">; +-+def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.gen.to.local">; +-+def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], +-+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], +-+ "llvm.nvvm.ptr.gen.to.constant">; +-+ +- // Used in nvvm internally to help address space opt and ptx code generation +- // This is for params that are passed to kernel functions by pointer by-val. +- def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], +-@@ -4453,6 +4479,22 @@ +- "llvm.nvvm.sust.p.3d.v4i32.trap">, +- ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; +- +-+ +-+def int_nvvm_rotate_b32 +-+ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], +-+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, +-+ ClangBuiltin<"__nvvm_rotate_b32">; +-+ +-+def int_nvvm_rotate_b64 +-+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], +-+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, +-+ ClangBuiltin<"__nvvm_rotate_b64">; +-+ +-+def int_nvvm_rotate_right_b64 +-+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], +-+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, +-+ ClangBuiltin<"__nvvm_rotate_right_b64">; +-+ +- def int_nvvm_swap_lo_hi_b64 +- : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], +- [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, +-diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp +---- a/llvm/lib/IR/AutoUpgrade.cpp +-+++ b/llvm/lib/IR/AutoUpgrade.cpp +-@@ -1272,19 +1272,6 @@ +- // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} +- Expand = +- Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; +-- else if (Name.consume_front("rotate.")) +-- // nvvm.rotate.{b32,b64,right.b64} +-- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; +-- else if (Name.consume_front("ptr.gen.to.")) +-- // nvvm.ptr.gen.to.{local,shared,global,constant} +-- Expand = Name.starts_with("local") || Name.starts_with("shared") || +-- Name.starts_with("global") || Name.starts_with("constant"); +-- else if (Name.consume_front("ptr.")) +-- // nvvm.ptr.{local,shared,global,constant}.to.gen +-- Expand = +-- (Name.consume_front("local") || Name.consume_front("shared") || +-- Name.consume_front("global") || Name.consume_front("constant")) && +-- Name.starts_with(".to.gen"); +- else +- Expand = false; +- +-@@ -2271,117 +2258,6 @@ +- } +- } +- +--static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, +-- Function *F, IRBuilder<> &Builder) { +-- Value *Rep = nullptr; +-- +-- if (Name == "abs.i" || Name == "abs.ll") { +-- Value *Arg = CI->getArgOperand(0); +-- Value *Neg = Builder.CreateNeg(Arg, "neg"); +-- Value *Cmp = Builder.CreateICmpSGE( +-- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); +-- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); +-- } else if (Name.starts_with("atomic.load.add.f32.p") || +-- Name.starts_with("atomic.load.add.f64.p")) { +-- Value *Ptr = CI->getArgOperand(0); +-- Value *Val = CI->getArgOperand(1); +-- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), +-- AtomicOrdering::SequentiallyConsistent); +-- } else if (Name.consume_front("max.") && +-- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +-- Name == "ui" || Name == "ull")) { +-- Value *Arg0 = CI->getArgOperand(0); +-- Value *Arg1 = CI->getArgOperand(1); +-- Value *Cmp = Name.starts_with("u") +-- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") +-- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); +-- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); +-- } else if (Name.consume_front("min.") && +-- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +-- Name == "ui" || Name == "ull")) { +-- Value *Arg0 = CI->getArgOperand(0); +-- Value *Arg1 = CI->getArgOperand(1); +-- Value *Cmp = Name.starts_with("u") +-- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") +-- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); +-- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); +-- } else if (Name == "clz.ll") { +-- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. +-- Value *Arg = CI->getArgOperand(0); +-- Value *Ctlz = Builder.CreateCall( +-- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, +-- {Arg->getType()}), +-- {Arg, Builder.getFalse()}, "ctlz"); +-- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); +-- } else if (Name == "popc.ll") { +-- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an +-- // i64. +-- Value *Arg = CI->getArgOperand(0); +-- Value *Popc = Builder.CreateCall( +-- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, +-- {Arg->getType()}), +-- Arg, "ctpop"); +-- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); +-- } else if (Name == "h2f") { +-- Rep = Builder.CreateCall( +-- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, +-- {Builder.getFloatTy()}), +-- CI->getArgOperand(0), "h2f"); +-- } else if (Name.consume_front("bitcast.") && +-- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || +-- Name == "d2ll")) { +-- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); +-- } else if (Name == "rotate.b32") { +-- Value *Arg = CI->getOperand(0); +-- Value *ShiftAmt = CI->getOperand(1); +-- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, +-- {Arg, Arg, ShiftAmt}); +-- } else if (Name == "rotate.b64") { +-- Type *Int64Ty = Builder.getInt64Ty(); +-- Value *Arg = CI->getOperand(0); +-- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); +-- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, +-- {Arg, Arg, ZExtShiftAmt}); +-- } else if (Name == "rotate.right.b64") { +-- Type *Int64Ty = Builder.getInt64Ty(); +-- Value *Arg = CI->getOperand(0); +-- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); +-- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, +-- {Arg, Arg, ZExtShiftAmt}); +-- } else if ((Name.consume_front("ptr.gen.to.") && +-- (Name.starts_with("local") || Name.starts_with("shared") || +-- Name.starts_with("global") || Name.starts_with("constant"))) || +-- (Name.consume_front("ptr.") && +-- (Name.consume_front("local") || Name.consume_front("shared") || +-- Name.consume_front("global") || +-- Name.consume_front("constant")) && +-- Name.starts_with(".to.gen"))) { +-- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); +-- } else { +-- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); +-- if (IID != Intrinsic::not_intrinsic && +-- !F->getReturnType()->getScalarType()->isBFloatTy()) { +-- rename(F); +-- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); +-- SmallVector Args; +-- for (size_t I = 0; I < NewFn->arg_size(); ++I) { +-- Value *Arg = CI->getArgOperand(I); +-- Type *OldType = Arg->getType(); +-- Type *NewType = NewFn->getArg(I)->getType(); +-- Args.push_back( +-- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) +-- ? Builder.CreateBitCast(Arg, NewType) +-- : Arg); +-- } +-- Rep = Builder.CreateCall(NewFn, Args); +-- if (F->getReturnType()->isIntegerTy()) +-- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); +-- } +-- } +-- +-- return Rep; +--} +-- +- static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, +- IRBuilder<> &Builder) { +- LLVMContext &C = F->getContext(); +-@@ -4332,8 +4208,85 @@ +- +- if (!IsX86 && Name == "stackprotectorcheck") { +- Rep = nullptr; +-+ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { +-+ Value *Arg = CI->getArgOperand(0); +-+ Value *Neg = Builder.CreateNeg(Arg, "neg"); +-+ Value *Cmp = Builder.CreateICmpSGE( +-+ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); +-+ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); +-+ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || +-+ Name.starts_with("atomic.load.add.f64.p"))) { +-+ Value *Ptr = CI->getArgOperand(0); +-+ Value *Val = CI->getArgOperand(1); +-+ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), +-+ AtomicOrdering::SequentiallyConsistent); +-+ } else if (IsNVVM && Name.consume_front("max.") && +-+ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +-+ Name == "ui" || Name == "ull")) { +-+ Value *Arg0 = CI->getArgOperand(0); +-+ Value *Arg1 = CI->getArgOperand(1); +-+ Value *Cmp = Name.starts_with("u") +-+ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") +-+ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); +-+ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); +-+ } else if (IsNVVM && Name.consume_front("min.") && +-+ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || +-+ Name == "ui" || Name == "ull")) { +-+ Value *Arg0 = CI->getArgOperand(0); +-+ Value *Arg1 = CI->getArgOperand(1); +-+ Value *Cmp = Name.starts_with("u") +-+ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") +-+ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); +-+ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); +-+ } else if (IsNVVM && Name == "clz.ll") { +-+ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. +-+ Value *Arg = CI->getArgOperand(0); +-+ Value *Ctlz = Builder.CreateCall( +-+ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, +-+ {Arg->getType()}), +-+ {Arg, Builder.getFalse()}, "ctlz"); +-+ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); +-+ } else if (IsNVVM && Name == "popc.ll") { +-+ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an +-+ // i64. +-+ Value *Arg = CI->getArgOperand(0); +-+ Value *Popc = Builder.CreateCall( +-+ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, +-+ {Arg->getType()}), +-+ Arg, "ctpop"); +-+ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); +- } else if (IsNVVM) { +-- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); +-+ if (Name == "h2f") { +-+ Rep = +-+ Builder.CreateCall(Intrinsic::getDeclaration( +-+ F->getParent(), Intrinsic::convert_from_fp16, +-+ {Builder.getFloatTy()}), +-+ CI->getArgOperand(0), "h2f"); +-+ } else if (Name.consume_front("bitcast.") && +-+ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || +-+ Name == "d2ll")) { +-+ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); +-+ } else { +-+ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); +-+ if (IID != Intrinsic::not_intrinsic && +-+ !F->getReturnType()->getScalarType()->isBFloatTy()) { +-+ rename(F); +-+ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); +-+ SmallVector Args; +-+ for (size_t I = 0; I < NewFn->arg_size(); ++I) { +-+ Value *Arg = CI->getArgOperand(I); +-+ Type *OldType = Arg->getType(); +-+ Type *NewType = NewFn->getArg(I)->getType(); +-+ Args.push_back((OldType->isIntegerTy() && +-+ NewType->getScalarType()->isBFloatTy()) +-+ ? Builder.CreateBitCast(Arg, NewType) +-+ : Arg); +-+ } +-+ Rep = Builder.CreateCall(NewFn, Args); +-+ if (F->getReturnType()->isIntegerTy()) +-+ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); +-+ } +-+ } +- } else if (IsX86) { +- Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); +- } else if (IsARM) { +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +---- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +-+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +-@@ -292,7 +292,6 @@ +- static const LLT S224 = LLT::scalar(224); +- static const LLT S256 = LLT::scalar(256); +- static const LLT S512 = LLT::scalar(512); +--static const LLT S1024 = LLT::scalar(1024); +- static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); +- +- static const LLT V2S8 = LLT::fixed_vector(2, 8); +-@@ -333,8 +332,8 @@ +- static const LLT V2S128 = LLT::fixed_vector(2, 128); +- static const LLT V4S128 = LLT::fixed_vector(4, 128); +- +--static std::initializer_list AllScalarTypes = { +-- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; +-+static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, +-+ S160, S224, S256, S512}; +- +- static std::initializer_list AllS16Vectors{ +- V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; +-@@ -890,11 +889,10 @@ +- .clampScalar(0, S16, S64); +- +- getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) +-- .legalIf(isRegisterClassType(0)) +-+ .legalIf(isRegisterType(0)) +- // s1 and s16 are special cases because they have legal operations on +- // them, but don't really occupy registers in the normal way. +- .legalFor({S1, S16}) +-- .clampNumElements(0, V16S32, V32S32) +- .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) +- .clampScalarOrElt(0, S32, MaxScalar) +- .widenScalarToNextPow2(0, 32) +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +---- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +-+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +-@@ -174,6 +174,10 @@ +- def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" +- "&& Subtarget->getPTXVersion() >= 64)">; +- +-+def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; +-+def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; +-+def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; +-+ +- def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; +- def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; +- +-@@ -1661,6 +1665,167 @@ +- "brev.b64 \t$dst, $a;", +- [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; +- +-+// +-+// Rotate: Use ptx shf instruction if available. +-+// +-+ +-+// 32 bit r2 = rotl r1, n +-+// => +-+// r2 = shf.l r1, r1, n +-+def ROTL32imm_hw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), +-+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, +-+ Requires<[hasHWROT32]>; +-+ +-+def ROTL32reg_hw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +-+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +-+ Requires<[hasHWROT32]>; +-+ +-+// 32 bit r2 = rotr r1, n +-+// => +-+// r2 = shf.r r1, r1, n +-+def ROTR32imm_hw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), +-+ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, +-+ Requires<[hasHWROT32]>; +-+ +-+def ROTR32reg_hw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +-+ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +-+ Requires<[hasHWROT32]>; +-+ +-+// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. +-+def ROT32imm_sw : +-+ NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), +-+ "{{\n\t" +-+ ".reg .b32 %lhs;\n\t" +-+ ".reg .b32 %rhs;\n\t" +-+ "shl.b32 \t%lhs, $src, $amt1;\n\t" +-+ "shr.b32 \t%rhs, $src, $amt2;\n\t" +-+ "add.u32 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ []>; +-+ +-+def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); +-+}]>; +-+ +-+def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), +-+ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, +-+ Requires<[noHWROT32]>; +-+def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), +-+ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, +-+ Requires<[noHWROT32]>; +-+ +-+// 32-bit software rotate left by register. +-+def ROTL32reg_sw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +-+ "{{\n\t" +-+ ".reg .b32 %lhs;\n\t" +-+ ".reg .b32 %rhs;\n\t" +-+ ".reg .b32 %amt2;\n\t" +-+ "shl.b32 \t%lhs, $src, $amt;\n\t" +-+ "sub.s32 \t%amt2, 32, $amt;\n\t" +-+ "shr.b32 \t%rhs, $src, %amt2;\n\t" +-+ "add.u32 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +-+ Requires<[noHWROT32]>; +-+ +-+// 32-bit software rotate right by register. +-+def ROTR32reg_sw : +-+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), +-+ "{{\n\t" +-+ ".reg .b32 %lhs;\n\t" +-+ ".reg .b32 %rhs;\n\t" +-+ ".reg .b32 %amt2;\n\t" +-+ "shr.b32 \t%lhs, $src, $amt;\n\t" +-+ "sub.s32 \t%amt2, 32, $amt;\n\t" +-+ "shl.b32 \t%rhs, $src, %amt2;\n\t" +-+ "add.u32 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, +-+ Requires<[noHWROT32]>; +-+ +-+// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. +-+def ROT64imm_sw : +-+ NVPTXInst<(outs Int64Regs:$dst), +-+ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), +-+ "{{\n\t" +-+ ".reg .b64 %lhs;\n\t" +-+ ".reg .b64 %rhs;\n\t" +-+ "shl.b64 \t%lhs, $src, $amt1;\n\t" +-+ "shr.b64 \t%rhs, $src, $amt2;\n\t" +-+ "add.u64 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ []>; +-+ +-+def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); +-+}]>; +-+ +-+def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), +-+ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; +-+def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), +-+ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; +-+ +-+// 64-bit software rotate left by register. +-+def ROTL64reg_sw : +-+ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), +-+ "{{\n\t" +-+ ".reg .b64 %lhs;\n\t" +-+ ".reg .b64 %rhs;\n\t" +-+ ".reg .u32 %amt2;\n\t" +-+ "and.b32 \t%amt2, $amt, 63;\n\t" +-+ "shl.b64 \t%lhs, $src, %amt2;\n\t" +-+ "sub.u32 \t%amt2, 64, %amt2;\n\t" +-+ "shr.b64 \t%rhs, $src, %amt2;\n\t" +-+ "add.u64 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; +-+ +-+def ROTR64reg_sw : +-+ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), +-+ "{{\n\t" +-+ ".reg .b64 %lhs;\n\t" +-+ ".reg .b64 %rhs;\n\t" +-+ ".reg .u32 %amt2;\n\t" +-+ "and.b32 \t%amt2, $amt, 63;\n\t" +-+ "shr.b64 \t%lhs, $src, %amt2;\n\t" +-+ "sub.u32 \t%amt2, 64, %amt2;\n\t" +-+ "shl.b64 \t%rhs, $src, %amt2;\n\t" +-+ "add.u64 \t$dst, %lhs, %rhs;\n\t" +-+ "}}", +-+ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; +-+ +-+// +-+// Funnnel shift in clamp mode +-+// +-+ +-+// Create SDNodes so they can be used in the DAG code, e.g. +-+// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) +-+def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; +-+def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; +-+ +-+def FUNSHFLCLAMP : +-+ NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-+ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", +-+ [(set Int32Regs:$dst, +-+ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; +-+ +-+def FUNSHFRCLAMP : +-+ NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-+ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", +-+ [(set Int32Regs:$dst, +-+ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; +- +- // +- // BFE - bit-field extract +-@@ -3492,42 +3657,6 @@ +- def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), +- (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; +- +--// +--// Funnel-Shift +--// +-- +--// Create SDNodes so they can be used in the DAG code, e.g. +--// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) +--def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; +--def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; +-- +--// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so +--// no side effects. +--let hasSideEffects = false in { +-- multiclass ShfInst { +-- def _i +-- : NVPTXInst<(outs Int32Regs:$dst), +-- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +-- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", +-- [(set Int32Regs:$dst, +-- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, +-- Requires<[hasHWROT32]>; +-- +-- def _r +-- : NVPTXInst<(outs Int32Regs:$dst), +-- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", +-- [(set Int32Regs:$dst, +-- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, +-- Requires<[hasHWROT32]>; +-- } +-- +-- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; +-- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; +-- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; +-- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; +--} +-- +- // Count leading zeros +- let hasSideEffects = false in { +- def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +---- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +-+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +-@@ -2537,45 +2537,59 @@ +- : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; +- +- +--multiclass NG_TO_G { +-+multiclass NG_TO_G { +- def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), +-- "cvta." # Str # ".u32 \t$result, $src;", []>; +-+ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), +-+ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; +- def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), +-- "cvta." # Str # ".u64 \t$result, $src;", []>; +-+ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), +-+ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; +-+ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), +-+ "{{ .reg .b64 %tmp;\n\t" +-+ #" cvt.u64.u32 \t%tmp, $src;\n\t" +-+ #" cvta." # Str # ".u64 \t$result, %tmp; }}", +-+ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, +-+ Requires<[ShortPtr]>; +- } +- +--multiclass G_TO_NG { +-+multiclass G_TO_NG { +- def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), +-- "cvta.to." # Str # ".u32 \t$result, $src;", []>; +-+ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), +-+ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; +- def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), +-- "cvta.to." # Str # ".u64 \t$result, $src;", []>; +-+ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), +-+ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; +-+ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), +-+ "{{ .reg .b64 %tmp;\n\t" +-+ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" +-+ #" cvt.u32.u64 \t$result, %tmp; }}", +-+ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, +-+ Requires<[ShortPtr]>; +- } +- +--defm cvta_local : NG_TO_G<"local">; +--defm cvta_shared : NG_TO_G<"shared">; +--defm cvta_global : NG_TO_G<"global">; +--defm cvta_const : NG_TO_G<"const">; +-- +--defm cvta_to_local : G_TO_NG<"local">; +--defm cvta_to_shared : G_TO_NG<"shared">; +--defm cvta_to_global : G_TO_NG<"global">; +--defm cvta_to_const : G_TO_NG<"const">; +-- +--// nvvm.ptr.param.to.gen +--defm cvta_param : NG_TO_G<"param">; +-- +--def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), +-- (cvta_param Int32Regs:$src)>; +-- +--def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), +-- (cvta_param_64 Int64Regs:$src)>; +-+defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; +-+defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; +-+defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; +-+defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; +-+defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; +-+ +-+defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; +-+defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; +-+defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; +-+defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; +- +- // nvvm.ptr.gen.to.param +--def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), +-- (IMOV32rr Int32Regs:$src)>; +-+def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), +-+ (ins Int32Regs:$src), +-+ "mov.u32 \t$result, $src;", +-+ [(set Int32Regs:$result, +-+ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; +-+def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), +-+ (ins Int64Regs:$src), +-+ "mov.u64 \t$result, $src;", +-+ [(set Int64Regs:$result, +-+ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; +- +--def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), +-- (IMOV64rr Int64Regs:$src)>; +- +- // nvvm.move intrinsicc +- def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), +-@@ -2618,6 +2632,24 @@ +- [(set Int64Regs:$r, +- (int_nvvm_move_ptr texternalsym:$s))]>;*/ +- +-+ +-+// MoveParam %r1, param +-+// ptr_local_to_gen %r2, %r1 +-+// ptr_gen_to_local %r3, %r2 +-+// -> +-+// mov %r1, param +-+ +-+// @TODO: Revisit this. There is a type +-+// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym +-+// instructions are not currently defined. However, we can use the ptr +-+// variants and the asm printer will do the right thing. +-+def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen +-+ (MoveParam texternalsym:$src)))), +-+ (nvvm_move_ptr64 texternalsym:$src)>; +-+def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen +-+ (MoveParam texternalsym:$src)))), +-+ (nvvm_move_ptr32 texternalsym:$src)>; +-+ +- def texsurf_handles +- : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), +- "mov.u64 \t$result, $src;", []>; +-@@ -2701,9 +2733,134 @@ +- def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; +- +- +-+// rotate builtin support +-+ +-+def ROTATE_B32_HW_IMM +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$src, i32imm:$amt), +-+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, +-+ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, +-+ Requires<[hasHWROT32]> ; +-+ +-+def ROTATE_B32_HW_REG +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$src, Int32Regs:$amt), +-+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", +-+ [(set Int32Regs:$dst, +-+ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, +-+ Requires<[hasHWROT32]> ; +-+ +-+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), +-+ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, +-+ Requires<[noHWROT32]> ; +-+ +-+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), +-+ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, +-+ Requires<[noHWROT32]> ; +-+ +-+let hasSideEffects = false in { +-+ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), +-+ !strconcat("{{\n\t", +-+ ".reg .b32 %dummy;\n\t", +-+ "mov.b64 \t{$dst,%dummy}, $src;\n\t", +-+ "}}"), +-+ []> ; +-+ +-+ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), +-+ !strconcat("{{\n\t", +-+ ".reg .b32 %dummy;\n\t", +-+ "mov.b64 \t{%dummy,$dst}, $src;\n\t", +-+ "}}"), +-+ []> ; +-+} +-+ +-+let hasSideEffects = false in { +-+ def PACK_TWO_INT32 +-+ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), +-+ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; +-+} +-+ +- def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), +-- (V2I32toI64 (I64toI32H Int64Regs:$src), +-- (I64toI32L Int64Regs:$src))> ; +-+ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src))> ; +-+ +-+// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so +-+// no side effects. +-+let hasSideEffects = false in { +-+ def SHF_L_WRAP_B32_IMM +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +-+ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +-+ Requires<[hasHWROT32]>; +-+ +-+ def SHF_L_WRAP_B32_REG +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-+ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +-+ Requires<[hasHWROT32]>; +-+ +-+ def SHF_R_WRAP_B32_IMM +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), +-+ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +-+ Requires<[hasHWROT32]>; +-+ +-+ def SHF_R_WRAP_B32_REG +-+ : NVPTXInst<(outs Int32Regs:$dst), +-+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), +-+ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, +-+ Requires<[hasHWROT32]>; +-+} +-+ +-+// HW version of rotate 64 +-+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), +-+ (PACK_TWO_INT32 +-+ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src), imm:$amt), +-+ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), +-+ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, +-+ Requires<[hasHWROT32]>; +-+ +-+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), +-+ (PACK_TWO_INT32 +-+ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), +-+ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), +-+ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, +-+ Requires<[hasHWROT32]>; +-+ +-+ +-+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), +-+ (PACK_TWO_INT32 +-+ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), +-+ (GET_HI_INT64 Int64Regs:$src), imm:$amt), +-+ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, +-+ Requires<[hasHWROT32]>; +-+ +-+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), +-+ (PACK_TWO_INT32 +-+ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), +-+ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), +-+ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), +-+ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, +-+ Requires<[hasHWROT32]>; +-+ +-+// SW version of rotate 64 +-+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), +-+ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, +-+ Requires<[noHWROT32]>; +-+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), +-+ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, +-+ Requires<[noHWROT32]>; +-+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), +-+ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, +-+ Requires<[noHWROT32]>; +-+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), +-+ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, +-+ Requires<[noHWROT32]>; +-+ +- +- //----------------------------------- +- // Texture Intrinsics +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +---- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +-+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +-@@ -1109,21 +1109,11 @@ +- AddrSpaceCastSDNode *CastN = cast(N); +- unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); +- unsigned DstAddrSpace = CastN->getDestAddressSpace(); +-- SDLoc DL(N); +- assert(SrcAddrSpace != DstAddrSpace && +- "addrspacecast must be between different address spaces"); +- +- if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { +- // Specific to generic +-- +-- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { +-- SDValue CvtNone = +-- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); +-- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, +-- Src, CvtNone); +-- Src = SDValue(Cvt, 0); +-- } +-- +- unsigned Opc; +- switch (SrcAddrSpace) { +- default: report_fatal_error("Bad address space in addrspacecast"); +-@@ -1131,16 +1121,26 @@ +- Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; +- break; +- case ADDRESS_SPACE_SHARED: +-- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +-+ ? NVPTX::cvta_shared_6432 +-+ : NVPTX::cvta_shared_64) +-+ : NVPTX::cvta_shared; +- break; +- case ADDRESS_SPACE_CONST: +-- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +-+ ? NVPTX::cvta_const_6432 +-+ : NVPTX::cvta_const_64) +-+ : NVPTX::cvta_const; +- break; +- case ADDRESS_SPACE_LOCAL: +-- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 +-+ ? NVPTX::cvta_local_6432 +-+ : NVPTX::cvta_local_64) +-+ : NVPTX::cvta_local; +- break; +- } +-- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); +-+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), +-+ Src)); +- return; +- } else { +- // Generic to specific +-@@ -1153,28 +1153,30 @@ +- Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; +- break; +- case ADDRESS_SPACE_SHARED: +-- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +-+ ? NVPTX::cvta_to_shared_3264 +-+ : NVPTX::cvta_to_shared_64) +-+ : NVPTX::cvta_to_shared; +- break; +- case ADDRESS_SPACE_CONST: +-- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +-+ ? NVPTX::cvta_to_const_3264 +-+ : NVPTX::cvta_to_const_64) +-+ : NVPTX::cvta_to_const; +- break; +- case ADDRESS_SPACE_LOCAL: +-- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; +-+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 +-+ ? NVPTX::cvta_to_local_3264 +-+ : NVPTX::cvta_to_local_64) +-+ : NVPTX::cvta_to_local; +- break; +- case ADDRESS_SPACE_PARAM: +-- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; +-+ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 +-+ : NVPTX::nvvm_ptr_gen_to_param; +- break; +- } +-- +-- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); +-- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { +-- SDValue CvtNone = +-- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); +-- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, +-- SDValue(CVTA, 0), CvtNone); +-- } +-- +-- ReplaceNode(N, CVTA); +-+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), +-+ Src)); +- return; +- } +- } +-diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +---- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +-+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +-@@ -594,13 +594,20 @@ +- setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); +- setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); +- +-- setOperationAction({ISD::ROTL, ISD::ROTR}, +-- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, +-- Expand); +-- +-- if (STI.hasHWROT32()) +-- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); +-+ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs +-+ // that don't have h/w rotation we lower them to multi-instruction assembly. +-+ // See ROT*_sw in NVPTXIntrInfo.td +-+ setOperationAction(ISD::ROTL, MVT::i64, Legal); +-+ setOperationAction(ISD::ROTR, MVT::i64, Legal); +-+ setOperationAction(ISD::ROTL, MVT::i32, Legal); +-+ setOperationAction(ISD::ROTR, MVT::i32, Legal); +- +-+ setOperationAction(ISD::ROTL, MVT::i16, Expand); +-+ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); +-+ setOperationAction(ISD::ROTR, MVT::i16, Expand); +-+ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); +-+ setOperationAction(ISD::ROTL, MVT::i8, Expand); +-+ setOperationAction(ISD::ROTR, MVT::i8, Expand); +- setOperationAction(ISD::BSWAP, MVT::i16, Expand); +- +- setOperationAction(ISD::BR_JT, MVT::Other, Custom); +-diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +---- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +-+++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +-@@ -31,19 +31,6 @@ +- declare i64 @llvm.nvvm.bitcast.d2ll(double) +- declare double @llvm.nvvm.bitcast.ll2d(i64) +- +--declare i32 @llvm.nvvm.rotate.b32(i32, i32) +--declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) +--declare i64 @llvm.nvvm.rotate.b64(i64, i32) +-- +--declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +--declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) +--declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) +--declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) +--declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) +--declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) +--declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) +--declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) +-- +- ; CHECK-LABEL: @simple_upgrade +- define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { +- ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) +-@@ -152,42 +139,4 @@ +- %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) +- +- ret void +--} +-- +--; CHECK-LABEL: @rotate +--define void @rotate(i32 %a, i64 %b) { +--; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) +--; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) +--; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) +--; +-- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) +-- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) +-- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) +-- ret void +--} +-- +--; CHECK-LABEL: @addrspacecast +--define void @addrspacecast(ptr %p0) { +--; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) +--; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr +--; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) +--; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr +--; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) +--; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr +--; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) +--; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr +--; +-- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) +-- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) +-- +-- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) +-- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) +-- +-- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) +-- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) +-- +-- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) +-- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) +-- +-- ret void +--} +-+} +-\ No newline at end of file +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll +---- a/llvm/test/CodeGen/AMDGPU/freeze.ll +-+++ b/llvm/test/CodeGen/AMDGPU/freeze.ll +-@@ -1,1856 +0,0 @@ +--; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +--; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s +--; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s +--; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s +--; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s +-- +--define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v2i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v2i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <2 x i32> %a +-- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v3i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v3i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <3 x i32> %a +-- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v4i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v4i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <4 x i32> %a +-- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v5i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v5i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v5i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v5i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <5 x i32> %a +-- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v6i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v6i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v6i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v6i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <6 x i32> %a +-- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v7i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v7i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v7i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v7i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <7 x i32> %a +-- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v8i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v8i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v8i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v8i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <8 x i32> %a +-- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v9i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x2 +--; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v9i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x2 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v9i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x2 +--; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v9i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x2 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <9 x i32> %a +-- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v10i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: s_clause 0x2 +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 +--; GFX10-NEXT: s_waitcnt vmcnt(2) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_waitcnt vmcnt(1) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v10i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: s_clause 0x2 +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 +--; GFX11-NEXT: s_waitcnt vmcnt(2) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_waitcnt vmcnt(1) +--; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <10 x i32> %a +-- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v11i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: s_clause 0x2 +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 +--; GFX10-NEXT: s_waitcnt vmcnt(2) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_waitcnt vmcnt(1) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v11i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: s_clause 0x2 +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 +--; GFX11-NEXT: s_waitcnt vmcnt(2) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_waitcnt vmcnt(1) +--; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <11 x i32> %a +-- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_v12i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: s_clause 0x2 +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-NEXT: s_waitcnt vmcnt(2) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_waitcnt vmcnt(1) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_v12i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: s_clause 0x2 +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-NEXT: s_waitcnt vmcnt(2) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_waitcnt vmcnt(1) +--; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <12 x i32> %a +-- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +--define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v13i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x3 +--; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v13i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x3 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v13i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x3 +--; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v13i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x3 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <13 x i32> %a +-- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v14i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x3 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v14i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x3 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v14i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x3 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v14i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x3 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <14 x i32> %a +-- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v15i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x3 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v15i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x3 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v15i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x3 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v15i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x3 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <15 x i32> %a +-- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v16i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x3 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v16i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x3 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v16i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x3 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v16i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x3 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <16 x i32> %a +-- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v17i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x4 +--; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v17i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x4 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v17i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x4 +--; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v17i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x4 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <17 x i32> %a +-- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v18i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x4 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v18i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x4 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v18i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x4 +--; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v18i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x4 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <18 x i32> %a +-- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v19i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x4 +--; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v19i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x4 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v19i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x4 +--; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v19i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x4 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <19 x i32> %a +-- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v20i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x4 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v20i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x4 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v20i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x4 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v20i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x4 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <20 x i32> %a +-- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v21i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x5 +--; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v21i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x5 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v21i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x5 +--; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v21i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x5 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <21 x i32> %a +-- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v22i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x5 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v22i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x5 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v22i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x5 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v22i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x5 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <22 x i32> %a +-- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v30i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x7 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +--; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v30i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x7 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +--; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v30i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x7 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +--; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v30i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x7 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +--; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <30 x i32> %a +-- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v31i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x7 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +--; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v31i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x7 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +--; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v31i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x7 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +--; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v31i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x7 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +--; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <31 x i32> %a +-- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_v32i32: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x7 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_v32i32: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x7 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_v32i32: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x7 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 +--; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 +--; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 +--; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 +--; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 +--; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off +--; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_v32i32: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x7 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 +--; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 +--; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 +--; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 +--; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 +--; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze <32 x i32> %a +-- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_i32: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dword v0, v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dword v[2:3], v0, off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_i32: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b32 v0, v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b32 v[2:3], v0, off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load i32, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze i32 %a +-- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_i64: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_i64: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load i64, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze i64 %a +-- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_float: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dword v0, v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dword v[2:3], v0, off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_float: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b32 v0, v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b32 v[2:3], v0, off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load float, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze float %a +-- store float %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-LABEL: freeze_i128: +--; GFX10: ; %bb.0: +--; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-NEXT: s_waitcnt vmcnt(0) +--; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-LABEL: freeze_i128: +--; GFX11: ; %bb.0: +--; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-NEXT: s_waitcnt vmcnt(0) +--; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-NEXT: s_setpc_b64 s[30:31] +-- %a = load i128, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze i128 %a +-- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-- +--define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { +--; GFX10-SDAG-LABEL: freeze_i256: +--; GFX10-SDAG: ; %bb.0: +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-SDAG-NEXT: s_clause 0x1 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 +--; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 +--; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off +--; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX10-GISEL-LABEL: freeze_i256: +--; GFX10-GISEL: ; %bb.0: +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX10-GISEL-NEXT: s_clause 0x1 +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off +--; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off +--; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 +--; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-SDAG-LABEL: freeze_i256: +--; GFX11-SDAG: ; %bb.0: +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-SDAG-NEXT: s_clause 0x1 +--; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 +--; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 +--; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) +--; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off +--; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] +--; +--; GFX11-GISEL-LABEL: freeze_i256: +--; GFX11-GISEL: ; %bb.0: +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +--; GFX11-GISEL-NEXT: s_clause 0x1 +--; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off +--; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off +--; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) +--; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 +--; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] +-- %a = load i256, ptr addrspace(1) %ptra, align 4 +-- %freeze = freeze i256 %a +-- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 +-- ret void +--} +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir +-@@ -171,9 +171,11 @@ +- ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 +- ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 +- ; GCN-NEXT: {{ $}} +-- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF +-- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) +-- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) +-+ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF +-+ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 +-+ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 +-+ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 +-+ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] +- %0:sgpr(s192) = G_IMPLICIT_DEF +- %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 +- S_ENDPGM 0, implicit %1, implicit %2, implicit %3 +-@@ -292,11 +294,11 @@ +- ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) +- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) +- ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) +-- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) +-- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) +-- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) +-- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) +-- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) +-+ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) +-+ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) +-+ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) +-+ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) +-+ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) +- %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 +- %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 +- %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir +-@@ -171,8 +171,12 @@ +- +- ; CHECK-LABEL: name: test_freeze_s448 +- ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 +-- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] +-- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) +-+ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) +-+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) +-+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) +-+ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) +- %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 +- %1:_(s448) = G_TRUNC %0 +- %2:_(s448) = G_FREEZE %1 +-@@ -395,12 +399,14 @@ +- bb.0: +- +- ; CHECK-LABEL: name: test_freeze_v33s32 +-- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +-- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] +-- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) +-- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) +-+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) +-+ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) +- ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) +- %0:_(<33 x s32>) = G_IMPLICIT_DEF +- %1:_(<33 x s32>) = G_FREEZE %0 +-@@ -413,10 +419,12 @@ +- bb.0: +- +- ; CHECK-LABEL: name: test_freeze_v64s32 +-- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +-- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] +-- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) +-+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] +-+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) +- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) +- %0:_(<64 x s32>) = G_IMPLICIT_DEF +- %1:_(<64 x s32>) = G_FREEZE %0 +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir +-@@ -135,9 +135,8 @@ +- bb.0: +- +- ; CHECK-LABEL: name: test_implicit_def_s448 +-- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) +-- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 +-+ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 +- ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) +- %0:_(s448) = G_IMPLICIT_DEF +- %1:_(s32) = G_EXTRACT %0, 0 +-@@ -297,6 +296,18 @@ +- ... +- +- --- +-+name: test_implicit_def_v17s32 +-+body: | +-+ bb.0: +-+ +-+ ; CHECK-LABEL: name: test_implicit_def_v17s32 +-+ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) +-+ %0:_(<17 x s32>) = G_IMPLICIT_DEF +-+ S_NOP 0, implicit %0 +-+... +-+ +-+--- +- name: test_implicit_def_v32s32 +- body: | +- bb.0: +-@@ -317,9 +328,9 @@ +- ; CHECK-LABEL: name: test_implicit_def_v33s32 +- ; CHECK: liveins: $vgpr0_vgpr1 +- ; CHECK-NEXT: {{ $}} +-- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +- ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 +- ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) +- ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) +-@@ -337,9 +348,10 @@ +- bb.0: +- +- ; CHECK-LABEL: name: test_implicit_def_v64s32 +-- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) +-+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) +- %0:_(<64 x s32>) = G_IMPLICIT_DEF +- %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 +- S_NOP 0, implicit %0, implicit %1 +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir +-@@ -190,11 +190,13 @@ +- ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 +- ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 +- ; CHECK-NEXT: {{ $}} +-- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 +- ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 +-- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +- ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) +- ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 +- ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) +-@@ -241,8 +243,10 @@ +- ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 +- ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) +- ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) +-- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +- ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) +- ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) +- ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir +-@@ -673,86 +673,88 @@ +- ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) +- ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 +- ; CHECK-NEXT: {{ $}} +-- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF +- ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 +- ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 +- ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] +-- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +- ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 +- ; CHECK-NEXT: G_BR %bb.2 +- ; CHECK-NEXT: {{ $}} +- ; CHECK-NEXT: bb.1: +- ; CHECK-NEXT: successors: %bb.2(0x80000000) +- ; CHECK-NEXT: {{ $}} +-- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) +-- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] +-- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] +-- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] +-- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] +-- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] +-- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] +-- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] +-- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] +-- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] +-- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] +-- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] +-- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] +-- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] +-- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] +-- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] +-- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] +-- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] +-- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] +-- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] +-- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] +-- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] +-- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] +-- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] +-- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] +-- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] +-- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] +-- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] +-- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] +-- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] +-- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] +-- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] +-- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] +-- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] +-- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] +-- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] +-- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] +-- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] +-- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] +-- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] +-- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] +-- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] +-- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] +-- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] +-- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] +-- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] +-- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] +-- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] +-- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] +-- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] +-- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] +-- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] +-- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] +-- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] +-- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] +-- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] +-- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] +-- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] +-- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] +-- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] +-- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] +-- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] +-- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] +-- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] +-- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] +-+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) +-+ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] +-+ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] +-+ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] +-+ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] +-+ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] +-+ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] +-+ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] +-+ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] +-+ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] +-+ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] +-+ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] +-+ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] +-+ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] +-+ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] +-+ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] +-+ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] +-+ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] +-+ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] +-+ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] +-+ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] +-+ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] +-+ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] +-+ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] +-+ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] +-+ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] +-+ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] +-+ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] +-+ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] +-+ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] +-+ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] +-+ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] +-+ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] +-+ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] +-+ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] +-+ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] +-+ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] +-+ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] +-+ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] +-+ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] +-+ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] +-+ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] +-+ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] +-+ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] +-+ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] +-+ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] +-+ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] +-+ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] +-+ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] +-+ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] +-+ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] +-+ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] +-+ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] +-+ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] +-+ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] +-+ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] +-+ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] +-+ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] +-+ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] +-+ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] +-+ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] +-+ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] +-+ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] +-+ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] +-+ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] +- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) +- ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) +- ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) +-@@ -760,10 +762,10 @@ +- ; CHECK-NEXT: G_BR %bb.2 +- ; CHECK-NEXT: {{ $}} +- ; CHECK-NEXT: bb.2: +-- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 +-- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 +-- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 +-- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 +-+ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 +-+ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 +-+ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 +-+ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 +- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) +- ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) +- bb.0: +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +-+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir +-@@ -42,6 +42,8 @@ +- ret void +- } +- +-+ define void @non_power_of_2() { ret void } +-+ +- define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { +- ret void +- } +-@@ -185,6 +187,23 @@ +- ... +- +- --- +-+name: non_power_of_2 +-+legalized: true +-+ +-+body: | +-+ bb.0: +-+ ; CHECK-LABEL: name: non_power_of_2 +-+ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF +-+ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 +-+ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) +-+ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 +-+ %0:_(s448) = G_IMPLICIT_DEF +-+ %1:_(s32) = G_EXTRACT %0:_(s448), 0 +-+ $sgpr0 = COPY %1:_(s32) +-+ SI_RETURN_TO_EPILOG $sgpr0 +-+... +-+ +-+--- +- name: load_constant_v4i16_from_8_align8 +- legalized: true +- +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +---- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +-+++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll +-@@ -0,0 +1,21 @@ +-+; RUN: opt < %s -O3 -S | FileCheck %s +-+ +-+; Address space intrinsics were erroneously marked NoCapture, leading to bad +-+; optimizations (such as the store below being eliminated as dead code). This +-+; test makes sure we don't regress. +-+ +-+declare void @foo(ptr addrspace(1)) +-+ +-+declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) +-+ +-+; CHECK: @bar +-+define void @bar() { +-+ %t1 = alloca i32 +-+; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) +-+; CHECK-NEXT: store i32 10, ptr %t1 +-+ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) +-+ store i32 10, ptr %t1 +-+ call void @foo(ptr addrspace(1) %t2) +-+ ret void +-+} +-+ +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll +---- a/llvm/test/CodeGen/NVPTX/rotate_64.ll +-+++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll +-@@ -1,38 +1,25 @@ +--; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +- ; RUN: llc < %s -march=nvptx64 | FileCheck %s +- ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} +- +- declare i64 @llvm.nvvm.rotate.b64(i64, i32) +- declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) +- +-+; CHECK: rotate64 +- define i64 @rotate64(i64 %a, i32 %b) { +--; CHECK-LABEL: rotate64( +--; CHECK: { +--; CHECK-NEXT: .reg .b64 %rd<5>; +--; CHECK-EMPTY: +--; CHECK-NEXT: // %bb.0: +--; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; +--; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; +--; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; +--; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; +--; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; +--; CHECK-NEXT: ret; +-+; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; +-+; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; +-+; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; +-+; CHECK: ret +- %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) +- ret i64 %val +- } +- +-+; CHECK: rotateright64 +- define i64 @rotateright64(i64 %a, i32 %b) { +--; CHECK-LABEL: rotateright64( +--; CHECK: { +--; CHECK-NEXT: .reg .b64 %rd<5>; +--; CHECK-EMPTY: +--; CHECK-NEXT: // %bb.0: +--; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; +--; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; +--; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; +--; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; +--; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; +--; CHECK-NEXT: ret; +-+; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; +-+; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; +-+; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; +-+; CHECK: ret +- %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) +- ret i64 %val +- } +-diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll +---- a/llvm/test/CodeGen/NVPTX/rotate.ll +-+++ b/llvm/test/CodeGen/NVPTX/rotate.ll +-@@ -9,29 +9,26 @@ +- declare i64 @llvm.nvvm.rotate.b64(i64, i32) +- declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) +- +--declare i64 @llvm.fshl.i64(i64, i64, i64) +--declare i64 @llvm.fshr.i64(i64, i64, i64) +--declare i32 @llvm.fshl.i32(i32, i32, i32) +--declare i32 @llvm.fshr.i32(i32, i32, i32) +-- +-- +- ; SM20: rotate32 +- ; SM35: rotate32 +- define i32 @rotate32(i32 %a, i32 %b) { +- ; SM20-LABEL: rotate32( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<9>; +-+; SM20-NEXT: .reg .b32 %r<4>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; +- ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; +--; SM20-NEXT: and.b32 %r3, %r2, 31; +--; SM20-NEXT: shl.b32 %r4, %r1, %r3; +--; SM20-NEXT: neg.s32 %r5, %r2; +--; SM20-NEXT: and.b32 %r6, %r5, 31; +--; SM20-NEXT: shr.u32 %r7, %r1, %r6; +--; SM20-NEXT: or.b32 %r8, %r4, %r7; +--; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b32 %lhs; +-+; SM20-NEXT: .reg .b32 %rhs; +-+; SM20-NEXT: .reg .b32 %amt2; +-+; SM20-NEXT: shl.b32 %lhs, %r1, %r2; +-+; SM20-NEXT: sub.s32 %amt2, 32, %r2; +-+; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; +-+; SM20-NEXT: add.u32 %r3, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotate32( +-@@ -53,36 +50,45 @@ +- define i64 @rotate64(i64 %a, i32 %b) { +- ; SM20-LABEL: rotate64( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b32 %r<2>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; +- ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM20-NEXT: neg.s32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; +--; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: .reg .u32 %amt2; +-+; SM20-NEXT: and.b32 %amt2, %r1, 63; +-+; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; +-+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotate64( +- ; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b32 %r<6>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; +--; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM35-NEXT: neg.s32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; +--; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b32 %dummy; +-+; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; +-+; SM35-NEXT: } +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b32 %dummy; +-+; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; +-+; SM35-NEXT: } +-+; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; +-+; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; +-+; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; +-+; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) +- ret i64 %val +-@@ -93,36 +99,45 @@ +- define i64 @rotateright64(i64 %a, i32 %b) { +- ; SM20-LABEL: rotateright64( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b32 %r<2>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; +- ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; +--; SM20-NEXT: neg.s32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; +--; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: .reg .u32 %amt2; +-+; SM20-NEXT: and.b32 %amt2, %r1, 63; +-+; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; +-+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotateright64( +- ; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b32 %r<6>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; +--; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; +--; SM35-NEXT: neg.s32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; +--; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b32 %dummy; +-+; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; +-+; SM35-NEXT: } +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b32 %dummy; +-+; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; +-+; SM35-NEXT: } +-+; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; +-+; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; +-+; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; +-+; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) +- ret i64 %val +-@@ -133,14 +148,18 @@ +- define i32 @rotl0(i32 %x) { +- ; SM20-LABEL: rotl0( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +-+; SM20-NEXT: .reg .b32 %r<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; +--; SM20-NEXT: shr.u32 %r2, %r1, 24; +--; SM20-NEXT: shl.b32 %r3, %r1, 8; +--; SM20-NEXT: or.b32 %r4, %r3, %r2; +--; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b32 %lhs; +-+; SM20-NEXT: .reg .b32 %rhs; +-+; SM20-NEXT: shl.b32 %lhs, %r1, 8; +-+; SM20-NEXT: shr.b32 %rhs, %r1, 24; +-+; SM20-NEXT: add.u32 %r2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotl0( +-@@ -158,40 +177,51 @@ +- ret i32 %t2 +- } +- +-+declare i64 @llvm.fshl.i64(i64, i64, i64) +-+declare i64 @llvm.fshr.i64(i64, i64, i64) +-+ +- ; SM35: rotl64 +- define i64 @rotl64(i64 %a, i64 %n) { +- ; SM20-LABEL: rotl64( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b32 %r<2>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; +- ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM20-NEXT: neg.s32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; +--; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: .reg .u32 %amt2; +-+; SM20-NEXT: and.b32 %amt2, %r1, 63; +-+; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; +-+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotl64( +- ; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b32 %r<2>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; +- ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM35-NEXT: neg.s32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; +--; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b64 %lhs; +-+; SM35-NEXT: .reg .b64 %rhs; +-+; SM35-NEXT: .reg .u32 %amt2; +-+; SM35-NEXT: and.b32 %amt2, %r1, 63; +-+; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; +-+; SM35-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; +-+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM35-NEXT: } +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) +- ret i64 %val +-@@ -201,26 +231,34 @@ +- define i64 @rotl64_imm(i64 %a) { +- ; SM20-LABEL: rotl64_imm( +- ; SM20: { +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; +--; SM20-NEXT: shr.u64 %rd2, %rd1, 62; +--; SM20-NEXT: shl.b64 %rd3, %rd1, 2; +--; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: shl.b64 %lhs, %rd1, 2; +-+; SM20-NEXT: shr.b64 %rhs, %rd1, 62; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotl64_imm( +- ; SM35: { +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; +--; SM35-NEXT: shr.u64 %rd2, %rd1, 62; +--; SM35-NEXT: shl.b64 %rd3, %rd1, 2; +--; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b64 %lhs; +-+; SM35-NEXT: .reg .b64 %rhs; +-+; SM35-NEXT: shl.b64 %lhs, %rd1, 2; +-+; SM35-NEXT: shr.b64 %rhs, %rd1, 62; +-+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM35-NEXT: } +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) +- ret i64 %val +-@@ -230,36 +268,44 @@ +- define i64 @rotr64(i64 %a, i64 %n) { +- ; SM20-LABEL: rotr64( +- ; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b32 %r<2>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; +- ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; +--; SM20-NEXT: neg.s32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; +--; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: .reg .u32 %amt2; +-+; SM20-NEXT: and.b32 %amt2, %r1, 63; +-+; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; +-+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotr64( +- ; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b32 %r<2>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; +- ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; +--; SM35-NEXT: neg.s32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; +--; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b64 %lhs; +-+; SM35-NEXT: .reg .b64 %rhs; +-+; SM35-NEXT: .reg .u32 %amt2; +-+; SM35-NEXT: and.b32 %amt2, %r1, 63; +-+; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; +-+; SM35-NEXT: sub.u32 %amt2, 64, %amt2; +-+; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; +-+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM35-NEXT: } +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) +- ret i64 %val +-@@ -269,180 +315,35 @@ +- define i64 @rotr64_imm(i64 %a) { +- ; SM20-LABEL: rotr64_imm( +- ; SM20: { +--; SM20-NEXT: .reg .b64 %rd<5>; +-+; SM20-NEXT: .reg .b64 %rd<3>; +- ; SM20-EMPTY: +- ; SM20-NEXT: // %bb.0: +- ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; +--; SM20-NEXT: shl.b64 %rd2, %rd1, 62; +--; SM20-NEXT: shr.u64 %rd3, %rd1, 2; +--; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM20-NEXT: { +-+; SM20-NEXT: .reg .b64 %lhs; +-+; SM20-NEXT: .reg .b64 %rhs; +-+; SM20-NEXT: shl.b64 %lhs, %rd1, 62; +-+; SM20-NEXT: shr.b64 %rhs, %rd1, 2; +-+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM20-NEXT: } +-+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM20-NEXT: ret; +- ; +- ; SM35-LABEL: rotr64_imm( +- ; SM35: { +--; SM35-NEXT: .reg .b64 %rd<5>; +-+; SM35-NEXT: .reg .b64 %rd<3>; +- ; SM35-EMPTY: +- ; SM35-NEXT: // %bb.0: +- ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; +--; SM35-NEXT: shl.b64 %rd2, %rd1, 62; +--; SM35-NEXT: shr.u64 %rd3, %rd1, 2; +--; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; +-+; SM35-NEXT: { +-+; SM35-NEXT: .reg .b64 %lhs; +-+; SM35-NEXT: .reg .b64 %rhs; +-+; SM35-NEXT: shl.b64 %lhs, %rd1, 62; +-+; SM35-NEXT: shr.b64 %rhs, %rd1, 2; +-+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; +-+; SM35-NEXT: } +-+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; +- ; SM35-NEXT: ret; +- %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) +- ret i64 %val +- } +-- +--define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { +--; SM20-LABEL: funnel_shift_right_32( +--; SM20: { +--; SM20-NEXT: .reg .b32 %r<11>; +--; SM20-EMPTY: +--; SM20-NEXT: // %bb.0: +--; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; +--; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; +--; SM20-NEXT: and.b32 %r3, %r2, 31; +--; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; +--; SM20-NEXT: shr.u32 %r5, %r4, %r3; +--; SM20-NEXT: shl.b32 %r6, %r1, 1; +--; SM20-NEXT: not.b32 %r7, %r2; +--; SM20-NEXT: and.b32 %r8, %r7, 31; +--; SM20-NEXT: shl.b32 %r9, %r6, %r8; +--; SM20-NEXT: or.b32 %r10, %r9, %r5; +--; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; +--; SM20-NEXT: ret; +--; +--; SM35-LABEL: funnel_shift_right_32( +--; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-EMPTY: +--; SM35-NEXT: // %bb.0: +--; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; +--; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; +--; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; +--; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; +--; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; +--; SM35-NEXT: ret; +-- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) +-- ret i32 %val +--} +-- +--define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { +--; SM20-LABEL: funnel_shift_left_32( +--; SM20: { +--; SM20-NEXT: .reg .b32 %r<11>; +--; SM20-EMPTY: +--; SM20-NEXT: // %bb.0: +--; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; +--; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; +--; SM20-NEXT: and.b32 %r3, %r2, 31; +--; SM20-NEXT: shl.b32 %r4, %r1, %r3; +--; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; +--; SM20-NEXT: shr.u32 %r6, %r5, 1; +--; SM20-NEXT: not.b32 %r7, %r2; +--; SM20-NEXT: and.b32 %r8, %r7, 31; +--; SM20-NEXT: shr.u32 %r9, %r6, %r8; +--; SM20-NEXT: or.b32 %r10, %r4, %r9; +--; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; +--; SM20-NEXT: ret; +--; +--; SM35-LABEL: funnel_shift_left_32( +--; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-EMPTY: +--; SM35-NEXT: // %bb.0: +--; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; +--; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; +--; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; +--; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; +--; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; +--; SM35-NEXT: ret; +-- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) +-- ret i32 %val +--} +-- +--define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { +--; SM20-LABEL: funnel_shift_right_64( +--; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<7>; +--; SM20-EMPTY: +--; SM20-NEXT: // %bb.0: +--; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; +--; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; +--; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; +--; SM20-NEXT: shl.b64 %rd4, %rd1, 1; +--; SM20-NEXT: not.b32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; +--; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; +--; SM20-NEXT: ret; +--; +--; SM35-LABEL: funnel_shift_right_64( +--; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<7>; +--; SM35-EMPTY: +--; SM35-NEXT: // %bb.0: +--; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; +--; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; +--; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; +--; SM35-NEXT: shl.b64 %rd4, %rd1, 1; +--; SM35-NEXT: not.b32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; +--; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; +--; SM35-NEXT: ret; +-- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) +-- ret i64 %val +--} +-- +--define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { +--; SM20-LABEL: funnel_shift_left_64( +--; SM20: { +--; SM20-NEXT: .reg .b32 %r<5>; +--; SM20-NEXT: .reg .b64 %rd<7>; +--; SM20-EMPTY: +--; SM20-NEXT: // %bb.0: +--; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; +--; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; +--; SM20-NEXT: and.b32 %r2, %r1, 63; +--; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; +--; SM20-NEXT: shr.u64 %rd4, %rd3, 1; +--; SM20-NEXT: not.b32 %r3, %r1; +--; SM20-NEXT: and.b32 %r4, %r3, 63; +--; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; +--; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; +--; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; +--; SM20-NEXT: ret; +--; +--; SM35-LABEL: funnel_shift_left_64( +--; SM35: { +--; SM35-NEXT: .reg .b32 %r<5>; +--; SM35-NEXT: .reg .b64 %rd<7>; +--; SM35-EMPTY: +--; SM35-NEXT: // %bb.0: +--; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; +--; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; +--; SM35-NEXT: and.b32 %r2, %r1, 63; +--; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; +--; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; +--; SM35-NEXT: shr.u64 %rd4, %rd3, 1; +--; SM35-NEXT: not.b32 %r3, %r1; +--; SM35-NEXT: and.b32 %r4, %r3, 63; +--; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; +--; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; +--; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; +--; SM35-NEXT: ret; +-- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) +-- ret i64 %val +--} +-- +-diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll +---- a/llvm/test/DebugInfo/NVPTX/debug-info.ll +-+++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll +-@@ -25,10 +25,6 @@ +- ; CHECK-DAG: .reg .b64 %rd<8>; +- ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 +- ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; +--; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +--; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +--; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +--; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +- ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 +- ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; +- ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 +-@@ -42,6 +38,10 @@ +- ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 +- ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; +- ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; +-+; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +-+; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +-+; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; +-+; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; +- ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 +- ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; +- ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; +-@@ -2661,22 +2661,22 @@ +- ; CHECK-NEXT:.b32 4579 // DW_AT_type +- ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine +- ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin +--; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc +--; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc +-+; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc +-+; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc +- ; CHECK-NEXT:.b8 1 // DW_AT_call_file +- ; CHECK-NEXT:.b8 6 // DW_AT_call_line +- ; CHECK-NEXT:.b8 11 // DW_AT_call_column +- ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine +- ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin +--; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc +--; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc +-+; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc +-+; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc +- ; CHECK-NEXT:.b8 1 // DW_AT_call_file +- ; CHECK-NEXT:.b8 6 // DW_AT_call_line +- ; CHECK-NEXT:.b8 24 // DW_AT_call_column +- ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine +- ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin +--; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc +--; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc +-+; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc +-+; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc +- ; CHECK-NEXT:.b8 1 // DW_AT_call_file +- ; CHECK-NEXT:.b8 6 // DW_AT_call_line +- ; CHECK-NEXT:.b8 37 // DW_AT_call_column diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index abe15ef..af35fe7 100644 +index af35fe7..7b11086 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "df0864e761107b07e38f5503e0cbee0cebb4c5e8" -- LLVM_SHA256 = "5bfcb7306d9d40f420862ace1f7ad3f01979facfb16ffd1fc80b6d91e92019fa" -+ LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" -+ LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" +- LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" +- LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" ++ LLVM_COMMIT = "29b92d07746fac26cd64c914bc9c5c3833974f6d" ++ LLVM_SHA256 = "3e8e93e3749454af4b64f7f34b792a4748b62fc533bca1703d33b2b04e34eb70" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 4f6a0785270667..3ffc08a6fd8eb9 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "22e68fa19cfb2d28434a75d4d20d0efc182b166a" - SHARDY_SHA256 = "2b47b0ee994feca2bd782e20aca7d709e29bc870c2ac435aca967f7664c9f949" + SHARDY_COMMIT = "c4642106cba935c06f437e542cb376bce8fbd16c" + SHARDY_SHA256 = "286661a749a4ed03dea624c15613bf615a357a720c5623e80f08479010cde42d" tf_http_archive( name = "shardy", From 53b6bb1684e9dea8768f44d3db9b53cafb87f3bf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 07:18:38 -0700 Subject: [PATCH 319/483] Automated Code Change PiperOrigin-RevId: 679137606 --- third_party/xla/xla/service/BUILD | 100 +++++++++++++++++- .../xla/xla/service/add_original_value.h | 3 + .../xla/xla/service/algebraic_simplifier.h | 1 + .../algebraic_simplifier_overflow_test.cc | 2 + .../service/all_gather_broadcast_reorder.h | 2 + .../all_gather_broadcast_reorder_test.cc | 3 + .../xla/xla/service/all_gather_combiner.cc | 2 + .../xla/xla/service/all_gather_combiner.h | 3 + .../xla/service/all_gather_combiner_test.cc | 3 + .../xla/xla/service/all_gather_decomposer.cc | 1 + .../xla/xla/service/all_reduce_combiner.cc | 8 +- .../xla/xla/service/all_reduce_combiner.h | 1 + .../xla/service/all_reduce_combiner_test.cc | 8 +- .../xla/xla/service/all_reduce_contiguous.cc | 7 ++ .../xla/xla/service/all_reduce_contiguous.h | 2 + .../xla/service/all_reduce_contiguous_test.cc | 5 +- .../xla/xla/service/all_reduce_folder.cc | 1 + .../xla/xla/service/all_reduce_folder.h | 2 + third_party/xla/xla/service/all_reduce_key.cc | 1 + third_party/xla/xla/service/all_reduce_key.h | 1 + .../xla/xla/service/all_reduce_promotion.cc | 13 +++ .../xla/xla/service/all_reduce_promotion.h | 8 ++ .../xla/service/all_reduce_promotion_test.cc | 6 ++ .../xla/xla/service/all_reduce_reassociate.cc | 1 + .../xla/xla/service/all_reduce_reassociate.h | 2 + .../service/all_reduce_reassociate_test.cc | 1 + .../xla/xla/service/all_reduce_simplifier.cc | 2 + .../xla/xla/service/all_reduce_simplifier.h | 2 + .../xla/service/all_reduce_simplifier_test.cc | 6 +- .../xla/xla/service/all_to_all_decomposer.cc | 2 + .../xla/xla/service/all_to_all_decomposer.h | 2 + .../xla/xla/service/allocation_tracker.cc | 1 + .../xla/xla/service/allocation_tracker.h | 6 ++ .../xla/xla/service/ar_crs_combiner.cc | 1 + .../xla/xla/service/ar_crs_combiner_test.cc | 1 + .../xla/service/async_collective_creator.cc | 6 ++ .../xla/service/async_collective_creator.h | 8 ++ .../service/async_collective_creator_test.cc | 4 + third_party/xla/xla/service/backend.cc | 9 ++ third_party/xla/xla/service/backend.h | 6 ++ .../xla/service/batch_dot_simplification.cc | 11 ++ .../xla/service/batch_dot_simplification.h | 4 + .../service/batch_dot_simplification_test.cc | 3 + .../batched_gather_scatter_normalizer.cc | 1 + .../xla/xla/service/batchnorm_expander.h | 3 + .../xla/service/batchnorm_expander_test.cc | 7 +- .../xla/service/bfloat16_conversion_folding.h | 5 + .../xla/xla/service/bfloat16_propagation.cc | 9 ++ .../xla/xla/service/bfloat16_propagation.h | 7 ++ .../xla/service/bitcast_dtypes_expander.cc | 16 +-- .../service/bitcast_dtypes_expander_test.cc | 5 +- 51 files changed, 291 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 93e581a80c78e6..e74dbda5c60e50 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -122,9 +122,16 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -140,6 +147,9 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", ], ) @@ -149,6 +159,7 @@ cc_library( hdrs = ["all_reduce_key.h"], deps = [ ":hlo_domain_map", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@com_google_absl//absl/log", ], @@ -158,7 +169,18 @@ cc_library( name = "all_reduce_promotion", srcs = ["all_reduce_promotion.cc"], hdrs = ["all_reduce_promotion.h"], - deps = [":change_op_data_type"], + deps = [ + ":change_op_data_type", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], ) xla_cc_test( @@ -168,7 +190,12 @@ xla_cc_test( ":all_reduce_promotion", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) @@ -183,6 +210,7 @@ cc_library( ":pattern_matcher", "//xla:literal", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", @@ -204,6 +232,7 @@ xla_cc_test( ":pattern_matcher", ":pattern_matcher_gmock", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", @@ -222,6 +251,7 @@ cc_library( hdrs = ["all_reduce_folder.h"], deps = [ ":all_reduce_key", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", @@ -306,7 +336,9 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -385,11 +417,13 @@ cc_library( ":float_support", ":hlo_dataflow_analysis", ":hlo_dce", + ":hlo_value", ":tuple_simplifier", "//xla:literal", "//xla:shape_tree", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", @@ -397,7 +431,12 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1307,13 +1346,18 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/host:host_platform_id", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1860,6 +1904,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -2716,6 +2761,7 @@ xla_test( deps = [ ":batchnorm_expander", ":hlo_parser", + "//xla:error_spec", "//xla:literal", "//xla:shape_util", "//xla:test", @@ -2725,6 +2771,7 @@ xla_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", ], ) @@ -2831,8 +2878,10 @@ xla_test( name = "algebraic_simplifier_overflow_test", srcs = ["algebraic_simplifier_overflow_test.cc"], deps = [ + "//xla:error_spec", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", ], ) @@ -2995,12 +3044,15 @@ cc_library( srcs = ["bitcast_dtypes_expander.cc"], hdrs = ["bitcast_dtypes_expander.h"], deps = [ + ":hlo_module_config", ":op_expander_pass", "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", "//xla:types", + "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "//xla/client:xla_computation", "//xla/client/lib:arithmetic", "//xla/client/lib:broadcast", "//xla/client/lib:constants", @@ -3009,7 +3061,9 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) @@ -3018,11 +3072,14 @@ xla_cc_test( srcs = ["bitcast_dtypes_expander_test.cc"], deps = [ ":bitcast_dtypes_expander", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", ], ) @@ -3035,6 +3092,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", ], ) @@ -3055,11 +3114,13 @@ cc_library( "//xla/hlo/utils:hlo_sharding_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -3072,6 +3133,8 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) @@ -3094,9 +3157,13 @@ cc_library( "//xla/hlo/utils:hlo_sharding_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -3112,6 +3179,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) @@ -3123,11 +3194,16 @@ cc_library( deps = [ "//xla:shape_util", "//xla:status_macros", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", ], ) @@ -3140,6 +3216,9 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) @@ -3193,6 +3272,7 @@ cc_library( ":hlo_replication_analysis", "//xla:literal_util", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", @@ -3222,6 +3302,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", ], ) @@ -3291,9 +3372,18 @@ cc_library( hdrs = ["batch_dot_simplification.h"], deps = [ ":hlo_creation_utils", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -3303,9 +3393,11 @@ xla_cc_test( deps = [ ":batch_dot_simplification", "//xla:test", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@local_tsl//tsl/platform:statusor", ], ) @@ -3980,6 +4072,8 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -3993,6 +4087,7 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", @@ -7134,6 +7229,7 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", @@ -7198,6 +7294,7 @@ xla_cc_test( srcs = ["ar_crs_combiner_test.cc"], deps = [ ":ar_crs_combiner", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", @@ -8048,6 +8145,7 @@ cc_library( ":op_expander_pass", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/add_original_value.h b/third_party/xla/xla/service/add_original_value.h index 8dd1655e0eb0e8..dd3acf5501cf45 100644 --- a/third_party/xla/xla/service/add_original_value.h +++ b/third_party/xla/xla/service/add_original_value.h @@ -16,7 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ #define XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" namespace xla { diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h index 9ff96b248d398a..f791a324135fd9 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.h +++ b/third_party/xla/xla/service/algebraic_simplifier.h @@ -40,6 +40,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc b/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc index 8e011d6d24edf7..071f9994b54a08 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include "xla/error_spec.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/all_gather_broadcast_reorder.h b/third_party/xla/xla/service/all_gather_broadcast_reorder.h index 0759f8ebfbbc79..78a81e6c255c0c 100644 --- a/third_party/xla/xla/service/all_gather_broadcast_reorder.h +++ b/third_party/xla/xla/service/all_gather_broadcast_reorder.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_GATHER_BROADCAST_REORDER_H_ #define XLA_SERVICE_ALL_GATHER_BROADCAST_REORDER_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/all_gather_broadcast_reorder_test.cc b/third_party/xla/xla/service/all_gather_broadcast_reorder_test.cc index 0c7eb62232d13a..f5f2406a89efe0 100644 --- a/third_party/xla/xla/service/all_gather_broadcast_reorder_test.cc +++ b/third_party/xla/xla/service/all_gather_broadcast_reorder_test.cc @@ -15,6 +15,9 @@ limitations under the License. #include "xla/service/all_gather_broadcast_reorder.h" +#include +#include +#include "absl/strings/string_view.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/all_gather_combiner.cc b/third_party/xla/xla/service/all_gather_combiner.cc index efd7a803f04a0c..a1fd7270d82fdf 100644 --- a/third_party/xla/xla/service/all_gather_combiner.cc +++ b/third_party/xla/xla/service/all_gather_combiner.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -47,6 +48,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_gather_combiner.h b/third_party/xla/xla/service/all_gather_combiner.h index 79bf388322081c..757b5ffd1a6eef 100644 --- a/third_party/xla/xla/service/all_gather_combiner.h +++ b/third_party/xla/xla/service/all_gather_combiner.h @@ -16,8 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_GATHER_COMBINER_H_ #define XLA_SERVICE_ALL_GATHER_COMBINER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/hlo_domain_map.h" diff --git a/third_party/xla/xla/service/all_gather_combiner_test.cc b/third_party/xla/xla/service/all_gather_combiner_test.cc index 97a966b8815036..71ffe8a46e5995 100644 --- a/third_party/xla/xla/service/all_gather_combiner_test.cc +++ b/third_party/xla/xla/service/all_gather_combiner_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include +#include #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -28,6 +30,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_gather_decomposer.cc b/third_party/xla/xla/service/all_gather_decomposer.cc index 98443b9113f976..ce3ed5f5f44026 100644 --- a/third_party/xla/xla/service/all_gather_decomposer.cc +++ b/third_party/xla/xla/service/all_gather_decomposer.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/all_reduce_combiner.cc b/third_party/xla/xla/service/all_reduce_combiner.cc index a581b15d420dca..706a04b40fed7b 100644 --- a/third_party/xla/xla/service/all_reduce_combiner.cc +++ b/third_party/xla/xla/service/all_reduce_combiner.cc @@ -23,8 +23,12 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -34,10 +38,12 @@ limitations under the License. #include "xla/service/all_reduce_key.h" #include "xla/service/collective_combiner_utils.h" #include "xla/service/hlo_domain_map.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_reduce_combiner.h b/third_party/xla/xla/service/all_reduce_combiner.h index bd1aa811f97160..c85c937154c55c 100644 --- a/third_party/xla/xla/service/all_reduce_combiner.h +++ b/third_party/xla/xla/service/all_reduce_combiner.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_COMBINER_H_ #define XLA_SERVICE_ALL_REDUCE_COMBINER_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/array2d.h" diff --git a/third_party/xla/xla/service/all_reduce_combiner_test.cc b/third_party/xla/xla/service/all_reduce_combiner_test.cc index 188d7a99251bb0..0eab2b96274d90 100644 --- a/third_party/xla/xla/service/all_reduce_combiner_test.cc +++ b/third_party/xla/xla/service/all_reduce_combiner_test.cc @@ -18,16 +18,22 @@ limitations under the License. #include #include +#include +#include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_reduce_contiguous.cc b/third_party/xla/xla/service/all_reduce_contiguous.cc index fa76de45facd59..84325793041688 100644 --- a/third_party/xla/xla/service/all_reduce_contiguous.cc +++ b/third_party/xla/xla/service/all_reduce_contiguous.cc @@ -17,14 +17,21 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_reduce_contiguous.h b/third_party/xla/xla/service/all_reduce_contiguous.h index 102245cd2ee36a..4e76be9b4fbc07 100644 --- a/third_party/xla/xla/service/all_reduce_contiguous.h +++ b/third_party/xla/xla/service/all_reduce_contiguous.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_CONTIGUOUS_H_ #define XLA_SERVICE_ALL_REDUCE_CONTIGUOUS_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/all_reduce_contiguous_test.cc b/third_party/xla/xla/service/all_reduce_contiguous_test.cc index ccd1effdbc6c30..1f5f970c5453a8 100644 --- a/third_party/xla/xla/service/all_reduce_contiguous_test.cc +++ b/third_party/xla/xla/service/all_reduce_contiguous_test.cc @@ -17,12 +17,15 @@ limitations under the License. #include +#include +#include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_reduce_folder.cc b/third_party/xla/xla/service/all_reduce_folder.cc index d616cc411844f5..8bc18a0d73357f 100644 --- a/third_party/xla/xla/service/all_reduce_folder.cc +++ b/third_party/xla/xla/service/all_reduce_folder.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/all_reduce_key.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_folder.h b/third_party/xla/xla/service/all_reduce_folder.h index 77706bbff34d26..3c84e2cae5df6d 100644 --- a/third_party/xla/xla/service/all_reduce_folder.h +++ b/third_party/xla/xla/service/all_reduce_folder.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_FOLDER_H_ #define XLA_SERVICE_ALL_REDUCE_FOLDER_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/all_reduce_key.cc b/third_party/xla/xla/service/all_reduce_key.cc index bd2fd49dc6be51..82319b09c3e0f8 100644 --- a/third_party/xla/xla/service/all_reduce_key.cc +++ b/third_party/xla/xla/service/all_reduce_key.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_domain_map.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_key.h b/third_party/xla/xla/service/all_reduce_key.h index 53a444d8a95c5b..fd72f7e4230bae 100644 --- a/third_party/xla/xla/service/all_reduce_key.h +++ b/third_party/xla/xla/service/all_reduce_key.h @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_domain_map.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_promotion.cc b/third_party/xla/xla/service/all_reduce_promotion.cc index b0328759c7d310..0e60d59b6a24be 100644 --- a/third_party/xla/xla/service/all_reduce_promotion.cc +++ b/third_party/xla/xla/service/all_reduce_promotion.cc @@ -19,6 +19,19 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_reduce_promotion.h b/third_party/xla/xla/service/all_reduce_promotion.h index a1ad33033187f1..e6459f82e00dc2 100644 --- a/third_party/xla/xla/service/all_reduce_promotion.h +++ b/third_party/xla/xla/service/all_reduce_promotion.h @@ -17,7 +17,15 @@ limitations under the License. #define XLA_SERVICE_ALL_REDUCE_PROMOTION_H_ #include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/change_op_data_type.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_promotion_test.cc b/third_party/xla/xla/service/all_reduce_promotion_test.cc index 380c1c3cf8e246..86d5fde6eb71c5 100644 --- a/third_party/xla/xla/service/all_reduce_promotion_test.cc +++ b/third_party/xla/xla/service/all_reduce_promotion_test.cc @@ -15,9 +15,15 @@ limitations under the License. #include "xla/service/all_reduce_promotion.h" +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_reduce_reassociate.cc b/third_party/xla/xla/service/all_reduce_reassociate.cc index c7becb2c436c0b..6063eef7b6e6b0 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_reassociate.h b/third_party/xla/xla/service/all_reduce_reassociate.h index f2ff998b4b6f04..9fbeb32e6bf81f 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate.h +++ b/third_party/xla/xla/service/all_reduce_reassociate.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_REASSOCIATE_H_ #define XLA_SERVICE_ALL_REDUCE_REASSOCIATE_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/all_reduce_reassociate_test.cc b/third_party/xla/xla/service/all_reduce_reassociate_test.cc index b7130508e878ea..c0a91a93be215c 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate_test.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_simplifier.cc b/third_party/xla/xla/service/all_reduce_simplifier.cc index 0760433bda4489..cc5c872e97b42c 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier.cc @@ -27,12 +27,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_replication_analysis.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/all_reduce_simplifier.h b/third_party/xla/xla/service/all_reduce_simplifier.h index 1c44b945bdf697..e670f2fd956eb5 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.h +++ b/third_party/xla/xla/service/all_reduce_simplifier.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_SIMPLIFIER_H_ #define XLA_SERVICE_ALL_REDUCE_SIMPLIFIER_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/all_reduce_simplifier_test.cc b/third_party/xla/xla/service/all_reduce_simplifier_test.cc index 35f5955076ad7e..5b07e61447ce3c 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier_test.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier_test.cc @@ -20,17 +20,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" -#include "xla/window_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_to_all_decomposer.cc b/third_party/xla/xla/service/all_to_all_decomposer.cc index ecb08af7660382..dabea315b81c40 100644 --- a/third_party/xla/xla/service/all_to_all_decomposer.cc +++ b/third_party/xla/xla/service/all_to_all_decomposer.cc @@ -18,11 +18,13 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/layout_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/all_to_all_decomposer.h b/third_party/xla/xla/service/all_to_all_decomposer.h index 3ef1891a412665..cca93ddcba0ff3 100644 --- a/third_party/xla/xla/service/all_to_all_decomposer.h +++ b/third_party/xla/xla/service/all_to_all_decomposer.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_ #define XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_ +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/op_expander_pass.h" diff --git a/third_party/xla/xla/service/allocation_tracker.cc b/third_party/xla/xla/service/allocation_tracker.cc index 95168eba9c6c61..507107723093ab 100644 --- a/third_party/xla/xla/service/allocation_tracker.cc +++ b/third_party/xla/xla/service/allocation_tracker.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/allocation_tracker.h b/third_party/xla/xla/service/allocation_tracker.h index f7748d7162ace2..cea193eaea8568 100644 --- a/third_party/xla/xla/service/allocation_tracker.h +++ b/third_party/xla/xla/service/allocation_tracker.h @@ -22,9 +22,15 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "xla/service/backend.h" +#include "xla/service/shaped_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/ar_crs_combiner.cc b/third_party/xla/xla/service/ar_crs_combiner.cc index a75acbc2b38498..ebe842b7ea9361 100644 --- a/third_party/xla/xla/service/ar_crs_combiner.cc +++ b/third_party/xla/xla/service/ar_crs_combiner.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/ar_crs_combiner_test.cc b/third_party/xla/xla/service/ar_crs_combiner_test.cc index e18d81d20fa93e..86dd6c397c6053 100644 --- a/third_party/xla/xla/service/ar_crs_combiner_test.cc +++ b/third_party/xla/xla/service/ar_crs_combiner_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/async_collective_creator.cc b/third_party/xla/xla/service/async_collective_creator.cc index 71c7eb820b61b1..f0a0e283b0717a 100644 --- a/third_party/xla/xla/service/async_collective_creator.cc +++ b/third_party/xla/xla/service/async_collective_creator.cc @@ -19,8 +19,13 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/frontend_attributes.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -33,6 +38,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/async_collective_creator.h b/third_party/xla/xla/service/async_collective_creator.h index 5a542cf1e48c59..15263e3e4235f1 100644 --- a/third_party/xla/xla/service/async_collective_creator.h +++ b/third_party/xla/xla/service/async_collective_creator.h @@ -20,7 +20,15 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/shape.h" +#include "xla/util.h" namespace xla { diff --git a/third_party/xla/xla/service/async_collective_creator_test.cc b/third_party/xla/xla/service/async_collective_creator_test.cc index ad783ca23d0770..ba0777b07a1303 100644 --- a/third_party/xla/xla/service/async_collective_creator_test.cc +++ b/third_party/xla/xla/service/async_collective_creator_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include +#include +#include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" @@ -25,6 +28,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/backend.cc b/third_party/xla/xla/service/backend.cc index eea05d78293e93..5ed8d66ddca365 100644 --- a/third_party/xla/xla/service/backend.cc +++ b/third_party/xla/xla/service/backend.cc @@ -13,6 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/service/computation_placer.h" +#include "xla/service/stream_pool.h" +#include "xla/service/transfer_manager.h" +#include "xla/stream_executor/platform.h" +#include "tsl/platform/statusor.h" #define EIGEN_USE_THREADS #include "xla/service/backend.h" diff --git a/third_party/xla/xla/service/backend.h b/third_party/xla/xla/service/backend.h index ba54e008333989..cbbec594bc9020 100644 --- a/third_party/xla/xla/service/backend.h +++ b/third_party/xla/xla/service/backend.h @@ -23,17 +23,23 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/service/compiler.h" #include "xla/service/computation_placer.h" #include "xla/service/stream_pool.h" #include "xla/service/transfer_manager.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "tsl/platform/threadpool.h" namespace Eigen { struct ThreadPoolDevice; diff --git a/third_party/xla/xla/service/batch_dot_simplification.cc b/third_party/xla/xla/service/batch_dot_simplification.cc index 3f22acf1930249..62751d090115b7 100644 --- a/third_party/xla/xla/service/batch_dot_simplification.cc +++ b/third_party/xla/xla/service/batch_dot_simplification.cc @@ -16,10 +16,21 @@ limitations under the License. #include "xla/service/batch_dot_simplification.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { absl::StatusOr diff --git a/third_party/xla/xla/service/batch_dot_simplification.h b/third_party/xla/xla/service/batch_dot_simplification.h index 6ba3cf13e69f27..800c6439aed159 100644 --- a/third_party/xla/xla/service/batch_dot_simplification.h +++ b/third_party/xla/xla/service/batch_dot_simplification.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ #define XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/batch_dot_simplification_test.cc b/third_party/xla/xla/service/batch_dot_simplification_test.cc index fd60e8f2a3ade3..0d4101b4fd97f1 100644 --- a/third_party/xla/xla/service/batch_dot_simplification_test.cc +++ b/third_party/xla/xla/service/batch_dot_simplification_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "xla/service/batch_dot_simplification.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc index c3b6c6d96a250a..c29a4cb65eb0cc 100644 --- a/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/batchnorm_expander.h b/third_party/xla/xla/service/batchnorm_expander.h index 0ae50afe13eb3c..15738efdc44158 100644 --- a/third_party/xla/xla/service/batchnorm_expander.h +++ b/third_party/xla/xla/service/batchnorm_expander.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/batchnorm_expander_test.cc b/third_party/xla/xla/service/batchnorm_expander_test.cc index e4bb01e9f486da..9f497942b55bcf 100644 --- a/third_party/xla/xla/service/batchnorm_expander_test.cc +++ b/third_party/xla/xla/service/batchnorm_expander_test.cc @@ -18,18 +18,17 @@ limitations under the License. #include #include +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_matchers.h" -#include "xla/layout_util.h" -#include "xla/literal.h" #include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding.h b/third_party/xla/xla/service/bfloat16_conversion_folding.h index c8bc39a98c4f74..1b71243e19027e 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding.h +++ b/third_party/xla/xla/service/bfloat16_conversion_folding.h @@ -16,9 +16,14 @@ limitations under the License. #ifndef XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ #define XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/float_support.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/bfloat16_propagation.cc b/third_party/xla/xla/service/bfloat16_propagation.cc index bf3dfedf4a0cad..bd0650afa9ba8e 100644 --- a/third_party/xla/xla/service/bfloat16_propagation.cc +++ b/third_party/xla/xla/service/bfloat16_propagation.cc @@ -19,17 +19,26 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/map_util.h" +#include "xla/service/float_support.h" +#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_dce.h" +#include "xla/service/hlo_value.h" #include "xla/service/tuple_simplifier.h" +#include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/bfloat16_propagation.h b/third_party/xla/xla/service/bfloat16_propagation.h index 3f292823f6edee..a46a8091edc92b 100644 --- a/third_party/xla/xla/service/bfloat16_propagation.h +++ b/third_party/xla/xla/service/bfloat16_propagation.h @@ -21,11 +21,18 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/float_support.h" #include "xla/service/hlo_dataflow_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander.cc b/third_party/xla/xla/service/bitcast_dtypes_expander.cc index f4cc6809599cdd..6a5d12c25c32fa 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander.cc +++ b/third_party/xla/xla/service/bitcast_dtypes_expander.cc @@ -15,23 +15,23 @@ limitations under the License. #include "xla/service/bitcast_dtypes_expander.h" -#include "absl/algorithm/container.h" -#include "absl/strings/str_join.h" +#include "absl/strings/str_format.h" #include "xla/client/lib/arithmetic.h" #include "xla/client/lib/broadcast.h" #include "xla/client/lib/constants.h" #include "xla/client/xla_builder.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/client/xla_computation.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal_util.h" +#include "xla/primitive_util.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc b/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc index a5dc3b882446cc..79734ee549dfc3 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc +++ b/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc @@ -15,10 +15,11 @@ limitations under the License. #include "xla/service/bitcast_dtypes_expander.h" -#include "xla/hlo/utils/hlo_matchers.h" +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { From cdd45598542e0e371fa80004545b3017b5d122c6 Mon Sep 17 00:00:00 2001 From: akhilgoe <114951738+akhilgoe@users.noreply.github.com> Date: Thu, 26 Sep 2024 07:23:47 -0700 Subject: [PATCH 320/483] PR #15987: [XLA:CPU][oneDNN] Add Bias and Binary Add fusions for oneDNN Convolutions Imported from GitHub PR https://github.com/openxla/xla/pull/15987 This PR adds support for Bias Add and Binary Add fusions with oneDNN convolution. In addition, it adds tests to test the functionality. Copybara import of the project: -- 8d10441bcf6c2e23fb2e17e6321bc31e29f0772b by Akhil Goel : Add Bias Add and Residual Add fusions for oneDNN Convs Merging this change closes #15987 PiperOrigin-RevId: 679138801 --- .../xla/xla/service/change_op_data_type.cc | 7 +- third_party/xla/xla/service/cpu/BUILD | 1 + .../xla/xla/service/cpu/onednn_config.proto | 2 + .../cpu/onednn_contraction_rewriter.cc | 299 ++++++++++-------- .../service/cpu/onednn_contraction_rewriter.h | 18 ++ .../xla/xla/service/cpu/onednn_convolution.cc | 82 ++++- .../xla/xla/service/cpu/onednn_convolution.h | 9 + .../xla/xla/service/cpu/onednn_matmul.cc | 7 + .../xla/xla/service/cpu/onednn_matmul.h | 9 + third_party/xla/xla/service/cpu/onednn_util.h | 12 +- .../cpu/tests/onednn_convolution_test.cc | 85 ++++- 11 files changed, 368 insertions(+), 163 deletions(-) diff --git a/third_party/xla/xla/service/change_op_data_type.cc b/third_party/xla/xla/service/change_op_data_type.cc index 3c7875a2836ceb..365308ae24cd50 100644 --- a/third_party/xla/xla/service/change_op_data_type.cc +++ b/third_party/xla/xla/service/change_op_data_type.cc @@ -63,12 +63,7 @@ absl::StatusOr ChangeOpDataType::Run( continue; } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - if (instr->opcode() == HloOpcode::kDot && - cpu::OneDnnContractionRewriter::ShouldRewriteDot(instr, true)) { - continue; - } - if (instr->opcode() == HloOpcode::kConvolution && - cpu::OneDnnContractionRewriter::ShouldRewriteConv(instr)) { + if (cpu::OneDnnContractionRewriter::ShouldRewriteInstr(instr, true)) { continue; } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 08d8685feb374e..71779097649d7c 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -1761,6 +1761,7 @@ cc_library( copts = runtime_copts() + tsl_copts(), visibility = ["//visibility:public"], deps = [ + ":backend_config_proto_cc", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@eigen_archive//:eigen3", diff --git a/third_party/xla/xla/service/cpu/onednn_config.proto b/third_party/xla/xla/service/cpu/onednn_config.proto index 9f38673eaacebd..44829a6857f1f9 100644 --- a/third_party/xla/xla/service/cpu/onednn_config.proto +++ b/third_party/xla/xla/service/cpu/onednn_config.proto @@ -113,4 +113,6 @@ message OneDnnConvolutionConfig { OneDnnFusionConfig fusions = 6; uint64 feature_groups = 7; + + OneDnnOptimizationConfig optimization_config = 8; } diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc index 19122b393ce23b..01ffb340e07c1b 100644 --- a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc @@ -111,6 +111,16 @@ inline auto OneDnnMatmulInstr(HloInstruction** instr) { return m::CustomCall(instr, {"__onednn$matmul"}); } +inline auto OneDnnConvolutionInstr(HloInstruction** instr) { + return m::CustomCall(instr, {"__onednn$convolution"}); +} + +inline auto OneDnnFusibleInstr(HloInstruction** instr) { + return m::AnyOf( + m::CustomCall(instr, {"__onednn$matmul"}), + m::CustomCall(instr, {"__onednn$convolution"})); +} + inline auto ConvertBF16ToF32(HloInstruction** instr) { return m::Convert(m::Op(instr).WithElementType(PrimitiveType::BF16)) .WithElementType(PrimitiveType::F32); @@ -275,11 +285,12 @@ auto GELUActivation(HloInstruction* instr, HloInstruction** src) { return OneDnnFusionConfig::UNDEFINED; } -// OneDNN matmul can fuse add operation with automatic broadcasting along the -// addend's dimensions that are 1s. When compatible, Broadcast can be replaced -// by Bitcast, which is much cheaper. Compute new shape for the Bitcast. +// OneDNN matmul / convolution can fuse add operation with automatic +// broadcasting along the addend's dimensions that are 1s. When compatible, +// Broadcast can be replaced by Bitcast, which is much cheaper. Compute new +// shape for the Bitcast. absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, - const Shape& dot_shape) { + const Shape& instr_shape) { if (broadcast_instr->opcode() != HloOpcode::kBroadcast) { return absl::InvalidArgumentError( "Hlo instruction is not a Broadcast insruction."); @@ -303,9 +314,9 @@ absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, } } - // If rank(new_shape) > rank(dot), extra dimensions with value = 1 can be + // If rank(new_shape) > rank(instr), extra dimensions with value = 1 can be // deleted from the new_shape. - int64_t rank_difference = new_shape.rank() - dot_shape.rank(); + int64_t rank_difference = new_shape.rank() - instr_shape.rank(); auto new_dims = new_shape.dimensions(); std::vector dims_to_delete; for (int i = 0; i < rank_difference; ++i) { @@ -316,8 +327,8 @@ absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, new_shape = ShapeUtil::DeleteDimensions(dims_to_delete, new_shape); // New shape for bias should satisfy the condition: - // rank(new_shape) <= rank(dot). - if (new_shape.rank() > dot_shape.rank()) { + // rank(new_shape) <= rank(instr). + if (new_shape.rank() > instr_shape.rank()) { return absl::CancelledError( "Bias shape could not be adjusted for a fusion."); } @@ -325,20 +336,20 @@ absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, return new_shape; }; -inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* dot) { - // Check if the operand's shape is compatible with matmul for fusion. +inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* instr) { + // Check if the operand's shape is compatible for fusion. // An operand is fusable if - // 1. rank(operand) <= rank(dot) and + // 1. rank(operand) <= rank(instr) and // 2. Starting from the last dim in backward direction, the dimension // size of operand is either 1 or same to dot. auto operand_dims = operand->shape().dimensions(); - auto dot_dims = dot->shape().dimensions(); - if (operand_dims.size() > dot_dims.size()) return false; + auto instr_dims = instr->shape().dimensions(); + if (operand_dims.size() > instr_dims.size()) return false; int operand_idx = operand_dims.size() - 1; - int dot_idx = dot_dims.size() - 1; - for (; operand_idx >= 0; --operand_idx, --dot_idx) { + int instr_idx = instr_dims.size() - 1; + for (; operand_idx >= 0; --operand_idx, --instr_idx) { if (operand_dims[operand_idx] != 1 && - operand_dims[operand_idx] != dot_dims[dot_idx]) + operand_dims[operand_idx] != instr_dims[instr_idx]) return false; } return true; @@ -367,6 +378,7 @@ inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert, bool OneDnnContractionRewriter::ShouldRewriteDot( const HloInstruction* dot_instr, bool before_layout_assignment) { + if (dot_instr->opcode() != HloOpcode::kDot) return false; // Currently, blocking control dependencies if (dot_instr->HasControlDependencies()) return false; if (!IsSupportedType(dot_instr->shape().element_type())) return false; @@ -429,6 +441,7 @@ bool OneDnnContractionRewriter::ShouldRewriteDot( bool OneDnnContractionRewriter::ShouldRewriteConv( const HloInstruction* conv_instr) { + if (conv_instr->opcode() != HloOpcode::kConvolution) return false; if (conv_instr->HasControlDependencies()) return false; if (!IsSupportedType(conv_instr->shape().element_type())) return false; if (conv_instr->batch_group_count() != 1) return false; @@ -566,14 +579,14 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } absl::Status HandleAdd(HloInstruction* instr) override { - // Try to do a fusion for Dot(onednn-matmul) + Add. However, + // Try to fuse Add to the instr. However, // HLO Add instruction might receive the addends after additional // processing like Broadcast, Bitcast, Convert, etc. is applied to the raw // addends. Here, the following possible pattern is matched. // // clang-format off // - // Dot addend + // Dot / Conv addend // | | // v v // optional instructions optional instructions @@ -586,148 +599,154 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { // // clang-format on - HloInstruction *addend_intermediate, *dot; - HloInstruction* optional_dot_bitcast = nullptr; - HloInstruction* optional_dot_convert = nullptr; + HloInstruction *addend_intermediate, *contraction; + HloInstruction* optional_contraction_bitcast = nullptr; + HloInstruction* optional_contraction_convert = nullptr; auto pattern = m::AddAnyOrder( &instr, - OptionalConvertAndBitcast(&optional_dot_convert, &optional_dot_bitcast, - OneDnnMatmulInstr(&dot)) + OptionalConvertAndBitcast(&optional_contraction_convert, + &optional_contraction_bitcast, + OneDnnFusibleInstr(&contraction)) .WithOneUser(), m::Op(&addend_intermediate)); if (Match(instr, pattern)) { - if (!IsSupportedType(dot->shape().element_type())) - return absl::OkStatus(); - // TODO(intel-tf): Remove the condition below when the fusion Dot + - // Add(bias) + Add(e.g., residual) is enabled. - if (!dot->backend_config() - ->mutable_onednn_matmul_config() - ->mutable_fusions() - ->ops() - .empty() && - dot->backend_config() - ->mutable_onednn_matmul_config() - ->mutable_fusions() - ->ops(0) == OneDnnFusionConfig::BIAS) { - return absl::OkStatus(); - } - std::vector new_operands; - for (auto operand : dot->operands()) { - new_operands.push_back(operand); - } + HANDLE_OP_INTERNAL(HandleAddInternal, contraction, instr, + addend_intermediate, optional_contraction_convert, + optional_contraction_bitcast); + } - // At this point, the addend could have one of the following - // possiblities that the current fusion can handle: - // - // - addend -> Convert -> Broadcast -> Add - // - addend -> Broadcast -> Convert -> Add - // - addend -> Convert - // - addend -> Broadcast - // - addend - // - // Hunt for addend through possible sequences above and check the addend - // is compatible to onednn-matmul fusion. - HloInstruction* addend = nullptr; - HloInstruction* optional_addend_broadcast = nullptr; - auto addend_pattern = m::AnyOf( - m::Broadcast(&optional_addend_broadcast, - m::Convert(&addend, m::Op())), - m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))), - m::Convert(&addend, m::Op()), - m::Broadcast(&optional_addend_broadcast, m::Op(&addend)), - m::Op(&addend)); - if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus(); - - if (optional_addend_broadcast && addend->shape().rank() != 1) { - auto new_shape = - AdjustBiasShape(optional_addend_broadcast, dot->shape()); - if (new_shape.ok()) { - addend = addend->AddInstruction( - HloInstruction::CreateBitcast(new_shape.value(), addend)); - } else { - VLOG(2) << new_shape.status(); - return absl::OkStatus(); - } - } + return absl::OkStatus(); + } - // Validate addend for fusion. - if (IsSupportedType(addend->shape().element_type()) && - IsOperandFusible(addend, dot)) { - new_operands.push_back(addend); + template + absl::Status HandleAddInternal(HloInstruction* contraction, + HloInstruction* instr, + HloInstruction* addend_intermediate, + HloInstruction* optional_contraction_convert, + HloInstruction* optional_contraction_bitcast) { + if (!IsSupportedType(contraction->shape().element_type())) + return absl::OkStatus(); + // TODO(intel-tf): Remove the condition below when the fusion Contraction + + // Add(bias) + Add(e.g., residual) is enabled. + auto contraction_config = contraction->backend_config(); + if (!GetKernelConfig(&contraction_config) + ->mutable_fusions() + ->ops() + .empty() && + GetKernelConfig(&contraction_config) + ->mutable_fusions() + ->ops(0) == OneDnnFusionConfig::BIAS) { + return absl::OkStatus(); + } + std::vector new_operands; + for (auto operand : contraction->operands()) { + new_operands.push_back(operand); + } + + // At this point, the addend could have one of the following + // possiblities that the current fusion can handle: + // + // - addend -> Convert -> Broadcast -> Add + // - addend -> Broadcast -> Convert -> Add + // - addend -> Convert + // - addend -> Broadcast + // - addend + // + // Hunt for addend through possible sequences above and check the addend + // is compatible for onednn fusion. + HloInstruction* addend = nullptr; + HloInstruction* optional_addend_broadcast = nullptr; + auto addend_pattern = m::AnyOf( + m::Broadcast(&optional_addend_broadcast, m::Convert(&addend, m::Op())), + m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))), + m::Convert(&addend, m::Op()), + m::Broadcast(&optional_addend_broadcast, m::Op(&addend)), + m::Op(&addend)); + if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus(); + + if (optional_addend_broadcast && addend->shape().rank() != 1) { + auto new_shape = + AdjustBiasShape(optional_addend_broadcast, contraction->shape()); + if (new_shape.ok()) { + addend = addend->AddInstruction( + HloInstruction::CreateBitcast(new_shape.value(), addend)); } else { + VLOG(2) << new_shape.status(); return absl::OkStatus(); } + } - // TODO(intel-tf): Remove this restriction once oneDNN has an optimized - // implementation for broadcasted add across all dimensions. - OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED; - kind = (addend->shape().rank() == 1) - ? (dot->backend_config() - ->mutable_onednn_matmul_config() - ->fusions() - .ops() - .empty() - ? OneDnnFusionConfig::BIAS - : OneDnnFusionConfig::UNDEFINED) - : OneDnnFusionConfig::BINARY_ADD; - if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); + // Validate addend for fusion. + if (IsSupportedType(addend->shape().element_type()) && + IsOperandFusible(addend, contraction)) { + new_operands.push_back(addend); + } else { + return absl::OkStatus(); + } - auto matmul_call = Cast(instr->AddInstruction( - dot->CloneWithNewOperands(dot->shape(), new_operands))); + auto custom_call = Cast(instr->AddInstruction( + contraction->CloneWithNewOperands(contraction->shape(), new_operands))); - auto backend_config = matmul_call->backend_config(); - backend_config->mutable_onednn_matmul_config() - ->mutable_fusions() - ->add_ops(kind); + auto backend_config = custom_call->backend_config(); - if (optional_addend_broadcast) { - backend_config->mutable_onednn_matmul_config() - ->mutable_optimization_config() - ->set_bias_broadcast(true); - } - TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); + // TODO(intel-tf): Remove this restriction once oneDNN has an optimized + // implementation for broadcasted add across all dimensions. + OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED; + kind = + (addend->shape().rank() == 1) + ? (GetKernelConfig(&backend_config)->fusions().ops().empty() + ? OneDnnFusionConfig::BIAS + : OneDnnFusionConfig::UNDEFINED) + : OneDnnFusionConfig::BINARY_ADD; + if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); - HloInstruction* new_instr; - // If matched pattern has custom-call -> bitcast -> add, then we need to - // insert bitcast after the new fusion to maintain the correct shape - // (new-custom-call -> bitcast). Also, this will optionally be followed - // by -> convert for bf16 case to avoid datatype mismatch. - if (optional_dot_bitcast != nullptr && - optional_dot_bitcast->opcode() == HloOpcode::kBitcast) { - if (optional_dot_convert != nullptr && - optional_dot_convert->opcode() == HloOpcode::kConvert) { - auto bitcast_call = - matmul_call->AddInstruction(HloInstruction::CreateBitcast( - ShapeUtil::ChangeElementType( - instr->shape(), matmul_call->shape().element_type()), - matmul_call)); - new_instr = - bitcast_call->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType( - bitcast_call->shape(), - optional_dot_convert->shape().element_type()), - bitcast_call)); - } else { - new_instr = matmul_call->AddInstruction( - HloInstruction::CreateBitcast(instr->shape(), matmul_call)); - } + GetKernelConfig(&backend_config)->mutable_fusions()->add_ops(kind); + + if (optional_addend_broadcast) { + GetKernelConfig(&backend_config) + ->mutable_optimization_config() + ->set_bias_broadcast(true); + } + TF_RETURN_IF_ERROR(custom_call->set_backend_config(*backend_config)); + + HloInstruction* new_instr; + // If matched pattern has custom-call -> bitcast -> add, then we need to + // insert bitcast after the new fusion to maintain the correct shape + // (new-custom-call -> bitcast). Also, this will optionally be followed + // by -> convert for bf16 case to avoid datatype mismatch. + if (optional_contraction_bitcast != nullptr && + optional_contraction_bitcast->opcode() == HloOpcode::kBitcast) { + if (optional_contraction_convert != nullptr && + optional_contraction_convert->opcode() == HloOpcode::kConvert) { + auto bitcast_call = + custom_call->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::ChangeElementType( + instr->shape(), custom_call->shape().element_type()), + custom_call)); + new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + bitcast_call->shape(), + optional_contraction_convert->shape().element_type()), + bitcast_call)); } else { - if (optional_dot_convert != nullptr && - optional_dot_convert->opcode() == HloOpcode::kConvert) { - new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType( - matmul_call->shape(), - optional_dot_convert->shape().element_type()), - matmul_call)); - } else { - new_instr = matmul_call; - } + new_instr = custom_call->AddInstruction( + HloInstruction::CreateBitcast(instr->shape(), custom_call)); + } + } else { + if (optional_contraction_convert != nullptr && + optional_contraction_convert->opcode() == HloOpcode::kConvert) { + new_instr = custom_call->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + custom_call->shape(), + optional_contraction_convert->shape().element_type()), + custom_call)); + } else { + new_instr = custom_call; } - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr)); } - + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h index 503d8a8ee25630..2706d05d1ef920 100644 --- a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h @@ -50,12 +50,30 @@ class OneDnnContractionRewriter : public HloModulePass { static bool ShouldRewriteDot(const HloInstruction* dot_instr, bool before_layout_assignment = false); static bool ShouldRewriteConv(const HloInstruction* conv_instr); + static bool ShouldRewriteInstr(const HloInstruction* instr, + bool before_layout_assignment = false) { + return ShouldRewriteDot(instr, before_layout_assignment) || + ShouldRewriteConv(instr); + } private: int intra_op_parallelism_; const tsl::thread::ThreadPool* compile_threadpool_; }; +#define HANDLE_OP_INTERNAL(internal_callee, contraction, ...) \ + switch (contraction->backend_config() \ + ->backend_config_oneof_case()) { \ + case BackendConfig::BackendConfigOneofCase::kOnednnMatmulConfig: \ + return internal_callee< \ + BackendConfig::BackendConfigOneofCase::kOnednnMatmulConfig>( \ + contraction, __VA_ARGS__); \ + default: \ + return internal_callee< \ + BackendConfig::BackendConfigOneofCase::kOnednnConvConfig>( \ + contraction, __VA_ARGS__); \ + } + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.cc b/third_party/xla/xla/service/cpu/onednn_convolution.cc index 7ed7987137ad60..30e91fb4aae3e7 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution.cc +++ b/third_party/xla/xla/service/cpu/onednn_convolution.cc @@ -63,6 +63,13 @@ dnnl::memory::format_tag GetFormatTag(const int dims) { : dnnl::memory::format_tag::any; } +template <> +typename PrimitiveTrait::pointer_type +GetKernelConfig( + absl::StatusOr* backend_config) { + return (*backend_config)->mutable_onednn_conv_config(); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( void* result, void** args) { // args[0]: ptr to nargs @@ -154,7 +161,6 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( MemrefInfo ker_minfo(args[arg_indx++]); MemrefInfo res_minfo(result); - // Permute memory descriptors auto inp_md = inp_minfo.GetOneDnnMemDesc(); auto ker_md = ker_minfo.GetOneDnnMemDesc(); auto res_md = res_minfo.GetOneDnnMemDesc(); @@ -174,6 +180,50 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( new_ker_md = new_ker_md.reshape(corr_dims); } + const int64_t num_fused_operands = num_args - arg_indx; + std::vector fused_mds; + std::vector fused_bufs; + for (int64_t i = 0; i < num_fused_operands; ++i) { + MemrefInfo operand_minfo(args[arg_indx++]); + fused_mds.push_back(operand_minfo.GetOneDnnMemDesc()); + fused_bufs.push_back(operand_minfo.Data()); + } + + std::vector> postop_args; + + auto bias_md = memory::desc(); + + dnnl::post_ops post_ops; + int fused_operand_idx = 0; + for (auto& fused_op : conv_config.fusions().ops()) { + switch (fused_op) { + case OneDnnFusionConfig::BIAS: { + bias_md = fused_mds.at(fused_operand_idx); + postop_args.emplace_back( + DNNL_ARG_BIAS, + dnnl::memory(bias_md, cpu_engine, fused_bufs[fused_operand_idx])); + fused_operand_idx++; + } break; + case OneDnnFusionConfig::BINARY_ADD: { + auto binary_md = fused_mds.at(fused_operand_idx); + binary_md = binary_md.permute_axes(out_axes); + auto arg_idx = + DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops.len()) | DNNL_ARG_SRC_1; + postop_args.emplace_back( + arg_idx, + dnnl::memory(binary_md, cpu_engine, fused_bufs[fused_operand_idx])); + post_ops.append_binary(dnnl::algorithm::binary_add, binary_md); + fused_operand_idx++; + } break; + default: + LOG(FATAL) + << __FILE__ << ":" << __LINE__ + << " Attempt to call OneDNN Convolution runtime library with " + "unsupported post op." + << std::endl; + } + } + auto any_ker_md = memory::desc(new_ker_md.get_dims(), new_ker_md.get_data_type(), dnnl::memory::format_tag::any); @@ -187,37 +237,41 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx); dnnl::primitive_attr attrs; + if (post_ops.len() > 0) { + attrs.set_post_ops(post_ops); + } + + auto conv_pd = std::make_unique( + cpu_engine, prop_kind::forward_inference, algorithm::convolution_direct, + any_inp_md, any_ker_md, bias_md, any_res_md, strides, rhs_dilations, + pad_left, pad_right, attrs); auto inp_mem = memory(new_inp_md, cpu_engine, inp_minfo.Data()); auto ker_mem = memory(new_ker_md, cpu_engine, ker_minfo.Data()); auto res_mem = memory(new_res_md, cpu_engine, res_minfo.Data()); - auto conv_pd = convolution_forward::primitive_desc( - cpu_engine, prop_kind::forward_inference, algorithm::convolution_direct, - any_inp_md, any_ker_md, any_res_md, strides, rhs_dilations, pad_left, - pad_right, attrs); - - auto new_inp_mem = (conv_pd.src_desc() == inp_mem.get_desc()) + auto new_inp_mem = (conv_pd->src_desc() == inp_mem.get_desc()) ? inp_mem - : ReorderMemory(cpu_engine, conv_pd.src_desc(), + : ReorderMemory(cpu_engine, conv_pd->src_desc(), inp_mem, onednn_stream); - auto new_ker_mem = (conv_pd.weights_desc() == ker_mem.get_desc()) + auto new_ker_mem = (conv_pd->weights_desc() == ker_mem.get_desc()) ? ker_mem - : ReorderMemory(cpu_engine, conv_pd.weights_desc(), + : ReorderMemory(cpu_engine, conv_pd->weights_desc(), ker_mem, onednn_stream); - auto new_res_mem = (conv_pd.dst_desc() == res_mem.get_desc()) + auto new_res_mem = (conv_pd->dst_desc() == res_mem.get_desc()) ? res_mem - : memory(conv_pd.dst_desc(), cpu_engine); + : memory(conv_pd->dst_desc(), cpu_engine); - auto conv_prim = convolution_forward(conv_pd); + auto conv_prim = convolution_forward(*conv_pd); std::unordered_map conv_args{{DNNL_ARG_SRC, new_inp_mem}, {DNNL_ARG_WEIGHTS, new_ker_mem}, {DNNL_ARG_DST, new_res_mem}}; + conv_args.insert(postop_args.begin(), postop_args.end()); conv_prim.execute(onednn_stream, conv_args); - if (conv_pd.dst_desc() == res_mem.get_desc()) { + if (conv_pd->dst_desc() == res_mem.get_desc()) { res_mem = new_res_mem; } else { dnnl::reorder(new_res_mem, res_mem) diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.h b/third_party/xla/xla/service/cpu/onednn_convolution.h index 19cbbe2e2a371a..657cddffb21afd 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution.h +++ b/third_party/xla/xla/service/cpu/onednn_convolution.h @@ -17,13 +17,22 @@ limitations under the License. #define XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_H_ #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/onednn_util.h" + namespace xla { namespace cpu { +constexpr auto kOnednnConvConfig = BackendConfigOneofCase::kOnednnConvConfig; + extern "C" { extern void __xla_cpu_runtime_OneDnnConvolution(void* result, void** args); } // extern "C" +template <> +struct PrimitiveTrait { + using pointer_type = xla::cpu::OneDnnConvolutionConfig*; +}; + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index 1b2dbee81c661b..b0b506d8fa2f55 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -230,6 +230,13 @@ std::unique_ptr CreateMatMulPrimDesc( weights_md, output_md, fused_mds, matmul_config); } +template <> +typename PrimitiveTrait::pointer_type +GetKernelConfig( + absl::StatusOr* backend_config) { + return (*backend_config)->mutable_onednn_matmul_config(); +} + template <> std::unique_ptr CreateOneDnnPrimDesc(HloInstruction* instr) { diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.h b/third_party/xla/xla/service/cpu/onednn_matmul.h index 09a2d6752ec29b..bf452e9d9f0518 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.h +++ b/third_party/xla/xla/service/cpu/onednn_matmul.h @@ -19,11 +19,15 @@ limitations under the License. #include "dnnl.hpp" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_util.h" #include "xla/shape.h" namespace xla { namespace cpu { +constexpr auto kOnednnMatmulConfig = + BackendConfigOneofCase::kOnednnMatmulConfig; + Shape OneDnnMatMulOptWeightsShape(const Shape& input_shape, const Shape& weights_shape, const Shape& bias_shape, @@ -36,6 +40,11 @@ extern void __xla_cpu_runtime_OneDnnMatMul(void* result, void* scratch, extern void __xla_cpu_runtime_OneDnnMatMulReorder(void* result, void** args); } // extern "C" +template <> +struct PrimitiveTrait { + using pointer_type = xla::cpu::OneDnnMatMulConfig*; +}; + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_util.h b/third_party/xla/xla/service/cpu/onednn_util.h index aaba304fc083fa..09bc5efb4b7574 100644 --- a/third_party/xla/xla/service/cpu/onednn_util.h +++ b/third_party/xla/xla/service/cpu/onednn_util.h @@ -23,6 +23,7 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/cpu/backend_config.pb.h" #include "xla/tsl/util/onednn_threadpool.h" #include "xla/xla_data.pb.h" #include "tsl/platform/cpu_info.h" @@ -58,11 +59,20 @@ dnnl::stream MakeOneDnnStream( const dnnl::engine& cpu_engine, dnnl::threadpool_interop::threadpool_iface* thread_pool); -// This template function must have explicit specialization at the definition +typedef BackendConfig::BackendConfigOneofCase BackendConfigOneofCase; + +// These template functions must have explicit specialization at the definition // site. template std::unique_ptr CreateOneDnnPrimDesc(HloInstruction*); +template +struct PrimitiveTrait; + +template +typename PrimitiveTrait::pointer_type GetKernelConfig( + absl::StatusOr*); + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc index 6bceebc7343c8e..d52d88c5bd062a 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc @@ -48,6 +48,30 @@ class ConvolutionTest : public HloTestBase { ; CHECK-DAG: } ; CHECK: } )"; + + const char* conv_rewrite_bias_str_ = R"( + ; CHECK: custom_call_target="__onednn$convolution", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_conv_config":{ + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BIAS"] + ; CHECK-DAG: } + ; CHECK-DAG: } + ; CHECK: } + )"; + + const char* fused_convolution_binary_add_ = R"( + ; CHECK: custom_call_target="__onednn$convolution", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_conv_config":{ + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BINARY_ADD"] + ; CHECK-DAG: } + ; CHECK-DAG: } + ; CHECK: } + )"; }; TEST_F(ConvolutionTest, Simple2DTestF32) { @@ -55,9 +79,9 @@ TEST_F(ConvolutionTest, Simple2DTestF32) { HloModule convolution.test.f32 ENTRY convolution.test.f32 { - arg.0 = f32[1,22,22,1] parameter(0), parameter_replication={false} + arg.0 = f32[1,22,22,1] parameter(0) reshape.0 = f32[1,22,22,1] reshape(arg.0) - arg.1 = f32[8,8,1,1] parameter(1), parameter_replication={false} + arg.1 = f32[8,8,1,1] parameter(1) reshape.1 = f32[8,8,1,1] reshape(arg.1) convolution.0 = f32[1,11,11,1] convolution(reshape.0, reshape.1), window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f reshape.2 = f32[1,11,11,1] reshape(convolution.0) @@ -94,6 +118,7 @@ TEST_F(ConvolutionTest, Simple2DTestF16) { const char* convolution_module_str = R"( HloModule convolution.test.f16 + ENTRY convolution.test.bf16 { p0 = f16[8,4,5,5,1] parameter(0) p1 = f16[3,3,3,1,32] parameter(1) @@ -104,6 +129,62 @@ TEST_F(ConvolutionTest, Simple2DTestF16) { MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_); } +TEST_F(ConvolutionTest, Conv3DWithBiasBF16) { + const char* convolution_module_str = R"( + HloModule convolution.test.with.bias.relu.bf16.3D + + ENTRY TestComputation { + arg.0 = bf16[15,4,5,5,28] parameter(0) + arg.1 = bf16[3,3,3,28,64] parameter(1) + conv = bf16[15,4,5,5,64] convolution(arg.0, arg.1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f + bias = bf16[64] parameter(2) + broadcasted_bias = bf16[15,4,5,5,64] broadcast(bias), dimensions={4} + ROOT add = bf16[15,4,5,5,64] add(conv, broadcasted_bias) +})"; + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{0.01, 0.01})); + MatchOptimizedHlo(convolution_module_str, conv_rewrite_bias_str_); +} + +TEST_F(ConvolutionTest, SimpleTestF32WithBinaryAddFusion1) { + const char* convolution_module_str = R"( + HloModule conv.binaryadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[1,22,22,1] parameter(0) + constant.3 = f32[] constant(1) + broadcast.4 = f32[8,8,1,1] broadcast(constant.3), dimensions={} + convolution.0 = f32[1,11,11,1] convolution(arg0.1, broadcast.4), window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + constant.5 = f32[] constant(15) + broadcast.6 = f32[1] broadcast(constant.5), dimensions={} + broadcast.9 = f32[1,11,11,1] broadcast(broadcast.6), dimensions={3} + ROOT add.10 = f32[1,11,11,1] add(convolution.0, broadcast.9) + })"; + + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(convolution_module_str, fused_convolution_binary_add_); +} + +// This test should match BIAS + Residual Add when the residual add fusion is +// re-enabled. +TEST_F(ConvolutionTest, SimpleTestBF16WithBiasAndAddFusion) { + const char* convolution_module_str = R"( + HloModule convolution.add.test.bf16 + + ENTRY convolution.add.test.bf16 { + arg0.1 = bf16[1,22,22,1] parameter(0) + arg0.2 = bf16[8,8,1,10] parameter(1) + convolution.0 = bf16[1,11,11,10] convolution(arg0.1, arg0.2), window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + const.0 = bf16[10] constant(15) + bcast.1 = bf16[1,11,11,10] broadcast(const.0), dimensions={3} + add.0 = bf16[1,11,11,10] add(convolution.0, bcast.1) + const.1 = bf16[1,11,11,10] constant({...}) + ROOT add.1 = bf16[1,11,11,10] add(add.0, const.1) + })"; + + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(convolution_module_str, conv_rewrite_bias_str_); +} + } // namespace cpu } // namespace xla From 5cdf0f123915bb437cfe76b2000a544c6b17682d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 07:31:27 -0700 Subject: [PATCH 321/483] [XLA:GPU] Add fp8 layout support to assign contrasting dim to be minor most. This is important for performance both for Triton and cuBLASLT FP8 Gemms. Due to GPU kernel constraints, XLA inserts an additional expensive transpose operation before the quantized gemm. PiperOrigin-RevId: 679141198 --- .../gpu/transforms/layout_assignment.cc | 5 +++- .../gpu/transforms/layout_assignment_test.cc | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index 99f306cfc64241..d63b45fc550847 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -361,8 +361,11 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( output_shape.dimensions_size() == 2 && lhs_shape.dimensions_size() == 2 && rhs_shape.dimensions_size() == 2); + bool is_fp8_to_fp8 = + (lhs_shape.element_type() == PrimitiveType::F8E4M3FN && + rhs_shape.element_type() == PrimitiveType::F8E4M3FN); - if (is_s8_to_s32 || + if (is_s8_to_s32 || is_fp8_to_fp8 || (is_bf16_to_bf16 && debug_options.xla_gpu_ensure_minor_dot_contraction_dims())) { TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 4dbd453e1d4850..a6b8f07e41065b 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -655,6 +655,35 @@ ENTRY main { LayoutUtil::MakeLayout({1, 3, 2, 0}).minor_to_major()); } +TEST_F(LayoutAssignmentTest, AutoLayoutE4M3ContractingMinorFirst) { + const char* hlo = R"( + + HloModule jit_dot_general_f8e4m3fn + + ENTRY main { + p0 = f8e4m3fn[128,5120] parameter(0) + p1 = f8e4m3fn[5120,10240] parameter(1) + ROOT dot = f32[128,10240] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnUnverifiedModule( + hlo, {}, HloParserOptions().set_fill_missing_layouts(false))); + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::Dot(m::Parameter(0).WithShape(F8E4M3FN, {128, 5120}, {1, 0}), + m::Parameter(1).WithShape(F8E4M3FN, {5120, 10240}, {0, 1})) + .WithShape(F32, {128, 10240}, {1, 0}))); +} + TEST_F(LayoutAssignmentTest, VariadicReduceSameOperandLayout) { const char* module_str = R"( HloModule variadic_reduce From 71eb0d54d61ca3b5f2a990ad2dfaad72e55c35d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 07:44:25 -0700 Subject: [PATCH 322/483] Enable polling for error from coordination service at startup by default. PiperOrigin-RevId: 679145273 --- third_party/xla/xla/pjrt/distributed/client.h | 5 ++--- .../xla/pjrt/distributed/client_server_test.cc | 18 ++++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/pjrt/distributed/client.h b/third_party/xla/xla/pjrt/distributed/client.h index 2387fe6dd452f5..bb177716797659 100644 --- a/third_party/xla/xla/pjrt/distributed/client.h +++ b/third_party/xla/xla/pjrt/distributed/client.h @@ -104,9 +104,8 @@ class DistributedRuntimeClient { // Whether the client should send a request to wait for error from the // coordination service at the startup. - // TODO(b/355706798): Enable this by default once we confirm this works for - // all cases and eventually remove this option. - bool poll_for_error_from_service_at_startup = false; + // TODO(b/355706798): eventually remove this option. + bool poll_for_error_from_service_at_startup = true; }; virtual ~DistributedRuntimeClient() = default; diff --git a/third_party/xla/xla/pjrt/distributed/client_server_test.cc b/third_party/xla/xla/pjrt/distributed/client_server_test.cc index 1b55bab10e1caa..462680d5fc764a 100644 --- a/third_party/xla/xla/pjrt/distributed/client_server_test.cc +++ b/third_party/xla/xla/pjrt/distributed/client_server_test.cc @@ -379,7 +379,8 @@ TEST_F(ClientServerTest, ZeroInitTimeoutShouldStillWaitForOtherTasks) { } } -TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { +TEST_F(ClientServerTest, + ClientsTerminateShutdownIfAnyClientGoesAway_WithoutErrorPolling) { int num_nodes = 3; StartService(num_nodes); @@ -425,8 +426,7 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { } } -TEST_F(ClientServerTest, - ClientsTerminateShutdownIfAnyClientGoesAway_WithErrorPolling) { +TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { int num_nodes = 3; StartService(num_nodes); @@ -435,7 +435,6 @@ TEST_F(ClientServerTest, client_options.shutdown_on_destruction = node_id != 0; client_options.missed_heartbeat_callback = [&](absl::Status status, bool coordinator_initiated) {}; - client_options.poll_for_error_from_service_at_startup = true; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -466,7 +465,7 @@ TEST_F(ClientServerTest, } } -TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) { +TEST_F(ClientServerTest, ClientsShutdownSuccessfully) { int num_nodes = 3; StartService(num_nodes); @@ -475,7 +474,6 @@ TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) { client_options.shutdown_on_destruction = true; client_options.missed_heartbeat_callback = [&](absl::Status status, bool coordinator_initiated) {}; - client_options.poll_for_error_from_service_at_startup = true; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -497,8 +495,7 @@ TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) { } } -TEST_F(ClientServerTest, - MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway_WithErrorPolling) { +TEST_F(ClientServerTest, MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway) { int num_nodes = 3; StartService(num_nodes); @@ -510,7 +507,6 @@ TEST_F(ClientServerTest, bool coordinator_initiated) { shutdown.Notify(); }; - client_options.poll_for_error_from_service_at_startup = true; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -535,7 +531,8 @@ TEST_F(ClientServerTest, } } -TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { +TEST_F(ClientServerTest, + ClientsReceiveMissedHeartbeatIfAnyClientGoesAway_WithoutErrorPolling) { int num_nodes = 3; StartService(num_nodes); @@ -547,6 +544,7 @@ TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { bool coordinator_initiated) { shutdown.Notify(); }; + client_options.poll_for_error_from_service_at_startup = false; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); From 596234a4148b842926a9caccc67d70d5c967eab1 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Thu, 26 Sep 2024 07:51:56 -0700 Subject: [PATCH 323/483] [XLA:GPU][NFC] Expose `ScheduleGpuModuleWithMemoryScheduler` in `gpu_hlo_schedule.h`. PiperOrigin-RevId: 679147367 --- .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 22 +++++++++---------- .../xla/xla/service/gpu/gpu_hlo_schedule.h | 5 +++++ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index ee2e9874e358cf..6f59064b4b56eb 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -247,17 +247,6 @@ HloInstructionSequence PostprocessorToScheduleSyncCollectives( return result; } -absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( - const HloModule* module, int64_t pointer_size) { - return ScheduleModule( - module, - [pointer_size](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); - }, - ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler, - PostProcessSchedule)); -} - // Latency hiding scheduler support. SchedulerConfig GetSchedulerConfig(int64_t memory_limit) { @@ -533,6 +522,17 @@ absl::StatusOr ScheduleGpuModule( return ScheduleMetadata{memory_limit}; } +absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( + const HloModule* module, int64_t pointer_size) { + return ScheduleModule( + module, + [pointer_size](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); + }, + ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler, + PostProcessSchedule)); +} + HloInstructionSequence PostProcessSchedule( const HloInstructionSequence& input) { HloInstructionSequence result = PostprocessorToScheduleSyncCollectives(input); diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h index b71226c20710a9..608f8b2ac52a4e 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h @@ -37,6 +37,11 @@ absl::StatusOr ScheduleGpuModule( HloModule* module, int64_t pointer_size, const se::DeviceDescription& gpu_device_info); +// Schedules a GPU module with `DefaultMemoryScheduler` and +// `PostProcessSchedule` postprocessing. +absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( + const HloModule* module, int64_t pointer_size); + HloInstructionSequence PostProcessSchedule(const HloInstructionSequence& input); constexpr absl::string_view kFingerprintBeforeLHS = "fingerprint_before_lhs"; From 76254692ee74b1f29956182fc6c599d5cf15cb69 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Thu, 26 Sep 2024 08:34:32 -0700 Subject: [PATCH 324/483] [XLA:GPU] Get peak memory bytes from module scheduling. PiperOrigin-RevId: 679161714 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 5 ++-- .../xla/xla/service/gpu/gpu_hlo_schedule.h | 6 +++-- .../xla/service/gpu/gpu_hlo_schedule_test.cc | 27 +++++++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index edc11b7c230bf2..dbac4b1661ab61 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2119,6 +2119,7 @@ xla_test( ], backends = ["gpu"], deps = [ + ":gpu_compiler", ":gpu_hlo_schedule", "//xla:shape_util", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 6f59064b4b56eb..6783ef2cb169b7 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -523,14 +523,15 @@ absl::StatusOr ScheduleGpuModule( } absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( - const HloModule* module, int64_t pointer_size) { + const HloModule* module, int64_t pointer_size, int64_t* peak_memory_bytes) { return ScheduleModule( module, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); }, ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler, - PostProcessSchedule)); + PostProcessSchedule), + /*execution_threads=*/{}, /*peak_memory=*/peak_memory_bytes); } HloInstructionSequence PostProcessSchedule( diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h index 608f8b2ac52a4e..af0e7dd3fa9afc 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h @@ -38,9 +38,11 @@ absl::StatusOr ScheduleGpuModule( const se::DeviceDescription& gpu_device_info); // Schedules a GPU module with `DefaultMemoryScheduler` and -// `PostProcessSchedule` postprocessing. +// `PostProcessSchedule` postprocessing. If `peak_memory_bytes` is not nullptr, +// then the it will be set to peak memory usage in bytes. absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( - const HloModule* module, int64_t pointer_size); + const HloModule* module, int64_t pointer_size, + int64_t* peak_memory_bytes = nullptr); HloInstructionSequence PostProcessSchedule(const HloInstructionSequence& input); diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index 0f9c8412bcdfc1..7a9ca28cf3bb29 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/backend.h" +#include "xla/service/gpu/gpu_compiler.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_ordering.h" #include "xla/shape.h" @@ -347,6 +348,32 @@ TEST_F(GpuHloScheduleTest, LHSCostModel) { EXPECT_TRUE(HasValidFingerprint(module.get())); } +TEST_F(GpuHloScheduleTest, + ScheduleGpuModuleWithMemorySchedulerReturnsPeakMemoryBytes) { + absl::string_view kHloText = R"( + HloModule m + + ENTRY ar { + p0 = f32[32,32] parameter(0) + p1 = f32[32,32] parameter(1) + + ROOT _ = f32[32,32]{1,0} custom-call(p0, p1), + custom_call_target="__cublas$gemm" + })"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + kHloText, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true))); + int64_t pointer_size = + dynamic_cast(backend().compiler())->GetPointerSize(); + int64_t peak_memory_bytes = -1; + TF_ASSERT_OK_AND_ASSIGN(auto schedule, + ScheduleGpuModuleWithMemoryScheduler( + module.get(), pointer_size, &peak_memory_bytes)); + EXPECT_GT(peak_memory_bytes, 0); +} + TEST_F(GpuHloScheduleTest, LHSCostModelCostlyAR) { const char* hlo_text = R"( HloModule AsyncAR From 98cef7a10c3abffbeba1433a6ec4e8f88f6ca1e2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 26 Sep 2024 08:38:46 -0700 Subject: [PATCH 325/483] Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class. Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API. PiperOrigin-RevId: 679163106 --- third_party/xla/xla/python/xla.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 987b06b9f8dd4a..868a3aa9d74016 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -182,6 +182,10 @@ NB_MODULE(xla_extension, m_nb) { // Exceptions nb::exception xla_runtime_error(m_nb, "XlaRuntimeError", PyExc_RuntimeError); + xla_runtime_error.attr("__doc__") = nb::str( + "Runtime errors thrown by the JAX runtime. While the JAX runtime may " + "raise other exceptions as well, most exceptions thrown by the runtime " + "are instances of this class."); // Types nb::enum_(m_nb, "PrimitiveType", nb::is_arithmetic()) From ec7e22e8e1d626e4d7c3ffa1b08822cb4fe01ce6 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 26 Sep 2024 08:44:58 -0700 Subject: [PATCH 326/483] #sdy add JAX Shardy support for shard_map. For example the following JAX program: ```py devices = np.array(jax.devices()[:8]) mesh = Mesh(devices, axis_names=('x')) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None))) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): axis_size = lax.psum(1, 'x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) print(jax.jit(fwd).lower(a).as_text()) ``` prints: ```cpp module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["x"=8]> func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) { %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) { %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32> sdy.return %1 : tensor<1x8xi32> } : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } } ``` PiperOrigin-RevId: 679165100 --- third_party/xla/xla/service/spmd/shardy/BUILD | 1 + .../service/spmd/shardy/sdy_round_trip/BUILD | 5 +- .../spmd/shardy/sdy_round_trip/pipelines.cc | 13 +++-- .../shardy/sdy_round_trip/shard_map_import.cc | 20 ++++--- .../spmd/shardy/shardy_call_inliner.cc | 5 +- .../spmd/shardy/shardy_call_inliner_test.cc | 54 +++++++++++++++++++ ...ound_trip_pipeline_manual_computation.mlir | 18 +++---- .../test/sdy_round_trip_shard_map_import.mlir | 36 ++++++------- ...y_round_trip_shard_map_import_failure.mlir | 8 +-- 9 files changed, 109 insertions(+), 51 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD index fb0b39c4195ace..db4d32ecc78ea0 100644 --- a/third_party/xla/xla/service/spmd/shardy/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/BUILD @@ -30,6 +30,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/service:call_inliner", + "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/strings", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD index ebc36cdc54d2a9..e7a81caaa33155 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -107,7 +107,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", - "@stablehlo//:stablehlo_ops", ], ) @@ -119,11 +118,15 @@ cc_library( ":export_ops", ":export_shardings", ":import_shardings", + ":shard_map_export", + ":shard_map_import", + "//xla/mlir_hlo:mhlo_passes", "//xla/service:hlo_proto_cc", "//xla/service/spmd/shardy/mhlo_round_trip:export_shardings", "//xla/service/spmd/shardy/mhlo_round_trip:shard_map_import", "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc index 14ad8133ab33f0..78fe16400e2299 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" #include "xla/service/hlo.pb.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h" @@ -27,6 +28,8 @@ limitations under the License. #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h" +#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h" +#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" namespace xla { namespace sdy { @@ -34,12 +37,13 @@ namespace sdy { using ::mlir::PassPipelineRegistration; void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) { - // NOTE: we don't do any exporting for ManualComputationOp, since during - // SDY round-trip we expect the same pattern of custom calls to continue to - // exist. We save `sdy.sharding`s on those custom calls during + // Run canonicalizer to simplify `ManualComputationOp`s. + pm.addPass(mlir::createCanonicalizerPass()); + // We save `sdy.sharding`s on those custom calls during // `createSdyRoundTripExportShardingsPass` and make use of // `createSdyRoundTripImportShardingsPass` to import them. pm.addPass(createSdyRoundTripExportOpsPass()); + pm.addPass(createSdyRoundTripShardMapExportPass()); // Preserve the SDY shardings for `createExportMhloShardingsPass` so that // we have both `mhlo.sharding`s and hidden `sdy.sharding`s on the module. We // want to have `mhlo.sharding`s for Pathways to read from. @@ -50,8 +54,7 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) { void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) { addCommonPreImportPasses(pm); pm.addPass(createSdyRoundTripImportShardingsPass()); - // TODO(bartchr): replace with an sdy round trip shard map pass. - pm.addPass(createMhloRoundTripShardMapImportPass()); + pm.addPass(createSdyRoundTripShardMapImportPass()); addCommonPostImportPasses(pm); } diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc index d577a0820d0a94..7a0e1d018e0c2e 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc @@ -44,7 +44,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "stablehlo/dialect/StablehloOps.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -60,8 +60,8 @@ using ::mlir::StringRef; using ::mlir::SymbolTable; using ::mlir::func::CallOp; using ::mlir::func::FuncOp; +using ::mlir::mhlo::CustomCallOp; -namespace stablehlo = ::mlir::stablehlo; namespace sdy = ::mlir::sdy; // Converts a CallOp calling a @local_xla.sdy.manual_computation_body func with in/out @@ -86,23 +86,22 @@ class ManualComputationPattern : public OpConversionPattern { // we have to take the operands/results of the newly created // `ManualComputationOp` differently depending on whether the original had // operands/results. - stablehlo::CustomCallOp fullToShard; + CustomCallOp fullToShard; mlir::ValueRange operands = callOp->getOperands(); if (!operands.empty()) { - fullToShard = - callOp->getOperand(0).getDefiningOp(); - operands = fullToShard->getOperands(); + fullToShard = callOp->getOperand(0).getDefiningOp(); CHECK(fullToShard); CHECK(fullToShard.getCallTargetName() == kGlobalToLocalShapeCallTargetName); + operands = fullToShard->getOperands(); } mlir::TypeRange resultTypes = callOp->getResultTypes(); - stablehlo::CustomCallOp shardToFull; + CustomCallOp shardToFull; if (!resultTypes.empty()) { CHECK(callOp->getResult(0).hasOneUse()) << "all CallOp results should be used by a single ShardToFull"; - shardToFull = mlir::cast( - *callOp->getResult(0).getUsers().begin()); + shardToFull = + mlir::cast(*callOp->getResult(0).getUsers().begin()); CHECK(shardToFull.getCallTargetName() == kLocalToGlobalShapeCallTargetName); resultTypes = shardToFull->getResultTypes(); @@ -161,8 +160,7 @@ class SdyRoundTripShardMapImportPass target.addDynamicallyLegalOp([](CallOp op) { return !absl::StartsWith(op.getCallee(), kManualComputationBodyFuncName); }); - target.addLegalOp(); + target.addLegalOp(); mlir::RewritePatternSet patterns(&context); patterns.add(&context, symbolTable); if (mlir::failed(mlir::applyPartialConversion(module, target, diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc index c7564dbf3ed140..2de735c98ecbbc 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc @@ -18,13 +18,16 @@ limitations under the License. #include "absl/strings/match.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/call_inliner.h" +#include "xla/service/spmd/shardy/constants.h" namespace xla { bool ShardyCallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return CallInliner::IsInlineableCallOp(instruction) && !(instruction->GetModule()->config().use_shardy_partitioner() && - absl::StrContains(instruction->to_apply()->name(), "shmap_body")); + (absl::StrContains(instruction->to_apply()->name(), "shmap_body") || + absl::StartsWith(instruction->to_apply()->name(), + sdy::kManualComputationBodyFuncName))); } } // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc index 00d952b3b80461..b2055e59d75c34 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc @@ -57,5 +57,59 @@ TEST_F(ShardyCallInlinerTest, MhloToHloShmapBodyNotInlined) { EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4"); } +// Don't inline when the name starts with "xla.sdy.manual_computation_body". +TEST_F(ShardyCallInlinerTest, ManualComputationBodyNotInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4 + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // The single call in the module is not inlined. + EXPECT_FALSE(changed); + + HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); + EXPECT_NE(call, nullptr); + EXPECT_TRUE(call->has_to_apply()); + EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4"); +} + +// Inliner only checks if the name of the function has +// "xla.sdy.manual_computation_body" a prefix, not if it contains it. +TEST_F(ShardyCallInlinerTest, ManualComputationBodyInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4 + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // Will be inlined. + EXPECT_TRUE(changed); +} + } // namespace sdy } // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir index 37e94fbb510bb6..90754f8e9bf0a2 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir @@ -18,17 +18,13 @@ func.func @main(%arg0: tensor<16x32xf32>) -> tensor<128x32xf32> { // CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> // CHECK-NEXT: return %[[ADD]] : tensor<128x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %2:2 = call @shmap_body_4(%1) : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) - %3 = mhlo.custom_call @Sharding(%2#0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a", "b"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> - %5 = mhlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %6 = mhlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> - %7 = mhlo.add %4, %6 : tensor<128x32xf32> - return %7 : tensor<128x32xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1:2 = call @local_xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={\\\22a\\\22, \\\22b\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22, \\\22b\\\22}, {}]>, <@mesh_1, [{\\\22b\\\22, \\\22a\\\22}, {}]>]>"}} : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) + %2:2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1#0, %1#1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) + %3 = mhlo.add %2#0, %2#1 : tensor<128x32xf32> + return %3 : tensor<128x32xf32> } -// CHECK-NOT: func.func private @shmap_body_4 -func.func private @shmap_body_4(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) { +// CHECK-NOT: func.func private @local_xla.sdy.manual_computation_body +func.func private @local_xla.sdy.manual_computation_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) { return %arg0, %arg0 : tensor<16x32xf32>, tensor<16x32xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir index c0ecce7e8b67de..0f55988e0f123b 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir @@ -20,9 +20,9 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) // CHECK-NEXT: sdy.return %[[REDUCE]] : tensor<2x32xf32> // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x32xf32> - %0:2 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + %0:2 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> return %2 : tensor<8x32xf32> } @@ -44,20 +44,20 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: sdy.return %arg1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[MAN_COMP_1]] : tensor<8x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_0(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> + %3 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> %4 = call @local_xla.sdy.manual_computation_body_1(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x4xf32>) -> tensor<8x4xf32> - %5 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> + %5 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> return %5 : tensor<8x8xf32> } // CHECK-NOT: func @local_xla.sdy.manual_computation_body_3( func.func @local_xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> %1 = call @local_xla.sdy.manual_computation_body_2(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> return %2 : tensor<2x8xf32> } @@ -85,9 +85,9 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[MAN_COMP_1]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_3(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -110,9 +110,9 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[ADD]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_5(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -128,7 +128,7 @@ func.func @manual_computation_no_inputs() -> tensor<4xi64> { // CHECK-NEXT: } : () -> tensor<4xi64> // CHECK-NEXT: return %[[SHMAP]] : tensor<4xi64> %0 = call @local_xla.sdy.manual_computation_body_6() {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} : () -> tensor<2xi64> - %1 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%0) : (tensor<2xi64>) -> tensor<4xi64> + %1 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%0) : (tensor<2xi64>) -> tensor<4xi64> return %1 : tensor<4xi64> } @@ -139,11 +139,11 @@ func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { // CHECK-SAME{LITERAL}: out_shardings=[] // CHECK-SAME{LITERAL}: manual_axes={"b"} // CHECK-SAME{LITERAL}: (%arg1: tensor<2xi64>) { - // CHECK-NEXT: stablehlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () + // CHECK-NEXT: mhlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () // CHECK-NEXT: sdy.return // CHECK-NEXT: } : (tensor<4xi64>) -> () // CHECK-NEXT: return - %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> call @local_xla.sdy.manual_computation_body_7(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} : (tensor<2xi64>) -> () return } @@ -178,9 +178,9 @@ func.func @local_xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> te // CHECK-NOT: func @local_xla.sdy.manual_computation_body_5( func.func @local_xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> %1 = call @local_xla.sdy.manual_computation_body_4(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> %3 = stablehlo.add %2, %2 : tensor<2x8xf32> return %3 : tensor<2x8xf32> } @@ -193,6 +193,6 @@ func.func @local_xla.sdy.manual_computation_body_6() -> tensor<2xi64> { // CHECK-NOT: func @local_xla.sdy.manual_computation_body_7( func.func @local_xla.sdy.manual_computation_body_7(%arg0: tensor<2xi64>) { - stablehlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () + mhlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () return } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir index ba5f28da7a7484..9f2a3a5740924d 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir @@ -3,14 +3,14 @@ sdy.mesh @mesh = <["a"=2]> func.func @using_same_body_func(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) %1 = call @local_xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) - %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) - %3 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %3 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) // expected-error @+2 {{'func.call' op expected a unique FuncOp per @local_xla.sdy.manual_computation_body call}} // expected-error @+1 {{failed to legalize operation 'func.call'}} %4 = call @local_xla.sdy.manual_computation_body(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) - %5 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %5 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) return %5 : tensor<8x8xf32> } From ee313c5f350b65d858a7730ded6d26489ce9cf3a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 09:21:07 -0700 Subject: [PATCH 327/483] Remove -Wno-error=unused-but-set-variable from .bazelrc This flag is no longer available since clang 10.0.0. PiperOrigin-RevId: 679177375 --- .bazelrc | 2 -- third_party/xla/.bazelrc | 2 -- third_party/xla/third_party/tsl/.bazelrc | 2 -- 3 files changed, 6 deletions(-) diff --git a/.bazelrc b/.bazelrc index 8201ce4582a00f..bee4dc3e784a99 100644 --- a/.bazelrc +++ b/.bazelrc @@ -329,8 +329,6 @@ build:linux --copt="-Werror=unused-result" # Add switch as an error on Linux. build:linux --copt="-Wswitch" build:linux --copt="-Werror=switch" -# Required for building with clang -build:linux --copt="-Wno-error=unused-but-set-variable" # Linux ARM64 specific options build:linux_arm64 --copt="-mtune=generic" --copt="-march=armv8-a" --copt="-O3" diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 8201ce4582a00f..bee4dc3e784a99 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -329,8 +329,6 @@ build:linux --copt="-Werror=unused-result" # Add switch as an error on Linux. build:linux --copt="-Wswitch" build:linux --copt="-Werror=switch" -# Required for building with clang -build:linux --copt="-Wno-error=unused-but-set-variable" # Linux ARM64 specific options build:linux_arm64 --copt="-mtune=generic" --copt="-march=armv8-a" --copt="-O3" diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index 8201ce4582a00f..bee4dc3e784a99 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -329,8 +329,6 @@ build:linux --copt="-Werror=unused-result" # Add switch as an error on Linux. build:linux --copt="-Wswitch" build:linux --copt="-Werror=switch" -# Required for building with clang -build:linux --copt="-Wno-error=unused-but-set-variable" # Linux ARM64 specific options build:linux_arm64 --copt="-mtune=generic" --copt="-march=armv8-a" --copt="-O3" From 96eaf3a595e20b76ede1d2360cc75605994bc8ba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 10:10:15 -0700 Subject: [PATCH 328/483] Adds batch size field in sparsecore layouts proto. This is the batch size for the unstacked table in a particular stacking setup. PiperOrigin-RevId: 679195588 --- tensorflow/core/tpu/kernels/sparse_core_layout.proto | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.proto b/tensorflow/core/tpu/kernels/sparse_core_layout.proto index 6b7bbd9ed5ebe5..78231c6e73cf41 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_layout.proto +++ b/tensorflow/core/tpu/kernels/sparse_core_layout.proto @@ -94,6 +94,10 @@ message SparseCoreTableLayout { // partitions. // sparse_core_shard_rotation = table_index * sparse_cores_per_partition int64 sparse_core_shard_rotation = 9; + + // The batch size per sparsecore for this table. This combines the batch sizes + // of all the features pointing to this table. + int64 per_sparse_core_batch_size = 10; } message SparseCoreTableLayouts { From a78f8a5a4b953eaa59563e2f49bc3544d4f9959e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 10:17:04 -0700 Subject: [PATCH 329/483] Add sparse core step time breakdown to overview page. PiperOrigin-RevId: 679197985 --- tensorflow/core/profiler/convert/BUILD | 4 ++++ .../op_stats_to_input_pipeline_analysis.cc | 1 + .../core/profiler/convert/xplane_to_op_stats.cc | 15 ++++++++++----- .../profiler/convert/xplane_to_step_events.cc | 9 ++++++++- .../core/profiler/utils/op_metrics_db_utils.cc | 15 +++++++++++++++ .../core/profiler/utils/op_metrics_db_utils.h | 2 ++ .../tsl/tsl/profiler/utils/xplane_schema.cc | 2 ++ .../tsl/tsl/profiler/utils/xplane_schema.h | 4 +++- .../tsl/tsl/profiler/utils/xplane_utils.cc | 3 ++- 9 files changed, 47 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 9233aded397147..82538b6392aa34 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -337,6 +337,7 @@ cc_library( "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -367,6 +368,7 @@ tf_cc_test( ":repository", ":step_events_to_steps_db", ":xplane_to_op_stats", + ":xplane_to_step_events", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:test", @@ -377,11 +379,13 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:op_metrics_db_utils", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_tsl//tsl/profiler/utils:group_events", + "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index c02c3d0067e49a..39b0ef3aebfda6 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -595,6 +595,7 @@ StepSummary ComputeStepTimeSummaryInMs( // iterates over each core. for (const auto& coreid_and_stepinfo : coreid_stepinfo_map.step_info_per_core()) { + if (coreid_and_stepinfo.first >= kSparseCoreIndexStart) continue; const auto& step_info = coreid_and_stepinfo.second; max_per_step_stats_in_ms = std::max(step_info.duration_ps() / kNumPsPerMs, max_per_step_stats_in_ms); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 46f46fabe65f4b..21326dacbf98df 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -46,11 +46,13 @@ limitations under the License. #include "tsl/profiler/utils/tf_xplane_visitor.h" #include "tsl/profiler/utils/tpu_xplane_utils.h" #include "tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { namespace { +using tsl::profiler::FindPlanesWithPrefix; using tsl::profiler::FindTensorCorePlanes; std::string Hostname(const XSpace& space) { @@ -179,11 +181,6 @@ void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, OpStats ConvertXSpaceToOpStats(const XSpace& space, const OpStatsOptions& options) { - std::vector device_planes = FindTensorCorePlanes(space); - bool is_tpu = !device_planes.empty(); - if (!is_tpu) { - device_planes = FindPlanesWithPrefix(space, kGpuPlanePrefix); - } OpStats op_stats; StepEvents step_events; PropagateXSpaceDiagnosticsToOpStats(space, &op_stats); @@ -194,6 +191,14 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, KernelReportMap reports; + // Handle device planes first. device_planes will contain either GPU or TPU. + std::vector device_planes = + FindPlanesWithPrefix(space, kTpuPlanePrefix); + const bool is_gpu = device_planes.empty(); + if (is_gpu) { + device_planes = FindPlanesWithPrefix(space, kGpuPlanePrefix); + } + const bool is_tpu = !is_gpu; // TODO(b/161942993) parallelize XPlane processing per thread. for (const XPlane* device_trace : device_planes) { XPlane aggregated_xplane; diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc index 8e50ed1ad18a93..e1591a43195e54 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc @@ -281,11 +281,14 @@ StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace) { StepEvents device_step_events; XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); std::optional tpu_core_id = tsl::profiler::GetTensorCoreId(plane.Name()); + std::optional sc_core_id = tsl::profiler::GetSparseCoreId(plane.Name()); plane.ForEachLine([&](const XLineVisitor& line) { int64_t line_id = line.Id(); if (line_id == kThreadIdStepInfo || (tpu_core_id.has_value() && - line.Name() == tsl::profiler::kStepLineName)) { + line.Name() == tsl::profiler::kStepLineName) || + (sc_core_id.has_value() && + line.Name() == tsl::profiler::kSparseCoreStepLineName)) { StepEvents step_marker_events = ConvertDeviceStepInfoToStepMarkers(line); UnionCombineStepEvents(step_marker_events, &device_step_events); } else if (IsDerivedThreadId(line_id)) { @@ -300,6 +303,10 @@ StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace) { stream_step_events = ConvertTpuDeviceTraceXLineToStepEvents(*tpu_core_id, line); IntersectCombineStepEvents(stream_step_events, &device_step_events); + } else if (sc_core_id.has_value()) { + stream_step_events = ConvertTpuDeviceTraceXLineToStepEvents( + kSparseCoreIndexStart + *sc_core_id, line); + IntersectCombineStepEvents(stream_step_events, &device_step_events); } else { stream_step_events = ConvertDeviceTraceXLineToStepEvents(plane.Id(), line); diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc index c1206949c336f3..61e7fa812d7c51 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -36,6 +37,7 @@ namespace tensorflow { namespace profiler { const absl::string_view kIdle = "IDLE"; +const uint32_t kSparseCoreIndexStart = 1000; namespace { @@ -226,6 +228,19 @@ OpMetrics* OpMetricsDbBuilder::LookupOrInsertNewOpMetrics( void XEventsOpMetricsDbBuilder::AddOpMetric( const tsl::profiler::XEventVisitor& event) { OpKey key = GetOpKeyFromHloEventMetadata(event.Metadata()); + std::optional stat = event.GetStat(StatType::kStepIdleTimePs); + if (stat.has_value()) { + uint64_t idle_time_ps = stat->IntOrUintValue(); + OpMetrics op_metrics; + op_metrics.set_self_time_ps(event.DurationPs() - idle_time_ps); + op_metrics.set_name("sparse_core_busy_ops"); + // TODO: Make it meaningful after SC stats are available. + op_metrics.set_category("sparse_core_busy_ops"); + constexpr uint64_t kMaxProgramId = std::numeric_limits::max(); + constexpr uint64_t kMaxSymbolId = std::numeric_limits::max(); + flat_op_metric_[kMaxProgramId][kMaxSymbolId] = op_metrics; + SetOpMetricsFromHloEvent(event, &op_metrics); + } if (!key.program_id.has_value() || !key.symbol_id.has_value()) return; OpMetricBySymbol& op_metric_by_symbol = flat_op_metric_[key.program_id.value()]; diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/tensorflow/core/profiler/utils/op_metrics_db_utils.h index 27cdfb61fa7800..a095a8e451cf0f 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.h +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -34,6 +34,8 @@ namespace profiler { // The name of OpMetrics to represent the idle time. TF_CONST_INIT extern const absl::string_view kIdle; +// The core index to add to sparse core index in op metrics. +TF_CONST_INIT extern const uint32_t kSparseCoreIndexStart; // Helps build an op metrics database (borrowed). // Enables fast lookup of existing ops and prevents the creation of duplicate diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index deed680f1d8bb5..6c088dfb8f0192 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -49,6 +49,7 @@ const absl::string_view kHostCpusPlaneName = "Host CPUs"; const absl::string_view kSyscallsPlaneName = "Syscalls"; const absl::string_view kStepLineName = "Steps"; +const absl::string_view kSparseCoreStepLineName = "Sparse Core Steps"; const absl::string_view kTensorFlowNameScopeLineName = "Framework Name Scope"; const absl::string_view kTensorFlowOpLineName = "Framework Ops"; const absl::string_view kXlaModuleLineName = "XLA Modules"; @@ -340,6 +341,7 @@ const StatTypeMap& GetStatTypeMap() { {"cuda_graph_id", kCudaGraphId}, {"cuda_graph_exec_id", kCudaGraphExecId}, {"cuda_graph_orig_id", kCudaGraphOrigId}, + {"step_idle_time_ps", kStepIdleTimePs}, }); DCHECK_EQ(stat_type_map->size(), kNumStatTypes); return *stat_type_map; diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h index 96ebf29d4fefff..c84c62d8c73996 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -70,6 +70,7 @@ TF_CONST_INIT extern const absl::string_view kTensorFlowNameScopeLineName; TF_CONST_INIT extern const absl::string_view kTensorFlowOpLineName; TF_CONST_INIT extern const absl::string_view kXlaModuleLineName; TF_CONST_INIT extern const absl::string_view kXlaOpLineName; +TF_CONST_INIT extern const absl::string_view kSparseCoreStepLineName; TF_CONST_INIT extern const absl::string_view kXlaAsyncOpLineName; TF_CONST_INIT extern const absl::string_view kKernelLaunchLineName; TF_CONST_INIT extern const absl::string_view kSourceLineName; @@ -328,7 +329,8 @@ enum StatType { // on the GPU device when tracing is in graph level. kCudaGraphExecId, kCudaGraphOrigId, - kLastStatType = kCudaGraphOrigId, + kStepIdleTimePs, + kLastStatType = kStepIdleTimePs, }; enum MegaScaleStatType : uint8_t { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc index 9e1632dd66b6b7..a327aab2d92c08 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc @@ -548,7 +548,8 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) { uint64_t last_op_end_ps = 0; plane.ForEachLine([&](const XLineVisitor& line) { - if (line.Name() == kStepLineName) { + if (line.Name() == kStepLineName || + line.Name() == kSparseCoreStepLineName) { XLineBuilder aggregated_line = aggregated_plane.GetOrCreateLine(line.Id()); aggregated_line.SetName(kStepLineName); From 86e1d9fcbc2d4437c6c418308e9c5dd27c48b38d Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Thu, 26 Sep 2024 10:41:55 -0700 Subject: [PATCH 330/483] Avoid compilation and return failure if a model is frozen PiperOrigin-RevId: 679207753 --- .../compiler/mlir/tfrt/transforms/ifrt/BUILD | 2 + .../transforms/ifrt/ifrt_backend_compiler.cc | 6 ++ .../ifrt/ifrt_backend_compiler_test.cc | 65 +++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 90f6d19218f51b..6339074d9a5ebe 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -267,12 +267,14 @@ tf_cc_test( "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector", "//tensorflow/core/tfrt/runtime", "//tensorflow/core/tfrt/saved_model:saved_model_testutil", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/python/ifrt", "@local_xla//xla/python/ifrt:test_util", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index 514322b7cc1dd5..1def614f8a263a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -148,6 +148,12 @@ absl::Status IfrtBackendCompiler::CompileTensorflow( "Failed to find model context for ifrt serving."); } + if ((*ifrt_model_context)->IsFrozen()) { + return absl::FailedPreconditionError( + "Cannot compile IFRT programs after the model is frozen. Please make " + "sure warmup covers all signatures by following go/tf-model-warmup."); + } + mlir::StatusScopedDiagnosticHandler diag_handler(module->getContext()); if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_before", module); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc index 3b190d326ce58f..0e33d3e74a68e0 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc @@ -16,9 +16,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h" #include +#include #include +#include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -40,6 +43,7 @@ limitations under the License. #include "tensorflow/core/tfrt/runtime/runtime.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" #include "tsl/platform/env.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tfrt/host_context/resource_context.h" // from @tf_runtime @@ -47,6 +51,8 @@ limitations under the License. namespace tensorflow { namespace ifrt_serving { namespace { +using ::testing::HasSubstr; +using ::tsl::testing::StatusIs; tsl::thread::ThreadPool& GetThreadPool() { constexpr int kMaxParallelism = 16; @@ -99,6 +105,65 @@ TEST(IfrtBackendCompilerTest, Basic) { TF_ASSERT_OK(compiler.CompileTensorflow(runtime_context, mlir_module.get())); } +TEST(IfrtBackendCompilerTest, CompileShallFailAfterModelIsFrozen) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/ifrt_cluster.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, &context); + + ASSERT_TRUE(mlir_module); + ASSERT_TRUE(mlir_module.get() != nullptr); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + std::unique_ptr runtime = + tensorflow::tfrt_stub::DefaultTfrtRuntime(/*num_threads=*/1); + tensorflow::tfrt_stub::GraphExecutionOptions graph_execution_options( + runtime.get()); + tfrt::ResourceContext resource_context; + tensorflow::tfrt_stub::ModelRuntimeContext runtime_context( + &graph_execution_options, /*export_dir=*/"", &resource_context); + + tsl::test_util::MockServingDeviceSelector mock_serving_device_selector; + IfrtServingCoreSelector core_selector(&mock_serving_device_selector, + client->addressable_device_count()); + + runtime_context.resource_context().CreateResource( + "IfrtModelContext", client, &core_selector, &GetThreadPool(), + /*compilation_environment_proto=*/nullptr); + + IfrtBackendCompiler compiler; + TF_ASSERT_OK(compiler.CompileTensorflow(runtime_context, mlir_module.get())); + + std::optional ifrt_model_context = + runtime_context.resource_context().GetResource( + "IfrtModelContext"); + ASSERT_TRUE(ifrt_model_context.has_value()); + + TF_ASSERT_OK((*ifrt_model_context)->Freeze()); + + mlir::OwningOpRef another_mlir_module = + mlir::parseSourceFile(mlir_module_path, &context); + + EXPECT_THAT( + compiler.CompileTensorflow(runtime_context, another_mlir_module.get()), + StatusIs( + absl::StatusCode::kFailedPrecondition, + HasSubstr("Cannot compile IFRT programs after the model is frozen"))); +} + } // namespace } // namespace ifrt_serving } // namespace tensorflow From 52a79643c7cc9e37453c71b6327a681dda02d398 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 10:57:58 -0700 Subject: [PATCH 331/483] We should skip async-start for a call to host computation instead of returning an error in host_offload_legalize, during the copy collection phase. PiperOrigin-RevId: 679214185 --- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/host_offload_legalize.cc | 37 +++++++++++++++---- .../xla/service/host_offload_legalize_test.cc | 30 +++++++++++++++ 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index e74dbda5c60e50..9b51d54582b12f 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6510,6 +6510,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/host_offload_legalize.cc b/third_party/xla/xla/service/host_offload_legalize.cc index e4f9a3a227cd44..e02a459e2bcb2b 100644 --- a/third_party/xla/xla/service/host_offload_legalize.cc +++ b/third_party/xla/xla/service/host_offload_legalize.cc @@ -245,7 +245,8 @@ absl::StatusOr WalkUpMemoryOffload( // instruction at a time, but returns multiple instructions for each conforming // user. absl::StatusOr> WalkDownMemoryOffload( - const InstructionAndIndex& current_value, const CallGraph& call_graph) { + const InstructionAndIndex& current_value, const CallGraph& call_graph, + bool for_move_copy_phase) { // TODO(maggioni): Verify that set of instructions supported in chain by // legalization is in sync with host_offloader. VLOG(6) << "Getting users of: \"" << current_value.instruction->ToString() @@ -348,8 +349,23 @@ absl::StatusOr> WalkDownMemoryOffload( results.emplace_back(user, current_value.index); break; } + case HloOpcode::kAsyncStart: { + if (user->async_execution_thread() == HloInstruction::kHostThread) { + // For move copy phase, we need to handle the copy even though we + // never move the tensor to device yet. For now just throw an error. + CHECK(!for_move_copy_phase) + << "Transpose copy going into host call is not supported yet."; + + // For first phase to collect copies to move, it's ok to ignore this + // path since we don't see copies along the path yet and it's ok to + // pass host tensor to the async host call. + break; + } + [[fallthrough]]; + } default: { - return absl::InvalidArgumentError("Unrecognized user opcode"); + return absl::InvalidArgumentError( + absl::StrFormat("Unrecognized user name: %s", user->name())); } } } @@ -423,11 +439,12 @@ absl::Status MoveCopy( current_instruction_and_shapes.instruction_and_index; stack.pop_back(); VLOG(5) << "Current top of stack: " - << current_instruction_and_index.instruction->ToString() << " " - << current_instruction_and_index.index; + << current_instruction_and_index.instruction->ToString() + << ", index: " << current_instruction_and_index.index; // Get the users of the current instruction. absl::StatusOr> current_value_down = - WalkDownMemoryOffload(current_instruction_and_index, *call_graph); + WalkDownMemoryOffload(current_instruction_and_index, *call_graph, + /*for_move_copy_phase=*/true); if (!current_value_down.ok()) { VLOG(5) << "WalkDownMemoryOffload failed: " << current_value_down.status(); @@ -677,7 +694,8 @@ absl::StatusOr ProcessAnnotationForCopyMovement( std::vector stack = {current_value}; while (!stack.empty()) { VLOG(5) << "Current value before down: " - << stack.back().instruction->ToString(); + << stack.back().instruction->ToString() << " " + << stack.back().index; if (absl::c_linear_search(kUsersOpcodes, stack.back().instruction->opcode()) || stack.back().instruction->IsCustomCall( @@ -737,7 +755,8 @@ absl::StatusOr ProcessAnnotationForCopyMovement( continue; } absl::StatusOr> current_value_down = - WalkDownMemoryOffload(stack.back(), *call_graph); + WalkDownMemoryOffload(stack.back(), *call_graph, + /*for_move_copy_phase=*/false); if (!current_value_down.ok()) { VLOG(5) << "Current value down failed: " << current_value_down.status(); break; @@ -758,6 +777,10 @@ absl::StatusOr ProcessAnnotationForCopyMovement( } } + if (copies_to_move.empty()) { + return false; + } + // Process all copies one at a time from the last to the first and push it to // its specific user. for (auto it = copies_to_move.rbegin(); it != copies_to_move.rend(); ++it) { diff --git a/third_party/xla/xla/service/host_offload_legalize_test.cc b/third_party/xla/xla/service/host_offload_legalize_test.cc index 55c36a5310f6db..2366a6cc97e041 100644 --- a/third_party/xla/xla/service/host_offload_legalize_test.cc +++ b/third_party/xla/xla/service/host_offload_legalize_test.cc @@ -74,6 +74,36 @@ class HostOffloadLegalizeTest : public HloTestBase { } }; +TEST_F(HostOffloadLegalizeTest, TestWithAsyncCall) { + const std::string& hlo_string = R"( +HloModule jit_update, entry_computation_layout={(f32[20,3,256,133]{2,3,1,0:T(8,128)S(5)})->(f32[20,3,256,133]{2,1,0,3:T(4,128)}, f32[4096]{0:T(1024)})} + +%async_computation { + %param_0 = f32[20,3,256,133] parameter(0) + ROOT %offloaded-custom-call = f32[4096] custom-call(%param_0), custom_call_target="HostExecute" +}, execution_thread="host" + +ENTRY main { + %param.246 = f32[20,3,256,133] parameter(0) + %async-start = ((f32[20,3,256,133]), f32[4096], u32[]) async-start(%param.246), async_execution_thread="host", calls=%async_computation + %async-done = f32[4096] custom-call-done(%async-start) + copy.16744 = f32[20,3,256,133]{2,1,0,3:T(4,128)} copy(param.246) + custom-call.7832 = f32[20,3,256,133]{2,1,0,3:T(4,128)} custom-call(copy.16744), custom_call_target="MoveToDevice" + ROOT tuple.16745 = (f32[20,3,256,133]{2,1,0,3:T(4,128)}, f32[4096]{0:T(1024)}) tuple(custom-call.7832, %async-done) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + EXPECT_TRUE(changed); + HloInstruction* custom_call = + FindInstruction(module.get(), "custom-call.7832"); + ASSERT_NE(custom_call, nullptr); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + XLA_VLOG_LINES(1, module->ToString()); +} + TEST_F(HostOffloadLegalizeTest, NoCopyWithOptBarrierMoreElaborate) { const std::string& hlo_string = R"( HloModule jit_f, entry_computation_layout={(f32[16,256]{0,1})->f32[16,256]{1,0}} From dc326974577afb565d3a40ede2ce839997b3805e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 11:04:03 -0700 Subject: [PATCH 332/483] [Refactor] Split huge BarrierAsync() method into a few helper methods. PiperOrigin-RevId: 679216991 --- .../coordination/coordination_service.cc | 173 +++++++++++------- 1 file changed, 109 insertions(+), 64 deletions(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index bb10595516c73b..90c9f0ba73078b 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -192,6 +192,20 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { // barrier). CoordinatedTask initiating_task; }; + // Validates that the barrier is invoked with the right args. Returns false if + // the barrier should fail immediately. + bool ValidateBarrierArgs( + std::string_view barrier_id, absl::Duration timeout, + const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done); + // Initializes a new barrier. Returns false if the barrier should fail + // immediately. + bool InitializeBarrier( + BarrierState* barrier, std::string_view barrier_id, + absl::Duration timeout, const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); void PassBarrier(std::string_view barrier_id, absl::Status result, BarrierState* barrier) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); @@ -1231,14 +1245,13 @@ void CoordinationServiceStandaloneImpl::PollForErrorAsync( error_polling_state_.AddTask(task, std::move(done)); } -void CoordinationServiceStandaloneImpl::BarrierAsync( +// Validates that the barrier is invoked with the right args. Returns false if +// the barrier should fail immediately. +bool CoordinationServiceStandaloneImpl::ValidateBarrierArgs( std::string_view barrier_id, absl::Duration timeout, const CoordinatedTask& task, const std::vector& participating_tasks, StatusCallback done) { - VLOG(3) << "Task " << GetTaskName(task) << " invoked BarrierAsync(" - << barrier_id << ")."; - // Check if caller task is participating in the barrier. If not, update // `barriers_` to cause subsequent calls from the same task and other tasks // that have already called this instance of the barrier to fail. @@ -1262,7 +1275,7 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( if (ServiceHasStopped()) { done(MakeCoordinationError(absl::InternalError( "Barrier requested after coordination service has shut down."))); - return; + return false; } auto pair = barriers_.try_emplace(barrier_id); auto it = pair.first; @@ -1272,8 +1285,93 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( PassBarrier(barrier_id, error, barrier); } done(error); - return; + return false; + } + return true; +}; + +// Initializes a new barrier. Returns false if the barrier should fail +// immediately. +bool CoordinationServiceStandaloneImpl::InitializeBarrier( + BarrierState* barrier, std::string_view barrier_id, absl::Duration timeout, + const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) { + // Initialize barrier state. + barrier->passed = false; + barrier->initiating_task = task; + // Assume barrier is for entire cluster if no tasks are specified. + if (participating_tasks.empty()) { + for (const auto& task_state : cluster_state_) { + std::string_view task_name = task_state.first; + barrier->tasks_at_barrier[GetTaskFromName(task_name)] = false; + } + } else { + for (const auto& task : participating_tasks) { + // Fail the barrier immediately if unexpected task is included in the + // barrier. + const std::string task_name = GetTaskName(task); + if (!cluster_state_.contains(task_name)) { + absl::Status error = MakeCoordinationError(absl::InvalidArgumentError( + absl::StrCat("Unexpected task (", task_name, + ") that is not in the cluster called the barrier. " + "Barrier Id: ", + barrier_id))); + PassBarrier(barrier_id, error, barrier); + done(error); + return false; + } + barrier->tasks_at_barrier[task] = false; + } + } + barrier->num_pending_tasks = barrier->tasks_at_barrier.size(); + + // Fail the barrier immediately if any tasks are already in error. + for (const auto& pending_task : barrier->tasks_at_barrier) { + const std::string task_name = GetTaskName(pending_task.first); + if (cluster_state_[task_name]->GetState() == + CoordinatedTaskState::TASKSTATE_ERROR) { + absl::Status error = MakeCoordinationError(absl::InternalError( + absl::StrCat("Task (", task_name, + ") is already in error before the barrier " + "was called. Barrier Id: ", + barrier_id))); + PassBarrier(barrier_id, error, barrier); + done(error); + return false; + } + } + barrier->deadline_in_micros = + Env::Default()->NowMicros() + (timeout / absl::Microseconds(1)); + + // Add ongoing barrier to cluster state. + ongoing_barriers_.emplace(barrier_id); + const size_t num_ongoing_barriers = ongoing_barriers_.size(); + if (num_ongoing_barriers > kOngoingBarriersSoftLimit) { + LOG(WARNING) << "There is a high number of ongoing barriers in " + "coordination service: " + << num_ongoing_barriers; + } + for (const auto& pending_task : barrier->tasks_at_barrier) { + const CoordinatedTask& task = pending_task.first; + cluster_state_[GetTaskName(task)]->JoinBarrier(barrier_id); } + return true; +} + +void CoordinationServiceStandaloneImpl::BarrierAsync( + std::string_view barrier_id, absl::Duration timeout, + const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) { + VLOG(3) << "Task " << GetTaskName(task) << " invoked BarrierAsync(" + << barrier_id << ")."; + + if (!ValidateBarrierArgs(barrier_id, timeout, task, participating_tasks, + done)) { + return; // Exit early if args are wrong. + } + absl::MutexLock l(&state_mu_); // Check if coordination service has stopped. If so, return an error // immediately. @@ -1282,70 +1380,17 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( "Barrier requested after coordination service has shut down."))); return; } + auto pair = barriers_.try_emplace(barrier_id); auto it = pair.first; bool inserted = pair.second; auto* barrier = &it->second; + // Create barrier for the first time. if (inserted) { - // Initialize barrier state. - barrier->passed = false; - barrier->initiating_task = task; - // Assume barrier is for entire cluster if no tasks are specified. - if (participating_tasks.empty()) { - for (const auto& task_state : cluster_state_) { - std::string_view task_name = task_state.first; - barrier->tasks_at_barrier[GetTaskFromName(task_name)] = false; - } - } else { - for (const auto& task : participating_tasks) { - // Fail the barrier immediately if unexpected task is included in the - // barrier. - const std::string task_name = GetTaskName(task); - if (!cluster_state_.contains(task_name)) { - absl::Status error = MakeCoordinationError(absl::InvalidArgumentError( - absl::StrCat("Unexpected task (", task_name, - ") that is not in the cluster called the barrier. " - "Barrier Id: ", - barrier_id))); - PassBarrier(barrier_id, error, barrier); - done(error); - return; - } - barrier->tasks_at_barrier[task] = false; - } - } - barrier->num_pending_tasks = barrier->tasks_at_barrier.size(); - - // Fail the barrier immediately if any tasks are already in error. - for (const auto& pending_task : barrier->tasks_at_barrier) { - const std::string task_name = GetTaskName(pending_task.first); - if (cluster_state_[task_name]->GetState() == - CoordinatedTaskState::TASKSTATE_ERROR) { - absl::Status error = MakeCoordinationError(absl::InternalError( - absl::StrCat("Task (", task_name, - ") is already in error before the barrier " - "was called. Barrier Id: ", - barrier_id))); - PassBarrier(barrier_id, error, barrier); - done(error); - return; - } - } - barrier->deadline_in_micros = - Env::Default()->NowMicros() + (timeout / absl::Microseconds(1)); - - // Add ongoing barrier to cluster state. - ongoing_barriers_.emplace(barrier_id); - const size_t num_ongoing_barriers = ongoing_barriers_.size(); - if (num_ongoing_barriers > kOngoingBarriersSoftLimit) { - LOG(WARNING) << "There is a high number of ongoing barriers in " - "coordination service: " - << num_ongoing_barriers; - } - for (const auto& pending_task : barrier->tasks_at_barrier) { - const CoordinatedTask& task = pending_task.first; - cluster_state_[GetTaskName(task)]->JoinBarrier(barrier_id); + if (!InitializeBarrier(barrier, barrier_id, timeout, task, + participating_tasks, done)) { + return; // Exit early if barrier init failed. } } From 71d40cdddee0cbf87e0a9d030075a94f685ea53d Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 26 Sep 2024 11:56:09 -0700 Subject: [PATCH 333/483] Add a pass to nest gemm fusions. This pass takes a fusion with a single dot and creates nested fusions for the two dot operands. The nested fusions are annotated with block_level_fusion_config specifying the tile sizes as propagated from the output through the dot to the operands. This pass is not hooked up yet other than for one initial test. The next step is to add more complex test cases and extend the implementation to handle those. PiperOrigin-RevId: 679238519 --- .../xla/xla/service/gpu/transforms/BUILD | 57 +++ .../gpu/transforms/nest_gemm_fusion.cc | 363 ++++++++++++++++++ .../service/gpu/transforms/nest_gemm_fusion.h | 49 +++ .../gpu/transforms/nest_gemm_fusion_test.cc | 109 ++++++ 4 files changed, 578 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc create mode 100644 third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h create mode 100644 third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 1a04da032cfce6..1deb2aa6dddb23 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2095,6 +2095,63 @@ xla_cc_test( ], ) +cc_library( + name = "nest_gemm_fusion", + srcs = ["nest_gemm_fusion.cc"], + hdrs = ["nest_gemm_fusion.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_dce", + "//xla/service:instruction_fusion", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", + "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "nest_gemm_fusion_test", + srcs = ["nest_gemm_fusion_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":nest_gemm_fusion", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "pipelined_p2p_rewriter", srcs = ["pipelined_p2p_rewriter.cc"], diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc new file mode 100644 index 00000000000000..d0f4f5fe97c943 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -0,0 +1,363 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/nest_gemm_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/hlo_dce.h" +#include "xla/service/instruction_fusion.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +namespace { +// Fuses the given instructions together. The instructions are expected to be +// passed in def-before-use order. The resulting fusion has a single root +// instruction, which is the last instructions in the input span. We only +// replace the uses of the root in 'consumer', and leave other users alone. +absl::Status FuseInstructionsForConsumer( + const std::vector& instructions, + HloInstruction& consumer) { + HloComputation::Builder builder(instructions.back()->name()); + + absl::flat_hash_map + old_to_new_mapping; + std::vector parameters; + + auto add_parameter = [&](HloInstruction* instruction) -> void { + int param_index = parameters.size(); + old_to_new_mapping[instruction] = + builder.AddInstruction(HloInstruction::CreateParameter( + param_index, instruction->shape(), + absl::StrCat("parameter_", param_index))); + parameters.push_back(instruction); + }; + + for (HloInstruction* instruction : instructions) { + if (old_to_new_mapping.contains(instruction)) { + continue; + } + + if (instruction->opcode() == HloOpcode::kParameter) { + add_parameter(instruction); + continue; + } + std::vector new_operands; + for (HloInstruction* operand : instruction->mutable_operands()) { + if (!old_to_new_mapping.contains(operand)) { + add_parameter(operand); + } + new_operands.push_back(old_to_new_mapping[operand]); + } + old_to_new_mapping[instruction] = builder.AddInstruction( + instruction->CloneWithNewOperands(instruction->shape(), new_operands)); + } + + HloInstruction* old_root = instructions.back(); + old_to_new_mapping[old_root]->MarkAsRoot(); + + HloComputation* computation = + old_root->GetModule()->AddComputationAndUnifyNamesAndIds( + builder.Build(), /*is_entry=*/false); + HloInstruction* fusion = + old_root->parent()->AddInstruction(HloInstruction::CreateFusion( + old_root->shape(), HloInstruction::FusionKind::kCustom, parameters, + computation)); + fusion->GetModule()->SetAndUniquifyInstrName(fusion, "block_fusion"); + + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind(std::string(kTritonFusionKind)); + TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + + for (int64_t operand_index : consumer.OperandIndices(old_root)) { + TF_RETURN_IF_ERROR(consumer.ReplaceOperandWith(operand_index, fusion)); + } + + return absl::OkStatus(); +} + +// Annotates the given nested fusion with the given tile sizes. +// Implementation for AnnotateDotLhs/RhsNestedFusion(). +absl::Status AnnotateDotOperandNestedFusionImpl( + HloFusionInstruction& nested_fusion, const HloDotInstruction& dot, + const TritonGemmConfig& config, + absl::Span contracting_dimensions, // Must be single element + absl::Span batch_dimensions, int64_t contracting_dim_size, + int64_t non_contracting_dim_size) { + if (contracting_dimensions.size() != 1) { + return absl::InternalError( + absl::StrCat("Expected a single lhs contracting dimension but got ", + contracting_dimensions.size())); + } + + TF_ASSIGN_OR_RETURN( + std::vector non_contracting_dimensions, + GetNonContractingDims(dot.operand(0)->shape(), batch_dimensions, + contracting_dimensions)); + + if (non_contracting_dimensions.size() != 1) { + return absl::InternalError( + absl::StrCat("Expected a single non-contracting dimension but got ", + non_contracting_dimensions.size())); + } + + // We have a single contracting dimension, and a single non-contracting + // dimension. All the other output tile sizes are set to 1. + std::vector output_tile_sizes(dot.operand(0)->shape().rank(), 1); + output_tile_sizes[contracting_dimensions[0]] = contracting_dim_size; + output_tile_sizes[non_contracting_dimensions[0]] = non_contracting_dim_size; + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = std::move(output_tile_sizes); + + TF_ASSIGN_OR_RETURN(auto backend_config, + nested_fusion.backend_config()); + *backend_config.mutable_fusion_backend_config() + ->mutable_block_level_fusion_config() = + block_level_parameters.ToBlockLevelFusionConfig(); + TF_RETURN_IF_ERROR(nested_fusion.set_backend_config(backend_config)); + + return absl::OkStatus(); +} + +absl::Status AnnotateDotLhsNestedFusion(HloFusionInstruction& nested_fusion, + const HloDotInstruction& dot, + const TritonGemmConfig& config) { + const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); + return AnnotateDotOperandNestedFusionImpl( + nested_fusion, dot, config, + dimension_numbers.lhs_contracting_dimensions(), + dimension_numbers.lhs_batch_dimensions(), config.block_k, config.block_m); +} + +absl::Status AnnotateDotRhsNestedFusion(HloFusionInstruction& nested_fusion, + const HloDotInstruction& dot, + const TritonGemmConfig& config) { + const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); + return AnnotateDotOperandNestedFusionImpl( + nested_fusion, dot, config, + dimension_numbers.rhs_contracting_dimensions(), + dimension_numbers.rhs_batch_dimensions(), config.block_k, config.block_n); +} + +// Finds tile sizes for the root of the analysis that satisfy the +// requirements of the dot. That is, the tile sizes need to satisfy the +// constraints of the analysis and map to the given config of the dot. +absl::StatusOr> FindOutputTileSizesForEpilogue( + const SymbolicTiledHloInstruction& tiled_dot, + const SymbolicTileAnalysis& analysis, const TritonGemmConfig& config) { + int64_t dot_rank = tiled_dot.symbolic_tile().tile_map().GetDimensionCount(); + llvm::SmallVector expected_dot_tile_sizes(dot_rank, 1); + // We always expect the shape of the dot to be [1, ..., block_m, block_n]. + expected_dot_tile_sizes[dot_rank - 2] = config.block_m; + expected_dot_tile_sizes[dot_rank - 1] = config.block_n; + + // Try all permutations of the dot tile sizes to see if any of them satisfy + // the constraints of the analysis and map to the given config of the dot. + llvm::SmallVector output_tile_sizes = expected_dot_tile_sizes; + std::sort(output_tile_sizes.begin(), output_tile_sizes.end()); + do { + TF_ASSIGN_OR_RETURN( + bool parameters_satisfy_constraints, + analysis.ParametersSatisfyConstraints(output_tile_sizes)); + if (!parameters_satisfy_constraints) { + continue; + } + auto mapped_dot_tile_sizes = tiled_dot.TileSizes(output_tile_sizes); + if (mapped_dot_tile_sizes == expected_dot_tile_sizes) { + return output_tile_sizes; + } + } while (std::next_permutation(output_tile_sizes.begin(), + output_tile_sizes.end())); + + return absl::InternalError(absl::StrCat( + "Couldn't find output tile sizes that satisfy ", tiled_dot.ToString())); +} + +// Extracts the TritonGemmConfig from the given fusion's backend config. +absl::StatusOr GetTritonGemmConfig( + const HloFusionInstruction& fusion) { + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion.backend_config()); + const FusionBackendConfig& backend_config = + gpu_config.fusion_backend_config(); + if (!backend_config.has_triton_gemm_config()) { + return absl::InternalError( + "The fusion's backend config doesn't have a triton_gemm_config."); + } + return TritonGemmConfig::FromProto(backend_config.triton_gemm_config()); +} + +// Transforms a fusion into an equivalent nested fusion if it has a single dot. +// Returns true if the transformation was successful. +absl::Status MakeNestedFusionFromGemmFusion( + HloFusionInstruction* fusion, const TritonGemmConfig& config, + const SymbolicTileAnalysis& analysis, + const SymbolicTiledHloInstruction& tiled_dot, HloDotInstruction* dot) { + DCHECK(GetTritonGemmConfig(*fusion).value() == config); + DCHECK_EQ(tiled_dot.hlo(), dot); + + HloComputation* computation = fusion->called_computation(); + + // Left-hand side of the dot. + TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( + computation->MakeInstructionPostOrderFrom(*dot->mutable_operand(0)), + *dot)); + TF_RETURN_IF_ERROR(AnnotateDotLhsNestedFusion( + *::xla::Cast(dot->mutable_operand(0)), *dot, + config)); + + // Right-hand side of the dot. + TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( + computation->MakeInstructionPostOrderFrom(*dot->mutable_operand(1)), + *dot)); + TF_RETURN_IF_ERROR(AnnotateDotRhsNestedFusion( + *::xla::Cast(dot->mutable_operand(1)), *dot, + config)); + + // Delete newly unused instructions, if any. + TF_ASSIGN_OR_RETURN([[maybe_unused]] bool changed, + HloDCE::RunOnComputation( + computation, + /*remove_cross_partition_collective_ops=*/false)); + + // Annotate the fusion itself. + TF_ASSIGN_OR_RETURN( + llvm::SmallVector output_tile_sizes, + FindOutputTileSizesForEpilogue(tiled_dot, analysis, config)); + + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind(std::string(kTritonFusionKind)); + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes.assign(output_tile_sizes.begin(), + output_tile_sizes.end()); + + *backend_config.mutable_block_level_fusion_config() = + block_level_parameters.ToBlockLevelFusionConfig(); + TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + + return absl::OkStatus(); +} + +size_t GetDotCount(HloComputation* computation) { + return absl::c_count_if(computation->instructions(), [](HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kDot; + }); +} + +class NestGemmFusionVisitor : public DfsHloRewriteVisitor { + public: + explicit NestGemmFusionVisitor(mlir::MLIRContext* ctx) : ctx_(ctx) {} + + absl::Status HandleFusion(HloInstruction* instruction) override { + HloFusionInstruction* fusion = Cast(instruction); + + absl::StatusOr config = GetTritonGemmConfig(*fusion); + if (!config.ok()) { + return absl::OkStatus(); // Skip because it's not a Triton gemm fusion. + } + + HloComputation* computation = fusion->called_computation(); + HloInstruction* dot = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + if (dot == nullptr) { + return absl::OkStatus(); // Skip because fusion has no dot. + } + DCHECK_EQ(GetDotCount(computation), 1) << "Fusion has more than one dot."; + SymbolicTileAnalysisOrError analysis_or = + SymbolicTileAnalysis::AnalyzeComputation( + *fusion->called_computations()[0], ctx_); + + if (std::holds_alternative(analysis_or)) { + return absl::InternalError( + absl::StrCat("Failed to analyze the computation (", + std::get(analysis_or).Explain(), + "): ", fusion->called_computation()->ToString())); + } + + auto& analysis = std::get(analysis_or); + auto tiled_dot_it = absl::c_find_if( + analysis.GetSymbolicTiledHloComputation(), + [&](const auto& tiled_hlo) { return tiled_hlo->hlo() == dot; }); + if (tiled_dot_it == analysis.GetSymbolicTiledHloComputation().end()) { + return absl::InternalError(absl::StrCat( + "Couldn't find a symbolic tiled instruction for ", dot->ToString())); + } + + TF_RETURN_IF_ERROR(MakeNestedFusionFromGemmFusion( + fusion, config.value(), analysis, **tiled_dot_it, + Cast(dot))); + this->MarkAsChanged(); + return absl::OkStatus(); + } + + private: + mlir::MLIRContext* ctx_; +}; + +} // namespace + +absl::StatusOr NestGemmFusion::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + mlir::MLIRContext ctx; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + NestGemmFusionVisitor visitor(&ctx); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + changed |= visitor.changed(); + } + return changed; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h new file mode 100644 index 00000000000000..aee2ece23afd33 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_NEST_GEMM_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_NEST_GEMM_FUSION_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla::gpu { + +// Rewrites Triton GEMM fusions to generic Triton fusions. Any other fusions are +// left unchanged. +// +// The fusion's backend config is set to a BlockLevelFusionConfig, derived from +// a previously set TritonGemmConfig. +// +// The operands of the dot (including their prologues) are fused into two new +// nested fusions, each with their own BlockLevelFusionConfig. +class NestGemmFusion : public HloModulePass { + public: + absl::string_view name() const override { return "nest_gemm_fusion"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_NEST_GEMM_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc new file mode 100644 index 00000000000000..fbe05e1a41ede6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/nest_gemm_fusion.h" + +#include + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +using ::testing::ElementsAre; + +namespace xla { + +// Gtest hook to pretty-print an HloInstruction. +static void PrintTo(const HloInstruction& hlo, std::ostream* os) { + *os << hlo.ToString(); +} + +namespace gpu { +namespace { + +// Wraps a matcher for a fusion instruction's output tile sizes. +// Proto matchers would be nice, but b/229726259 is P2. +MATCHER_P(OutputTileSizesIs, matcher, "") { + auto backend_config = arg.template backend_config(); + if (!backend_config.ok()) { + *result_listener << "failed to get backend config: " + << backend_config.status(); + return false; + } + FusionBackendConfig fusion_backend_config = + backend_config->fusion_backend_config(); + if (!fusion_backend_config.has_block_level_fusion_config()) { + *result_listener << "has no block level fusion config"; + return false; + } + auto output_tile_sizes = + fusion_backend_config.block_level_fusion_config().output_tile_sizes(); + return ExplainMatchResult(matcher, output_tile_sizes, result_listener); +} + +class NestGemmFusionTest : public HloTestBase {}; + +TEST_F(NestGemmFusionTest, BasicTest) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule module + +dot { + lhs = bf16[8192,512] parameter(0) + rhs = bf16[512,512] parameter(1) + ROOT %dot = bf16[8192,512] dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = bf16[8192,512] parameter(0) + p1 = bf16[512,512] parameter(1) + ROOT fusion = bf16[8192,512] fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={ + "fusion_backend_config": { + "kind":"__triton_gemm", "triton_gemm_config": { + "block_m":"64", "block_n":"256", "block_k":"32", + "split_k":"1", "num_stages":"1", "num_warps":"1", "num_ctas":"1" + } + } + } +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, NestGemmFusion().Run(module.get())) + EXPECT_TRUE(changed); + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + const HloInstruction* fusion = nullptr; + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(match::Fusion(&fusion))); + EXPECT_THAT(*fusion, OutputTileSizesIs(ElementsAre(64, 256))); + + const HloInstruction* lhs = nullptr; + const HloInstruction* rhs = nullptr; + EXPECT_THAT(fusion->fused_expression_root(), + GmockMatch(match::Dot(match::Fusion(&lhs), match::Fusion(&rhs)))); + EXPECT_THAT(*lhs, OutputTileSizesIs(ElementsAre(64, 32))); + EXPECT_THAT(*rhs, OutputTileSizesIs(ElementsAre(32, 256))); +} + +} // namespace +} // namespace gpu +} // namespace xla From caec1524d60111c39bf3f7280566d69b402ba74f Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 26 Sep 2024 12:07:02 -0700 Subject: [PATCH 334/483] [xla:spmd:shardy:nfc] Make xla_dump_to=sponge work. PiperOrigin-RevId: 679243179 --- .../xla/xla/service/spmd/shardy/shardy_xla_pass.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc index 19ed5d9a95292e..4bd800ff179fb3 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/spmd/shardy/shardy_xla_pass.h" #include +#include #include #include #include @@ -308,11 +309,22 @@ absl::StatusOr ShardyXLA::Run( /*flatten_computation_args_result=*/true)); std::string shardyDir = hloModule->config().debug_options().xla_dump_to(); + + if (shardyDir == "sponge") { + shardyDir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); + if (shardyDir.empty()) { + LOG(WARNING) << "\"sponge\" specified as dump directory but " + "TEST_UNDECLARED_OUTPUTS_DIR is not set!"; + } + } + if (!shardyDir.empty()) { shardyDir = tsl::io::JoinPath(shardyDir, "shardy", std::string_view(mlirModule->getName().value_or(""))); + LOG(INFO) << "Using Shardy output directory: " << shardyDir; } + // MLIR pipeline: (1) import, (2) Shardy, and (3) export. bool enableVerifier = false; From e221e79d12af396c5bf76c64596af80370356744 Mon Sep 17 00:00:00 2001 From: Luke Boyer Date: Thu, 26 Sep 2024 12:32:15 -0700 Subject: [PATCH 335/483] QNN plugin scaffolding PiperOrigin-RevId: 679251880 --- .../lite/experimental/lrt/c/lite_rt_common.h | 4 +- tensorflow/lite/experimental/lrt/qnn/BUILD | 121 ++++++++++ .../lite/experimental/lrt/qnn/load_sdk.cc | 77 +++++++ .../lite/experimental/lrt/qnn/load_sdk.h | 76 +++++++ tensorflow/lite/experimental/lrt/qnn/log.cc | 89 ++++++++ tensorflow/lite/experimental/lrt/qnn/log.h | 40 ++++ tensorflow/lite/experimental/lrt/qnn/lrt.lds | 18 ++ .../lrt/qnn/qnn_compiler_plugin.cc | 205 +++++++++++++++++ .../qnn/qnn_compiler_plugin_manual_test.cc | 113 ++++++++++ .../lite/experimental/lrt/qnn/qnn_manager.cc | 210 ++++++++++++++++++ .../lite/experimental/lrt/qnn/qnn_manager.h | 130 +++++++++++ .../experimental/lrt/test_data/one_mul.mlir | 6 + 12 files changed, 1088 insertions(+), 1 deletion(-) create mode 100644 tensorflow/lite/experimental/lrt/qnn/BUILD create mode 100644 tensorflow/lite/experimental/lrt/qnn/load_sdk.cc create mode 100644 tensorflow/lite/experimental/lrt/qnn/load_sdk.h create mode 100644 tensorflow/lite/experimental/lrt/qnn/log.cc create mode 100644 tensorflow/lite/experimental/lrt/qnn/log.h create mode 100644 tensorflow/lite/experimental/lrt/qnn/lrt.lds create mode 100644 tensorflow/lite/experimental/lrt/qnn/qnn_compiler_plugin.cc create mode 100644 tensorflow/lite/experimental/lrt/qnn/qnn_compiler_plugin_manual_test.cc create mode 100644 tensorflow/lite/experimental/lrt/qnn/qnn_manager.cc create mode 100644 tensorflow/lite/experimental/lrt/qnn/qnn_manager.h create mode 100644 tensorflow/lite/experimental/lrt/test_data/one_mul.mlir diff --git a/tensorflow/lite/experimental/lrt/c/lite_rt_common.h b/tensorflow/lite/experimental/lrt/c/lite_rt_common.h index 423f208e90d7e0..cab430bdedc9fc 100644 --- a/tensorflow/lite/experimental/lrt/c/lite_rt_common.h +++ b/tensorflow/lite/experimental/lrt/c/lite_rt_common.h @@ -31,6 +31,7 @@ LITE_RT_DEFINE_HANDLE(LrtStatus); typedef enum { kLrtStatusOk = 0, + // Generic errors. kLrtStatusErrorInvalidArgument = 1, kLrtStatusErrorMemoryAllocationFailure = 2, kLrtStatusErrorRuntimeFailure = 3, @@ -38,9 +39,10 @@ typedef enum { kLrtStatusErrorUnsupported = 5, kLrtStatusErrorNotFound = 6, - // File related errors. + // File and loading related errors. kLrtStatusBadFileOp = 500, kLrtStatusFlatbufferFailedVerify = 501, + kLrtStatusDynamicLoadErr = 502, // IR related errors. kLrtParamIndexOOB = 1000, diff --git a/tensorflow/lite/experimental/lrt/qnn/BUILD b/tensorflow/lite/experimental/lrt/qnn/BUILD new file mode 100644 index 00000000000000..4ad95b544e01c2 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/BUILD @@ -0,0 +1,121 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], +) + +# TODO: Put common copts/linkopts behind custom macro/rule. +QNN_SDK_LINKOPTS = [ + "-Wl,--dynamic-linker=/lib64/ld-linux-x86-64.so.2", + "-Wl,--version-script=$(location {})".format(":lrt.lds"), +] + +# NOTE: Currently the user needs to supply shared libraries from their own +# SDK. + +# copybara:uncomment QNN_SDK_LIB_HTP_LINUX = "//third_party/qairt:lib/x86_64-linux-clang/libQnnHtp.so" + +cc_library( + name = "qnn_sdk_libcpp", + srcs = [ + # copybara:uncomment "//third_party/qairt:lib/x86_64-linux-clang/libc++.so.1", + # copybara:uncomment "//third_party/qairt:lib/x86_64-linux-clang/libc++abi.so.1", + ], + # TODO: Replace no builder with constraint on only linux systems. + tags = ["nobuilder"], +) + +cc_library( + name = "load_sdk", + srcs = ["load_sdk.cc"], + hdrs = ["load_sdk.h"], + # copybara:uncomment data = [QNN_SDK_LIB_HTP_LINUX], + # copybara:uncomment defines = ["QNN_SDK_LIB_HTP=\\\"$(location {})\\\"".format(QNN_SDK_LIB_HTP_LINUX)], + linkopts = [ + "-ldl", + ], + tags = ["nobuilder"], + deps = [ + ":qnn_sdk_libcpp", # buildcleaner: keep + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "//third_party/qairt:qnn_lib_headers", + ], +) + +cc_library( + name = "qnn_manager", + srcs = ["qnn_manager.cc"], + hdrs = ["qnn_manager.h"], + tags = ["nobuilder"], + deps = [ + ":load_sdk", + ":log", + ":qnn_sdk_libcpp", # buildcleaner: keep + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt:qnn_lib_headers", + "//tensorflow/lite/experimental/lrt/core:api_internal", + ], +) + +cc_library( + name = "log", + srcs = ["log.cc"], + hdrs = ["log.h"], + tags = ["nobuilder"], + deps = [ + ":qnn_sdk_libcpp", # buildcleaner: keep + # copybara:uncomment "//third_party/qairt:qnn_lib_headers", + ], +) + +cc_library( + name = "qnn_compiler_plugin", + srcs = ["qnn_compiler_plugin.cc"], + hdrs = ["//tensorflow/lite/experimental/lrt/c:lite_rt_compiler_plugin.h"], + linkopts = QNN_SDK_LINKOPTS, + tags = ["nobuilder"], + deps = [ + ":lrt.lds", + ":qnn_manager", + ":qnn_sdk_libcpp", # buildcleaner: keep + "//tensorflow/lite/experimental/lrt/core:api_internal", + "//tensorflow/lite/experimental/lrt/core:graph_tools", + ], +) + +# NOTE: cc_test realistically will never link properly with the extra qnn libc++ needed. Need to find +# a long term solution for testing and packaging everything. +# TODO: In the short term we can wrap this in an "sh_test". +cc_binary( + name = "qnn_compiler_pugin_manual_test", + testonly = 1, + srcs = [ + "qnn_compiler_plugin_manual_test.cc", + ], + linkopts = QNN_SDK_LINKOPTS, + tags = ["nobuilder"], + deps = [ + ":lrt.lds", + ":qnn_compiler_plugin", # buildcleaner: keep + ":qnn_sdk_libcpp", # buildcleaner: keep + "//tensorflow/lite/experimental/lrt/core:api_internal", + "//tensorflow/lite/experimental/lrt/core:graph_tools", + "//tensorflow/lite/experimental/lrt/core:model", + "//tensorflow/lite/experimental/lrt/test_data:test_data_util", + "@com_google_absl//absl/log:absl_check", + ], +) diff --git a/tensorflow/lite/experimental/lrt/qnn/load_sdk.cc b/tensorflow/lite/experimental/lrt/qnn/load_sdk.cc new file mode 100644 index 00000000000000..d4c2381ddeca39 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/load_sdk.cc @@ -0,0 +1,77 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/lrt/qnn/load_sdk.h" + +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" + +namespace qnn::load { + +void DumpDlInfo(void* handle) { + std::cerr << "--- Dyn Load Info ---\n"; + + Lmid_t dl_ns_idx; + if (0 != ::dlinfo(handle, RTLD_DI_LMID, &dl_ns_idx)) { + return; + } + + std::string dl_origin; + dl_origin.resize(245); + if (0 != ::dlinfo(handle, RTLD_DI_ORIGIN, dl_origin.data())) { + return; + } + + link_map* lm; + if (0 != ::dlinfo(handle, RTLD_DI_LINKMAP, &lm)) { + return; + } + + std::cerr << "DL namespace: " << dl_ns_idx << "\n"; + std::cerr << "DL origin: " << dl_origin << "\n"; + + std::cerr << "loaded objects:\n"; + + auto* forward = lm->l_next; + auto* backward = lm->l_prev; + + while (forward != nullptr) { + std::cerr << " " << forward->l_name << "\n"; + forward = forward->l_next; + } + + std::cerr << "*** " << lm->l_name << "\n"; + + while (backward != nullptr) { + std::cerr << " " << backward->l_name << "\n"; + backward = backward->l_prev; + } +} + +void* LoadSO(absl::string_view so_path) { + void* lib_handle = + ::dlopen(so_path.data(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND); + if (lib_handle == nullptr) { + std::cerr << "Failed to load so at path: " << so_path + << " with err: " << ::dlerror() << "\n"; + } + return lib_handle; +} + +} // namespace qnn::load diff --git a/tensorflow/lite/experimental/lrt/qnn/load_sdk.h b/tensorflow/lite/experimental/lrt/qnn/load_sdk.h new file mode 100644 index 00000000000000..b29c975f5fbda4 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/load_sdk.h @@ -0,0 +1,76 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_LOAD_SDK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_LOAD_SDK_H_ + +#include +#include + +#include + +#include "absl/strings/string_view.h" +#include "third_party/qairt/include/QNN/QnnCommon.h" +#include "third_party/qairt/include/QNN/QnnInterface.h" + +#ifndef QNN_SDK_LIB_HTP + +// If path not provided, check current directory. +constexpr absl::string_view kLibQnnHtpSo = "libQnnHtp.so"; +#else + +constexpr absl::string_view kLibQnnHtpSo = QNN_SDK_LIB_HTP; +#endif + +namespace qnn::load { + +// +// QNN Specific Data and Types +// + +// This is one of two qnns symbol that needs sym. It is used to populate +// pointers to related available qnn functions. +constexpr char kLibQnnGetProvidersSymbol[] = "QnnInterface_getProviders"; + +// Type definition for the QnnInterface_getProviders symbol. +typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)( + const QnnInterface_t*** provider_list, uint32_t* num_providers); + +// +// Wrappers for Dynamic Linking +// + +// Loads (qnn) shared library at given path, returning handle on success +// and nullptr on failure. +void* LoadSO(absl::string_view so_path); + +// Dumps info relavant to dynamic loading of given loaded so handle. +void DumpDlInfo(void* lib_handle); + +// Resolves a named symbol from given loaded so handle of type SymbolT. Returns +// nullptr on failure. +template +inline static SymbolT ResolveQnnSymbol(void* lib_handle, + absl::string_view symbol) { + SymbolT ptr = (SymbolT)::dlsym(lib_handle, symbol.data()); + if (ptr == nullptr) { + std::cerr << "Failed to resolve symbol: " << symbol << " with err " + << ::dlerror() << "\n"; + } + return ptr; +} + +} // namespace qnn::load + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_LOAD_SDK_H_ diff --git a/tensorflow/lite/experimental/lrt/qnn/log.cc b/tensorflow/lite/experimental/lrt/qnn/log.cc new file mode 100644 index 00000000000000..47cde541ee77d0 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/log.cc @@ -0,0 +1,89 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/lrt/qnn/log.h" + +#include +#include +#include +#include + +#include "third_party/qairt/include/QNN/QnnInterface.h" +#include "third_party/qairt/include/QNN/QnnLog.h" + +namespace qnn::log { +namespace { + +// NOLINTBEGIN +constexpr char kProviderDumpTemplate[] = + "\ +PROVIDER LOADED\n\ +name: %s\n\ +backend_id: %u\n\ +core_api_version: %u.%u.%u\n\ +backend_api_version: %u.%u.%u\n\ +has_createLog_function: %d\n\ +has_createBackend_function: %d\n"; +// NOLINTEND + +void DefaultStdOutLogger(const char* fmt, QnnLog_Level_t level, + uint64_t timestamp, va_list argp) { + const char* levelStr = ""; + switch (level) { + case QNN_LOG_LEVEL_ERROR: + levelStr = " ERROR "; + break; + case QNN_LOG_LEVEL_WARN: + levelStr = "WARNING"; + break; + case QNN_LOG_LEVEL_INFO: + levelStr = " INFO "; + break; + case QNN_LOG_LEVEL_DEBUG: + levelStr = " DEBUG "; + break; + case QNN_LOG_LEVEL_VERBOSE: + levelStr = "VERBOSE"; + break; + case QNN_LOG_LEVEL_MAX: + levelStr = "UNKNOWN"; + break; + } + char buffer1[256]; + char buffer2[256]; + double ms = timestamp; + snprintf(buffer1, sizeof(buffer1), "%8.1fms [%-7s] ", ms, levelStr); + buffer1[sizeof(buffer1) - 1] = 0; + vsnprintf(buffer2, sizeof(buffer2), fmt, argp); + buffer2[sizeof(buffer1) - 2] = 0; + std::cout << buffer1 << buffer2; +} + +} // namespace + +void DumpInterface(const QnnInterface_t* interface) { + const auto core_version = interface->apiVersion.coreApiVersion; + const auto backend_version = interface->apiVersion.backendApiVersion; + + fprintf(stderr, kProviderDumpTemplate, interface->providerName, + interface->backendId, core_version.major, core_version.minor, + core_version.patch, backend_version.major, backend_version.minor, + backend_version.patch, + interface->QNN_INTERFACE_VER_NAME.logCreate != nullptr, + interface->QNN_INTERFACE_VER_NAME.backendCreate != nullptr); +} + +QnnLog_Callback_t GetDefaultStdOutLogger() { return DefaultStdOutLogger; } + +} // namespace qnn::log diff --git a/tensorflow/lite/experimental/lrt/qnn/log.h b/tensorflow/lite/experimental/lrt/qnn/log.h new file mode 100644 index 00000000000000..a0e5fe1c7efdd0 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/log.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_LOG_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_LOG_H_ + +#include "third_party/qairt/include/QNN/QnnInterface.h" +#include "third_party/qairt/include/QNN/QnnLog.h" + +namespace qnn::log { + +// +// Standalone Dump/Log Funcitonality +// + +// Prints details about this interface. +void DumpInterface(const QnnInterface_t* interface); + +// +// QNN SDK Usage +// + +// Gets a default logger implementation to stdout. +// This is used when initializing qnn logging. +QnnLog_Callback_t GetDefaultStdOutLogger(); + +} // namespace qnn::log + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_LOG_H_ diff --git a/tensorflow/lite/experimental/lrt/qnn/lrt.lds b/tensorflow/lite/experimental/lrt/qnn/lrt.lds new file mode 100644 index 00000000000000..caeb91dbecf751 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/lrt.lds @@ -0,0 +1,18 @@ +VERS_1.0 { + global: + Lrt*; + GetModelNumSubgraphs*; + GetModelSubgraph*; + GetOpCode*; + GetOpInputs*; + GetTensorTypeId*; + GetRankedTensorType*; + GetOpOutputs*; + GetSubgraphOps*; + PushOp*; + GetSubgraphInputs*; + GetSubgraphOutputs*; + + local: + *; +}; \ No newline at end of file diff --git a/tensorflow/lite/experimental/lrt/qnn/qnn_compiler_plugin.cc b/tensorflow/lite/experimental/lrt/qnn/qnn_compiler_plugin.cc new file mode 100644 index 00000000000000..ae1fc768bd50e0 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/qnn_compiler_plugin.cc @@ -0,0 +1,205 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/graph_tools.h" +#include "tensorflow/lite/experimental/lrt/qnn/qnn_manager.h" + +using ::qnn::QnnManager; + +// +// Configurations +// + +constexpr char kPluginMan[] = "QNN"; +constexpr char kPluginModel[] = "HTP_Reference"; + +const char* LrtPluginSocManufacturer() { return kPluginMan; } + +lrt_param_index_t LrtPluginNumSupportedSocModels( + LrtCompilerPlugin compiler_plugin) { + return 1; +} + +LrtStatus LrtPluginGetSupportedSocModelId(LrtCompilerPlugin compiler_plugin, + lrt_param_index_t config_idx, + const char** config_id) { + if (config_idx != 0) { + return StatusCreate(kLrtStatusErrorUnsupported); + } + *config_id = kPluginModel; + return StatusOk(); +} + +// +// Compiled Result Definition +// + +struct LrtCompiledResultT { + std::vector context_bin; + std::vector graph_names; +}; + +LrtStatus LrtCompiledResultGetByteCode(LrtCompiledResult compiled_result, + const void** byte_code, + size_t* byte_code_size) { + *byte_code = compiled_result->context_bin.data(); + *byte_code_size = compiled_result->context_bin.size(); + return StatusOk(); +} + +LrtStatus LrtCompiledResultGetCallInfo(LrtCompiledResult compiled_result, + lrt_param_index_t call_idx, + const void** call_info, + size_t* call_info_size) { + if (call_idx >= compiled_result->graph_names.size()) { + return StatusCreate(kLrtParamIndexOOB); + } + + *call_info = compiled_result->graph_names.at(call_idx).data(); + *call_info_size = compiled_result->graph_names.at(call_idx).size(); + + return StatusOk(); +} + +LrtStatus LrtCompiledResultGetNumCalls(LrtCompiledResult compiled_result, + lrt_param_index_t* num_calls) { + *num_calls = compiled_result->graph_names.size(); + return StatusOk(); +} + +void LrtCompiledResultDestroy(LrtCompiledResult compiled_result) { + delete compiled_result; +} + +// +// Plugin Definition +// + +// Plugins can hold state. +struct LrtCompilerPluginT { + QnnManager qnn; +}; + +LrtStatus LrtPluginInit(LrtCompilerPlugin* compiler_plugin) { + auto* plugin = new LrtCompilerPluginT; + LRT_RETURN_STATUS_IF_NOT_OK(qnn::SetupAll(plugin->qnn)); + *compiler_plugin = plugin; + return StatusOk(); +} + +void LrtPluginDestroy(LrtCompilerPlugin compiler_plugin) { + delete compiler_plugin; +} + +bool IsOpSupported(LrtOp op) { + using TyInfo = graph_tools::RankedTypeInfo; + + // NOTE: Currently we are demoing by just mapping simple f32 mul ops. + // In the limit this function withh want to leverage QNN SDK's getSuportedOps + // feature (along with our op/type mappings). + + static const TyInfo supported_op_type = {kLrtElementTypeFloat32, {2, 2}}; + return graph_tools::MatchOpType(op, {supported_op_type, supported_op_type}, + {supported_op_type}, kLrtOpCodeTflMul); +} + +LrtStatus LrtPluginPartitionModel(LrtCompilerPlugin compiler_plugin, + LrtModel model, LrtOpList selected_ops) { + LRT_ASSIGN_OR_RETURN_STATUS(auto subgraph, graph_tools::GetSubgraph(model)); + LRT_ASSIGN_OR_RETURN_STATUS(auto ops, graph_tools::GetSubgraphOps(subgraph)); + + for (auto op : ops) { + if (!IsOpSupported(op)) { + continue; + } + + LRT_RETURN_STATUS_IF_NOT_OK(PushOp(selected_ops, op)); + } + + return StatusOk(); +} + +// Composes a QNN graph with the context inside qnn from subgraph. On success, +// will write the QNN graph name (entry point) to output param. +LrtStatus ComposeGraph(QnnManager& qnn, LrtSubgraph subgraph, + std::string& qnn_graph_name) { + // TODO: Implement this. + qnn_graph_name = "Unimplemented_QNN_Graph"; + return StatusOk(); +} + +LrtStatus LrtPluginCompile(LrtCompilerPlugin compiler_plugin, + LrtSubgraphArray partitions, + lrt_param_index_t num_partitions, + LrtCompiledResult* compiled_result) { + // NOTE: Currently we are demoing by just handling a simple case where + // there is one partitions and the partitions is as follows: + + // func(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) + // %0 = tfl.mul(%arg0, %arg1) + // return %0 + + if (num_partitions != 1) { + std::cerr << "Only 1 partition currently supported.\n"; + return StatusCreate(kLrtStatusErrorUnsupported); + } + auto subgraph = partitions[0]; + + LRT_ASSIGN_OR_RETURN_STATUS(auto inputs, + graph_tools::GetSubgraphInputs(subgraph)); + if (inputs.size() != 2) { + std::cerr << "Only 2 inputs currently supported\n"; + return StatusCreate(kLrtStatusErrorUnsupported); + } + + LRT_ASSIGN_OR_RETURN_STATUS(auto outputs, + graph_tools::GetSubgraphOutputs(subgraph)); + if (outputs.size() != 1) { + std::cerr << "Only 1 output currently supported\n"; + return StatusCreate(kLrtStatusErrorUnsupported); + } + + LRT_ASSIGN_OR_RETURN_STATUS(auto ops, graph_tools::GetSubgraphOps(subgraph)); + if (ops.size() != 1) { + std::cerr << "Only one op subgraphs supported\n"; + return StatusCreate(kLrtStatusErrorUnsupported); + } + + LrtCompiledResult result = new LrtCompiledResultT; + result->graph_names.reserve(num_partitions); + + LRT_RETURN_STATUS_IF_NOT_OK(ComposeGraph(compiler_plugin->qnn, subgraph, + result->graph_names.emplace_back())); + + LRT_RETURN_STATUS_IF_NOT_OK( + compiler_plugin->qnn.GenerateContextBin(result->context_bin)); + + *compiled_result = result; + + return StatusOk(); +} diff --git a/tensorflow/lite/experimental/lrt/qnn/qnn_compiler_plugin_manual_test.cc b/tensorflow/lite/experimental/lrt/qnn/qnn_compiler_plugin_manual_test.cc new file mode 100644 index 00000000000000..55efd978138525 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/qnn_compiler_plugin_manual_test.cc @@ -0,0 +1,113 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/graph_tools.h" +#include "tensorflow/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/lrt/test_data/test_data_util.h" + +typedef void (*TestFunc)(); + +namespace { + +UniqueLrtCompilerPlugin GetQnnPlugin() { + LrtCompilerPlugin qnn_plugin; + LRT_CHECK_STATUS_OK(LrtPluginInit(&qnn_plugin)); + ABSL_CHECK_NE(qnn_plugin, nullptr); + return UniqueLrtCompilerPlugin(qnn_plugin); +} + +void TestQnnPlugin_GetConfigInfo() { + ABSL_CHECK_STREQ(LrtPluginSocManufacturer(), "QNN"); + + auto plugin = GetQnnPlugin(); + + ABSL_CHECK_EQ(1, LrtPluginNumSupportedSocModels(plugin.get())); + + const char* config_id; + LRT_CHECK_STATUS_OK( + LrtPluginGetSupportedSocModelId(plugin.get(), 0, &config_id)); + ABSL_CHECK_STREQ(config_id, "HTP_Reference"); +} + +void TestQnnPluginPartition_PartitionMulOps() { + auto plugin = GetQnnPlugin(); + auto model = LoadTestFileModel("one_mul.tflite"); + + LrtOpListT selected_ops; + LRT_CHECK_STATUS_OK( + LrtPluginPartitionModel(plugin.get(), model.get(), &selected_ops)); + + ABSL_CHECK_EQ(selected_ops.ops.size(), 1); +} + +void TestQnnPluginCompile_CompileMulSubgraph() { + auto plugin = GetQnnPlugin(); + auto model = LoadTestFileModel("one_mul.tflite"); + + auto result = ::graph_tools::GetSubgraph(model.get()); + ABSL_CHECK(result.HasValue()); + auto subgraph = result.Value(); + + LrtCompiledResult compiled; + LRT_CHECK_STATUS_OK(LrtPluginCompile(plugin.get(), &subgraph, 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + + LRT_CHECK_STATUS_OK( + LrtCompiledResultGetByteCode(compiled, &byte_code, &byte_code_size)); + + std::string byte_code_string(reinterpret_cast(byte_code), + byte_code_size); + ABSL_CHECK(!byte_code_string.empty()); + + const void* op_data; + size_t op_data_size; + + LRT_CHECK_STATUS_OK( + LrtCompiledResultGetCallInfo(compiled, 0, &op_data, &op_data_size)); + + std::string op_data_string(reinterpret_cast(op_data), + op_data_size); + ABSL_CHECK_EQ("Unimplemented_QNN_Graph", op_data_string); + + LrtCompiledResultDestroy(compiled); +} + +} // namespace + +void ExecuteSuite() { + static const TestFunc suite[] = {TestQnnPlugin_GetConfigInfo, + TestQnnPluginPartition_PartitionMulOps, + TestQnnPluginCompile_CompileMulSubgraph}; + + std::cerr << "RUNNING SUITE\n"; + for (const auto& t : suite) { + t(); + } + std::cerr << "SUCCESS\n"; +} + +int main(int argc, char* argv[]) { + ExecuteSuite(); + return 0; +} diff --git a/tensorflow/lite/experimental/lrt/qnn/qnn_manager.cc b/tensorflow/lite/experimental/lrt/qnn/qnn_manager.cc new file mode 100644 index 00000000000000..5fcf2ff53ab2e4 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/qnn_manager.cc @@ -0,0 +1,210 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/lrt/qnn/qnn_manager.h" + +#include +#include +#include + +#include "absl/types/span.h" +#include "third_party/qairt/include/QNN/QnnCommon.h" +#include "third_party/qairt/include/QNN/QnnInterface.h" +#include "third_party/qairt/include/QNN/QnnLog.h" +#include "third_party/qairt/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/qnn/load_sdk.h" +#include "tensorflow/lite/experimental/lrt/qnn/log.h" + +using ::qnn::load::QnnInterfaceGetProvidersFn_t; + +namespace qnn { + +namespace { + +absl::Span LoadProvidersFromLib(void* lib_so) { + load::QnnInterfaceGetProvidersFn_t get_providers = nullptr; + get_providers = load::ResolveQnnSymbol( + lib_so, load::kLibQnnGetProvidersSymbol); + if (get_providers == nullptr) { + std::cerr << "Failed to resolve get providers symbol\n"; + return {}; + } + + const QnnInterface_t** interface_providers = nullptr; + uint32_t num_providers = 0; + if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { + std::cerr << "Failed to get providers\n"; + return {}; + } + + return absl::MakeSpan(interface_providers, num_providers); +} + +} // namespace + +LrtStatus QnnManager::LoadLibSO(absl::string_view path) { + lib_so_ = load::LoadSO(path); + if (lib_so_ == nullptr) { + return StatusCreate(kLrtStatusDynamicLoadErr); + } + return StatusOk(); +} + +void QnnManager::DumpLibSODetails() const { + if (lib_so_ == nullptr) { + return; + } + load::DumpDlInfo(lib_so_); +} + +// TODO: Repace QnnManager::Funcs with indirection access operator. +const QnnFunctionPointers* QnnManager::API() const { + if (interface_ == nullptr) { + return nullptr; + } + return &interface_->QNN_INTERFACE_VER_NAME; +} + +LrtStatus QnnManager::ResolveFuncs() { + if (lib_so_ == nullptr) { + std::cerr << "Cannot resolve functions: libQnn*.so has not been loaded.\n"; + return StatusCreate(kLrtStatusDynamicLoadErr); + } + + auto providers = LoadProvidersFromLib(lib_so_); + for (const auto& prov : providers) { + const bool major = + prov->apiVersion.coreApiVersion.major == QNN_API_VERSION_MAJOR; + + const bool minor = + prov->apiVersion.coreApiVersion.minor == QNN_API_VERSION_MINOR; + + const bool patch = + prov->apiVersion.coreApiVersion.patch == QNN_API_VERSION_PATCH; + + if (major && minor && patch) { + interface_ = prov; + break; + } + } + + if (interface_ == nullptr) { + std::cerr << "No valid interface was provided\n"; + return StatusCreate(kLrtStatusDynamicLoadErr); + } + + return StatusOk(); +} + +void QnnManager::DumpProviderDetails() const { + if (interface_ == nullptr) { + return; + } + log::DumpInterface(interface_); +} + +LrtStatus QnnManager::FreeLogging() { + if (log_handle_ != nullptr) { + if (QNN_SUCCESS != API()->logFree(log_handle_)) { + std::cerr << "Failed to free logging\n"; + return StatusCreate(kLrtStatusErrorNotFound); + } + } + log_handle_ = nullptr; + return StatusOk(); +} + +LrtStatus QnnManager::FreeBackend() { + if (backend_handle_ != nullptr) { + if (QNN_SUCCESS != API()->backendFree(backend_handle_)) { + std::cerr << "Failed to free backend\n"; + return StatusCreate(kLrtStatusErrorNotFound); + } + } + backend_handle_ = nullptr; + return StatusOk(); +} + +LrtStatus QnnManager::FreeContext() { + if (context_handle_ != nullptr) { + if (QNN_SUCCESS != API()->contextFree(context_handle_, nullptr)) { + std::cerr << "Failed to free context\n"; + return StatusCreate(kLrtStatusErrorNotFound); + } + } + context_handle_ = nullptr; + return StatusOk(); +} + +LrtStatus QnnManager::GenerateContextBin(std::vector& buffer) { + Qnn_ContextBinarySize_t bin_size = 0; + if (QNN_SUCCESS != API()->contextGetBinarySize(ContextHandle(), &bin_size)) { + std::cerr << "Failed to get context bin size\n"; + return StatusCreate(kLrtStatusErrorNotFound); + } + buffer.clear(); + buffer.resize(bin_size); + + Qnn_ContextBinarySize_t written_bin_size = 0; + if (QNN_SUCCESS != API()->contextGetBinary(ContextHandle(), buffer.data(), + buffer.size(), + &written_bin_size)) { + std::cerr << "Failed to generated context binary \n"; + return StatusCreate(kLrtStatusErrorNotFound); + } + + std::cerr << "Serialized a context bin of size (bytes): " << written_bin_size + << "\n"; + + return StatusOk(); +} + +LrtStatus SetupAll(QnnManager& qnn) { + LRT_RETURN_STATUS_IF_NOT_OK(qnn.LoadLibSO(kLibQnnHtpSo)); + qnn.DumpLibSODetails(); + + LRT_RETURN_STATUS_IF_NOT_OK(qnn.ResolveFuncs()); + qnn.DumpProviderDetails(); + + { + if (QNN_SUCCESS != qnn.API()->logCreate(qnn::log::GetDefaultStdOutLogger(), + QNN_LOG_LEVEL_DEBUG, + &qnn.LogHandle())) { + return StatusCreate(kLrtStatusErrorNotFound); + } + } + + { + auto cfg = qnn::config::GetDefaultHtpConfigs(); + if (QNN_SUCCESS != qnn.API()->backendCreate(qnn.LogHandle(), cfg.data(), + &qnn.BackendHandle())) { + return StatusCreate(kLrtStatusErrorNotFound); + } + } + + { + auto cfg = qnn::config::GetDefaultContextConfigs(); + auto device = nullptr; + if (QNN_SUCCESS != qnn.API()->contextCreate(qnn.BackendHandle(), device, + cfg.data(), + &qnn.ContextHandle())) { + return StatusCreate(kLrtStatusErrorNotFound); + } + } + return StatusOk(); +} + +}; // namespace qnn diff --git a/tensorflow/lite/experimental/lrt/qnn/qnn_manager.h b/tensorflow/lite/experimental/lrt/qnn/qnn_manager.h new file mode 100644 index 00000000000000..83acaecffad044 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/qnn/qnn_manager.h @@ -0,0 +1,130 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_QNN_MANAGER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_QNN_MANAGER_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/include/QNN/QnnBackend.h" +#include "third_party/qairt/include/QNN/QnnCommon.h" +#include "third_party/qairt/include/QNN/QnnInterface.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" + +namespace qnn { + +typedef QNN_INTERFACE_VER_TYPE QnnFunctionPointers; + +// Wrapper to manage dynamic loading and lifetimes of QNN SDK objects. +class QnnManager { + public: + explicit QnnManager() = default; + + // + // Manage libQnn*.so Loading + // + + // Loads the libQnn*.so at given path. + LrtStatus LoadLibSO(absl::string_view path); + + // Dumps dynamic loading info about the loaded libQnn*.so. Does + // nothing if it has not been loaded yet. + void DumpLibSODetails() const; + + // + // Resolve and Access QNN SDK Functions + // + + // Resolve all available QNN SDK functions from (already) loaded so. If + // multiple providers are found, selects the first one with a suitable + // version. Fails if none can be found. + LrtStatus ResolveFuncs(); + + // Get resolved function pointers for qnn sdk calls. Nullptr if functions + // have not been resolved yet. + const QnnFunctionPointers* API() const; + + // Dumps information relevant to the loaded api provider. Does nothing if + // a successful ResolveFuncs hasn't occurred. + void DumpProviderDetails() const; + + // + // QNN SDK Objects. + // + + // Get qnn log handle. Nullptr if logCreate has not been successfully called. + Qnn_LogHandle_t& LogHandle() { return log_handle_; } + + // Signal QNN SDK to free any memory related to logging. Does nothing + // if logCreate has not been called. + LrtStatus FreeLogging(); + + // Get qnn backend handle. Nullptr if backendCreate has not been successfully + // called. + Qnn_BackendHandle_t& BackendHandle() { return backend_handle_; } + + // Signal QNN SDK to free any memory related to backend. Does nothing + // if backendCreate has not been called. + LrtStatus FreeBackend(); + + // Get qnn context handle. Nullptr if contextCreate has not been successfully + // called. + Qnn_ContextHandle_t& ContextHandle() { return context_handle_; } + + // Signal QNN SDK to free any memory related to context. Does nothing + // if contextCreate has not been called. + LrtStatus FreeContext(); + + // + // Context Binary + // + + // Generates QNN context binary from current context. Writes to given + // buffer. + LrtStatus GenerateContextBin(std::vector& buffer); + + private: + void* lib_so_ = nullptr; + + const QnnInterface_t* interface_ = nullptr; + + Qnn_LogHandle_t log_handle_ = nullptr; + + Qnn_BackendHandle_t backend_handle_ = nullptr; + + Qnn_ContextHandle_t context_handle_ = nullptr; +}; + +// Runs alls "setup" methods (LoadLibSO, ResolveFuncs) and aditionally +// instantiates the logging, backend and context. +LrtStatus SetupAll(QnnManager& qnn); + +// Default QNN Configurations. +namespace config { + +inline absl::Span GetDefaultHtpConfigs() { + static const QnnBackend_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +inline absl::Span GetDefaultContextConfigs() { + static const QnnContext_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +} // namespace config + +} // namespace qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_QNN_QNN_MANAGER_H_ diff --git a/tensorflow/lite/experimental/lrt/test_data/one_mul.mlir b/tensorflow/lite/experimental/lrt/test_data/one_mul.mlir new file mode 100644 index 00000000000000..afabf1903ee846 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/test_data/one_mul.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +} \ No newline at end of file From 701edc872f622506d09c1952737a98f392d0c784 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 26 Sep 2024 12:46:15 -0700 Subject: [PATCH 336/483] Integrate StableHLO at openxla/stablehlo@9d9290dc PiperOrigin-RevId: 679256562 --- third_party/stablehlo/workspace.bzl | 4 ++-- third_party/xla/third_party/stablehlo/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 2be9b211c51514..0bd9fb077ccc9a 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "ca13d31b5ed0b2053dde0a624480ad765e219ebf" - STABLEHLO_SHA256 = "123462093f087f2576bb6a6cc471370eed2d43c291f881ff359fd4ca812003db" + STABLEHLO_COMMIT = "9d9290dc2308c1850cea69ea05f8c94017e484ee" + STABLEHLO_SHA256 = "29803fc8a3a96f9e5469c7ab51f2ff4292dc2419c17bd0466f5d15a448cf6815" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 2be9b211c51514..0bd9fb077ccc9a 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "ca13d31b5ed0b2053dde0a624480ad765e219ebf" - STABLEHLO_SHA256 = "123462093f087f2576bb6a6cc471370eed2d43c291f881ff359fd4ca812003db" + STABLEHLO_COMMIT = "9d9290dc2308c1850cea69ea05f8c94017e484ee" + STABLEHLO_SHA256 = "29803fc8a3a96f9e5469c7ab51f2ff4292dc2419c17bd0466f5d15a448cf6815" # LINT.ThenChange(Google-internal path) tf_http_archive( From eb4d79df817e98c94a8b615565073d3868642885 Mon Sep 17 00:00:00 2001 From: Shaogang Wang Date: Thu, 26 Sep 2024 12:58:06 -0700 Subject: [PATCH 337/483] PR #17631: Add nccl AllToAllThunk support to command buffer Imported from GitHub PR https://github.com/openxla/xla/pull/17631 Copybara import of the project: -- 1380f8e22793ef21ab530e16cb65f869a925b7a4 by Shawn Wang : add command buffer alltoall support -- ddfab8f3ca025d2df9cbd1e3e543f811ed965e0b by Shawn Wang : add a space to vlog message -- 5ddb5f539d0c265ba5a1a6a6fb273a29d0ae7580 by Shawn Wang : clean code format -- 7b1917988a568afdb8feadfdf233f96f4a812be4 by Shawn Wang : sequentail thunk -- 6fe43dfbeeef87758569a0c59e26a4f1119a9a0e by Shawn Wang : add NcclAllToAllDone thunk support Merging this change closes #17631 PiperOrigin-RevId: 679260969 --- third_party/xla/xla/service/gpu/runtime/BUILD | 2 + .../service/gpu/runtime/command_buffer_cmd.cc | 71 +++++++++++++++++++ .../service/gpu/runtime/command_buffer_cmd.h | 27 +++++++ .../gpu/runtime/command_buffer_cmd_emitter.cc | 11 +++ .../gpu/runtime/nccl_all_to_all_thunk.h | 5 +- .../transforms/command_buffer_scheduling.cc | 9 ++- .../command_buffer_scheduling_test.cc | 29 ++++++++ 7 files changed, 150 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 9427f585206df4..336df6502005bd 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -58,6 +58,7 @@ cc_library( ":custom_call_thunk", ":nccl_all_gather_thunk", ":nccl_all_reduce_thunk", + ":nccl_all_to_all_thunk", ":nccl_api", ":nccl_clique_key", ":nccl_collective_broadcast_thunk", @@ -126,6 +127,7 @@ cc_library( ":memset_thunk", ":nccl_all_gather_thunk", ":nccl_all_reduce_thunk", + ":nccl_all_to_all_thunk", ":nccl_collective_thunk", ":replica_id_thunk", ":sequential_thunk", diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index abfa6eb465cd99..3bdea089ff425a 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -53,6 +53,7 @@ limitations under the License. #include "xla/service/gpu/runtime/annotation.h" #include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h" @@ -1780,6 +1781,76 @@ CommandBufferCmd::BufferUsageVector ReduceScatterCmd::buffers() { return buffer_usage; } +//===----------------------------------------------------------------------===// +// AllToAllCmd +//===----------------------------------------------------------------------===// + +AllToAllCmd::AllToAllCmd(ExecutionStreamId execution_stream_id, + ExecutionStreamId async_from_stream_id, + NcclApi* nccl_api, NcclCollectiveConfig config, + bool has_split_dimension, + absl::Span buffers) + : CollectiveCmd(CommandBufferCmdType::kAllToAll, execution_stream_id, + async_from_stream_id, nccl_api, std::move(config)), + has_split_dimension_(has_split_dimension), + buffers_(buffers.begin(), buffers.end()) {} + +absl::Status AllToAllCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + TF_RETURN_IF_ERROR(BarrierIfAsync( + command_buffer, execute_params.stream->parent(), record_params)); + + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, + config().operand_element_type)); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "AllToAllCmd, has_split_dimension=" << has_split_dimension_ + << ", execution_scope_id=" << execution_scope_id.value(); + + for (size_t i = 0; i < device_buffers.size(); ++i) { + VLOG(5) << " Src: " << buffers_[i].source_buffer << " (" + << device_buffers[i].source_buffer.opaque() << ")"; + VLOG(5) << " Dst: " << buffers_[i].destination_buffer << " (" + << device_buffers[i].destination_buffer.opaque() << ")"; + } + + if (!execute_params.collective_params || !execute_params.collective_cliques) { + return absl::InvalidArgumentError( + "ReduceScatterCmd requires collective parameters and cliques"); + } + + TF_ASSIGN_OR_RETURN( + NcclCommHandleWrapper comm_handle, + GetNcclComm(*execute_params.collective_params, + *execute_params.collective_cliques, config().replica_groups, + config().group_mode, nccl_stream_id(), GetAsyncStreamKind())); + NcclApi::NcclCommHandle comm = comm_handle.comm_handle; + // Use custom allocator for persistent execution plans. + NcclApi::ScopedPersistentPlanAllocator scoped_allocator( + comm, tsl::MakeRef( + execute_params.buffer_allocations->device_ordinal(), + execute_params.buffer_allocations->memory_allocator(), + execute_params.stream)); + + return AddTracedCommandBuffer( + execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RunAllToAll(nccl_api(), has_split_dimension_, device_buffers, + *stream, comm); + }); +} + +CommandBufferCmd::BufferUsageVector AllToAllCmd::buffers() { + BufferUsageVector buffer_usage; + for (auto& buffer : buffers_) { + buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); + buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); + } + return buffer_usage; +} + //===----------------------------------------------------------------------===// // AllGatherCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index 27e8fea0d86366..04b34e991029f2 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -78,6 +78,7 @@ namespace xla::gpu { V(kCollectiveCmd, "CollectiveCmd") \ V(kAllReduceCmd, "AllReduceCmd") \ V(kReduceScatter, "ReduceScatterCmd") \ + V(kAllToAll, "AllToAllCmd") \ V(kAllGatherCmd, "AllGatherCmd") \ V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd") \ V(kUnknownCmd, "UnknownCmd") \ @@ -1073,6 +1074,32 @@ class ReduceScatterCmd : public CollectiveCmd { std::vector buffers_; }; +//===----------------------------------------------------------------------===// +// AllToAllCmd +//===----------------------------------------------------------------------===// + +class AllToAllCmd : public CollectiveCmd { + public: + AllToAllCmd(ExecutionStreamId execution_stream_id, + ExecutionStreamId async_from_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, bool has_split_dimension, + absl::Span buffers); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + AsyncStreamKind GetAsyncStreamKind() override { + return AsyncStreamKind::kCollective; + }; + + private: + bool has_split_dimension_; + std::vector buffers_; +}; + //===----------------------------------------------------------------------===// // AllGatherCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index cb3801ef13c1da..d7e7b45fb38d98 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/service/gpu/runtime/memset_thunk.h" #include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/replica_id_thunk.h" #include "xla/service/gpu/runtime/sequential_thunk.h" @@ -173,6 +174,13 @@ static absl::StatusOr Convert( thunk.buffers()); } +static absl::StatusOr Convert(const NcclAllToAllStartThunk& thunk) { + return std::make_unique( + thunk.nccl_execution_stream_id(), thunk.execution_stream_id(), + thunk.nccl_api(), thunk.config(), thunk.has_split_dimension(), + thunk.buffers()); +} + static absl::StatusOr Convert(const NcclAllGatherStartThunk& thunk) { return std::make_unique( thunk.nccl_execution_stream_id(), thunk.execution_stream_id(), @@ -270,6 +278,8 @@ static absl::Status AppendCommands( return append(Convert(thunk)); case Thunk::Kind::kNcclReduceScatterStart: return append(Convert(thunk)); + case Thunk::Kind::kNcclAllToAllStart: + return append(Convert(thunk)); case Thunk::Kind::kPartitionId: return append(Convert(thunk)); case Thunk::Kind::kReplicaId: @@ -289,6 +299,7 @@ static absl::Status AppendCommands( case Thunk::Kind::kNcclAllGatherDone: case Thunk::Kind::kNcclAllReduceDone: case Thunk::Kind::kNcclReduceScatterDone: + case Thunk::Kind::kNcclAllToAllDone: return append(Convert(thunk)); case Thunk::Kind::kWaitForStreams: diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h index ed3056ec646789..e7bb099fb7419d 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h @@ -59,8 +59,11 @@ class NcclAllToAllStartThunk : public NcclCollectiveThunk { static CollectiveOpGroupMode GetGroupMode( const HloAllToAllInstruction* instr); - protected: const NcclCollectiveConfig& config() const override { return config_.config; } + bool has_split_dimension() const { return config_.has_split_dimension; } + absl::Span buffers() const { return buffers_; } + + protected: absl::Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, NcclCommHandleWrapper comm_wrapper) override; diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index 641a37a9659d30..79e77022b0a6ff 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -112,12 +112,14 @@ static bool IsAsyncStartCommand(const HloInstruction* hlo, if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { return config.enabled_commands.contains(DebugOptions::FUSION); } - if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { + if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter || + hlo->async_wrapped_opcode() == HloOpcode::kAllToAll) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } } - if (hlo->opcode() == HloOpcode::kReduceScatter) { + if (hlo->opcode() == HloOpcode::kReduceScatter || + hlo->opcode() == HloOpcode::kAllToAll) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } @@ -138,7 +140,8 @@ static bool IsAsyncDoneCommand(const HloInstruction* hlo, if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { return config.enabled_commands.contains(DebugOptions::FUSION); } - if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { + if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter || + hlo->async_wrapped_opcode() == HloOpcode::kAllToAll) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } } diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index be29a09897b54b..383a4f7e479f1d 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -1072,6 +1072,35 @@ TEST_F(CommandBufferSchedulingTest, AsyncFusion) { }); } +TEST_F(CommandBufferSchedulingTest, AsyncAlltoAll) { + const char* hlo = R"( + HloModule m, is_scheduled=true + + async_computation.1 { + param.1 = f32[4,8,128]{2,1,0} parameter(0) + ROOT all-to-all.1 = f32[4,8,128]{2,1,0} all-to-all(param.1), channel_id=1, dimensions={1} + } + + ENTRY main { + param.0 = f32[4,8,128]{2,1,0} parameter(0) + all-to-all-start = ((f32[4,8,128]{2,1,0}), f32[4,8,128]{2,1,0}) async-start(param.0), calls=async_computation.1 + ROOT all-to-all-done = f32[4,8,128]{2,1,0} async-done(all-to-all-start) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P:.+]]: f32[4,8,128]) -> f32[4,8,128] { + CHECK: %[[P]] = f32[4,8,128]{2,1,0} parameter(0) + CHECK: %[[S1:.+]] = ((f32[4,8,128]{2,1,0}), f32[4,8,128]{2,1,0}) all-to-all-start(%[[P]]), channel_id=1, replica_groups={}, dimensions={1} + CHECK: ROOT {{.*}} = f32[4,8,128]{2,1,0} all-to-all-done(%[[S1]]) + CHECK: })"; + + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionDynamicSlicing) { if (backend().platform()->Name() == "Host") { GTEST_SKIP() << "GPU support required for this test"; From d736421e0465b072962d1c309822c3f7deaccd48 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 13:41:10 -0700 Subject: [PATCH 338/483] Add TFLite flatbuffer debug metadata deserialization logic. PiperOrigin-RevId: 679277364 --- tensorflow/compiler/mlir/lite/BUILD | 6 +- .../compiler/mlir/lite/flatbuffer_import.cc | 258 ++++++++++++++++-- .../tests/flatbuffer2mlir/debug_metadata.mlir | 36 +++ 3 files changed, 282 insertions(+), 18 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 7bb70a19f4f116..2c26cf3cc5f166 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1373,15 +1373,16 @@ cc_library( ], deps = [ ":const_tensor_utils", + ":control_edges", ":convert_type", ":flatbuffer_tflite_operator_lib", ":offset_buffer", ":size_utils", ":tensorflow_lite", - "//tensorflow/compiler/mlir/lite:control_edges", "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:debug_metadata_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_utils", @@ -1398,7 +1399,9 @@ cc_library( "//tensorflow/core/platform:errors", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Support", @@ -1411,6 +1414,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@stablehlo//:stablehlo_ops", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index a289126d26b6ca..97dfa6ce44a7ec 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -29,7 +29,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "llvm/ADT/APFloat.h" @@ -80,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/offset_buffer.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/schema/mutable/debug_metadata_generated.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" @@ -99,6 +102,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -130,6 +134,17 @@ using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; using ::tflite::IsValidBufferOffset; +struct DebugMetadata { + // Debug metadata locations. + std::vector debug_metadata_locations; + + // Maps from operator (subgraph_debug_metadata_idx, + // operator_debug_metadata_idx) to its top-level location index in + // `debug_metadata_locations`, which is: + // <, location_idx>. + absl::flat_hash_map> operator_location_map; +}; + // Create the MLIR NamedLoc location corresponding to a given tensor Location TensorLoc(const TensorT& tensor, Builder builder, Location base) { if (tensor.name.empty()) { @@ -138,27 +153,223 @@ Location TensorLoc(const TensorT& tensor, Builder builder, Location base) { return mlir::NameLoc::get(builder.getStringAttr(tensor.name), base); } -// Create the MLIR Location corresponding to a given op. This is an -// experimental/debugging feature and production code should not rely on names -// of intermediate tensors since importer doesn't guarantee to preserve tensor -// names except output tensors. -Location OpLoc(const OperatorT& op, - const std::vector>& tensors, - Builder builder, Location base) { +// Build and return the MLIR location. +StatusOr BuildLocation( + Builder builder, const debug_metadata::Location& location, + const std::vector& debug_metadata_locations, + const absl::flat_hash_map& + attribute_location_idx_map) { + switch (location.location_type()) { + // FileLineColLoc. + case debug_metadata::LocationType_FileLineColLoc: { + auto file_line_col_loc = + static_cast( + location.location()); + return mlir::FileLineColLoc::get( + builder.getContext(), + builder.getStringAttr(file_line_col_loc->filename()->string_view()), + file_line_col_loc->line(), file_line_col_loc->column()); + } + // CallSiteLoc. + case debug_metadata::LocationType_CallSiteLoc: { + auto callsite_loc = + static_cast(location.location()); + if (!attribute_location_idx_map.contains(callsite_loc->callee_index()) || + !attribute_location_idx_map.contains(callsite_loc->caller_index())) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken (callee " + "or caller index of a CallSiteLoc is not valid)"); + } + return mlir::CallSiteLoc::get( + debug_metadata_locations[attribute_location_idx_map.at( + callsite_loc->callee_index())], + debug_metadata_locations[attribute_location_idx_map.at( + callsite_loc->caller_index())]); + } + // NameLoc. + case debug_metadata::LocationType_NameLoc: { + auto name_loc = + static_cast(location.location()); + if (!attribute_location_idx_map.contains(name_loc->child_index())) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken (child " + "index of a NameLoc is not valid)"); + } + return mlir::NameLoc::get( + builder.getStringAttr(name_loc->name()->string_view()), + debug_metadata_locations[attribute_location_idx_map.at( + name_loc->child_index())]); + } + // FusedLoc. + case debug_metadata::LocationType_FusedLoc: { + auto fused_loc = + static_cast(location.location()); + auto fused_location_indexes = fused_loc->location_indexes(); + std::vector fused_locations; + fused_locations.reserve(fused_location_indexes->size()); + for (int fused_loc_idx = 0; + fused_loc_idx < fused_location_indexes->size(); ++fused_loc_idx) { + if (!attribute_location_idx_map.contains( + fused_location_indexes->Get(fused_loc_idx))) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken " + "(location index of a FusedLoc is not valid)"); + } + fused_locations.push_back( + debug_metadata_locations[attribute_location_idx_map.at( + fused_location_indexes->Get(fused_loc_idx))]); + } + return mlir::FusedLoc::get( + fused_locations, mlir::StringAttr::get(builder.getContext(), ""), + builder.getContext()); + } + default: { + return mlir::UnknownLoc::get(builder.getContext()); + } + } +} + +// Parses all locations in ConversionDebugMetadata, build the mlir::location +// counterparts, and put them inside debug_metadata_. Additionally, maintain a +// map that maps the top location index of each operator. +Status ParseAndBuildLocation( + Builder builder, + const debug_metadata::ConversionDebugMetadata* conversion_debug_metadata, + DebugMetadata& debug_metadata_var) { + auto attribute_types = conversion_debug_metadata->attributes_type(); + auto attributes = conversion_debug_metadata->attributes(); + + auto& debug_metadata_locations = debug_metadata_var.debug_metadata_locations; + debug_metadata_locations.reserve(attribute_types->size()); + + // Map index in the attribute_vector to the index in the data structure we + // are building: DebugMetadata::debug_metadata_locations. + absl::flat_hash_map attribute_location_idx_map; + + for (int i = 0; i < attribute_types->size(); ++i) { + if (attribute_types->Get(i) == debug_metadata::Attribute_Location) { + auto location = + static_cast(attributes->Get(i)); + TF_ASSIGN_OR_RETURN( + auto mlir_location, + BuildLocation(builder, *location, debug_metadata_locations, + attribute_location_idx_map)); + debug_metadata_locations.push_back(mlir_location); + + // Create index mapping. + attribute_location_idx_map[i] = debug_metadata_locations.size() - 1; + } + } + + // Collect the top location idx of each operator. + auto subgraphs_debug_metadata = + conversion_debug_metadata->subgraphs_debug_metadata(); + for (int subgraph_idx = 0; subgraph_idx < subgraphs_debug_metadata->size(); + ++subgraph_idx) { + const auto* subgraph_debug_metadata = + subgraphs_debug_metadata->Get(subgraph_idx); + auto operators_debug_metadata = + subgraph_debug_metadata->operators_debug_metadata(); + for (int operator_idx = 0; operator_idx < operators_debug_metadata->size(); + ++operator_idx) { + const auto* operator_debug_metadata = + operators_debug_metadata->Get(operator_idx); + // Find the location attribute of the operator. Note that there should + // be at most one idx pointing to location attribute for each operator. + std::vector location_attribute_idxs; + for (int i = 0; + i < operator_debug_metadata->attribute_metadata_indexes()->size(); + ++i) { + auto attribute_idx = + operator_debug_metadata->attribute_metadata_indexes()->Get(i); + if (attribute_types->Get(attribute_idx) == + debug_metadata::Attribute_Location) { + location_attribute_idxs.push_back(attribute_idx); + } + } + if (location_attribute_idxs.size() > 1) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken (more " + "than one location attribute for an operator)"); + } + if (location_attribute_idxs.empty()) { + continue; + } + + if (!attribute_location_idx_map.contains(location_attribute_idxs[0])) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken " + "(location attribute index of an operator is not valid)"); + } + debug_metadata_var.operator_location_map[subgraph_idx][operator_idx] = + attribute_location_idx_map[location_attribute_idxs[0]]; + } + } + + return absl::OkStatus(); +} + +// Parse the DebugMetadata flatbuffer and store debug metadata in struct +// `debug_metadata`. +Status ParseDebugMetadata(Builder builder, const char* data, size_t size, + DebugMetadata& debug_metadata_var) { + auto debug_metadata_fb = debug_metadata::GetDebugMetadata(data); + + if (debug_metadata_fb->debug_metadata_type()->size() != + debug_metadata_fb->debug_metadata()->size()) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken (size of " + "debug_metadata_type and debug_metadata not equal)"); + } + + for (int i = 0; i < debug_metadata_fb->debug_metadata_type()->size(); ++i) { + if (debug_metadata_fb->debug_metadata_type()->Get(i) == + debug_metadata::DebugMetadataType_ConversionDebugMetadata) { + auto conversion_debug_metadata = + static_cast( + debug_metadata_fb->debug_metadata()->Get(i)); + TF_RETURN_IF_ERROR(ParseAndBuildLocation( + builder, conversion_debug_metadata, debug_metadata_var)); + } else { + LOG(WARNING) << "Unsupported DebugMetadataType: " + << debug_metadata_fb->debug_metadata_type()->Get(i); + } + } + + return absl::OkStatus(); +} + +// Return MLIR location if it exists in the debug metadata. Otherwise, create a +// MLIR location by fusing its output tensor names. +Location OpLoc(const OperatorT& op, Builder builder, + DebugMetadata& debug_metadata, const tflite::SubGraphT& subgraph, + Location base) { + const int subgraph_debug_metadata_idx = subgraph.debug_metadata_index; + if (debug_metadata.operator_location_map.contains( + subgraph_debug_metadata_idx) && + debug_metadata.operator_location_map[subgraph_debug_metadata_idx] + .contains(op.debug_metadata_index)) { + int location_idx = + debug_metadata.operator_location_map[subgraph_debug_metadata_idx] + [op.debug_metadata_index]; + return debug_metadata.debug_metadata_locations[location_idx]; + } + if (op.outputs.empty()) return base; llvm::SmallVector locations; locations.reserve(op.outputs.size()); for (auto tensor_index : op.outputs) { - locations.push_back(TensorLoc(*tensors[tensor_index], builder, base)); + locations.push_back( + TensorLoc(*subgraph.tensors[tensor_index], builder, base)); } return mlir::FusedLoc::get(builder.getContext(), locations); } // Extract the min max information in the tensor and create the quant stats op. -// If the input `tensor` has scale/zero_point, `res` should have quantized -// type, thus none stats op is required and nullptr is returned. -// If the min max information is invalid, nullptr is returned. +// If the input `tensor` has scale/zero_point, `res` should have quantized type, +// thus none stats op is required and nullptr is returned. If the min max +// information is invalid, nullptr is returned. mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, Value res) { // If the `tensor` has scale/zero_point, it must have been quantized, then the @@ -678,8 +889,8 @@ StatusOr ConvertOp( } // While the last several tensors could be optional tensors for an tfl op, the - // number of input operands could vary. Gets the min/max number of - // operands from tflite op name. + // number of input operands could vary. Gets the min/max number of operands + // from tflite op name. // Also, since the above code special-handles the `tfl.reshape` op and add an // additional input, we put these function block here. llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name); @@ -1117,7 +1328,7 @@ StatusOr ConvertSubgraph( const tflite::SignatureDefT* signature, const tflite::ControlEdges& control_edges, const std::unique_ptr& model_ptr, - bool use_stablehlo_constant) { + bool use_stablehlo_constant, DebugMetadata& debug_metadata) { // Populate from metadata. ControlNodes control_nodes; for (const auto [from, to] : control_edges) { @@ -1301,11 +1512,12 @@ StatusOr ConvertSubgraph( TF_ASSIGN_OR_RETURN( mlir::TensorType type, tfl::GetTensorType(*subgraph.tensors[intermediate], builder, - /*is_constant=*/false, /*is_intermediate=*/true)); + /*is_constant=*/false, + /*is_intermediate=*/true)); intermediate_types.emplace_back(type); } - auto op_loc = OpLoc(*op, subgraph.tensors, builder, base_loc); + auto op_loc = OpLoc(*op, builder, debug_metadata, subgraph, base_loc); // If there's an optional argument, maybe_optional_arg_marker has been set // to a valid Value @@ -1535,6 +1747,7 @@ OwningOpRef tflite::FlatBufferToMlir( llvm::SmallVector metadata_attrs; mlir::StringSet<> seen_attr; + DebugMetadata debug_metadata; for (const auto& metadata : model->metadata) { if (metadata->name == tflite::kModelControlDependenciesMetadataKey) { const std::vector& data = model->buffers[metadata->buffer]->data; @@ -1559,6 +1772,17 @@ OwningOpRef tflite::FlatBufferToMlir( continue; } + if (metadata->name == "debug_metadata") { + const std::vector& data = model->buffers[metadata->buffer]->data; + auto status = ParseDebugMetadata( + builder, reinterpret_cast(data.data()), data.size(), + debug_metadata); + if (!status.ok()) { + return emitError(base_loc, std::string(status.message())), nullptr; + } + continue; + } + std::vector buffer = model->buffers[metadata->buffer]->data; metadata_attrs.emplace_back( builder.getStringAttr(metadata->name), @@ -1618,7 +1842,7 @@ OwningOpRef tflite::FlatBufferToMlir( ? subgraph_to_signature_map.at(subgraph_index) : nullptr, model_control_dependencies[subgraph_index], model_ptr, - use_stablehlo_constant); + use_stablehlo_constant, debug_metadata); if (!func_or_error.ok()) { return emitError(base_loc, "could not translate function ") << subgraph->name << ": " << func_or_error.status().message(), diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir new file mode 100644 index 00000000000000..61df9ad531515a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir @@ -0,0 +1,36 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer --serialize-debug-metadata=true %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --mlir-print-debuginfo -o - | FileCheck %s +// This test verifies that debug locations are round-trippable. + +module @jit_relu attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, tfl._legalize_tfl_variables = true} { + func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tfl.less"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> loc(#loc) + // CHECK-DAG: {{.*}} = tfl.less(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> loc([[LOC:.+]]) + %1 = "tf.If"(%0, %arg0, %arg1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> loc(#loc) + // CHECK-DAG: {{.*}} = "tf.If"(%0, %arg0, %arg1) {{.*}} -> tensor<1xf32> loc([[LOC]]) + func.return %1 : tensor<1xf32> loc(#loc) + } + + func.func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc(#loc4) + // CHECK-DAG: {{.*}} = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc([[LOC4:.+]]) + func.return %0 : tensor<*xf32> loc(#loc) + } + + func.func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc(#loc5) + // CHECK-DAG: {{.*}} = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc([[LOC5:.+]]) + func.return %0 : tensor<*xf32> loc(#loc) + } +} loc(#loc) +#loc = loc(unknown) +// CHECK-DAG: [[LOC]] = loc(unknown) +#loc1 = loc("":1:4) +// CHECK-DAG: [[LOC1:.+]] = loc("":1:4) +#loc2 = loc("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3066:16) +// CHECK-DAG: [[LOC2:.+]] = loc("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3066:16) +#loc3 = loc(callsite(#loc1 at #loc2)) +// CHECK-DAG: [[LOC3:.+]] = loc(callsite([[LOC1]] at [[LOC2]])) +#loc4 = loc("jit(relu)/jit(main)/max"(#loc3)) +// CHECK-DAG: [[LOC4]] = loc("jit(relu)/jit(main)/max"([[LOC3]])) +#loc5 = loc(fused<"">[#loc1, #loc2]) +// CHECK-DAG: [[LOC5]] = loc(fused<"">[[[LOC1]], [[LOC2]]]) \ No newline at end of file From 7d1b0006e9a4423daa4d9605dd21b60b524da61b Mon Sep 17 00:00:00 2001 From: Raviteja Gorijala Date: Thu, 26 Sep 2024 13:52:45 -0700 Subject: [PATCH 339/483] Update release notes for 2.17.1 PiperOrigin-RevId: 679281691 --- RELEASE.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index 9820f420589ac5..ccef1e4da76327 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -109,6 +109,14 @@ This release contains contributions from many people at Google, as well as: Akhil Goel, akhilgoe, Alexander Pivovarov, Amir Samani, Andrew Goodbody, Andrey Portnoy, Anthony Platanios, bernardoArcari, Brett Taylor, buptzyb, Chao, Christian Clauss, Cocoa, Daniil Kutz, Darya Parygina, dependabot[bot], Dimitris Vardoulakis, Dragan Mladjenovic, Elfie Guo, eukub, Faijul Amin, flyingcat, Frédéric Bastien, ganyu.08, Georg Stefan Schmid, Grigory Reznikov, Harsha H S, Harshit Monish, Heiner, Ilia Sergachev, Jan, Jane Liu, Jaroslav Sevcik, Kaixi Hou, Kanvi Khanna, Kristof Maar, Kristóf Maár, LakshmiKalaKadali, Lbertho-Gpsw, lingzhi98, MarcoFalke, Masahiro Hiramori, Mmakevic-Amd, mraunak, Nobuo Tsukamoto, Notheisz57, Olli Lupton, Pearu Peterson, pemeliya, Peyara Nando, Philipp Hack, Phuong Nguyen, Pol Dellaiera, Rahul Batra, Ruturaj Vaidya, sachinmuradi, Sergey Kozub, Shanbin Ke, Sheng Yang, shengyu, Shraiysh, Shu Wang, Surya, sushreebarsa, Swatheesh-Mcw, syzygial, Tai Ly, terryysun, tilakrayal, Tj Xu, Trevor Morris, Tzung-Han Juang, wenchenvincent, wondertx, Xuefei Jiang, Ye Huang, Yimei Sun, Yunlong Liu, Zahid Iqbal, Zhan Lu, Zoranjovanovic-Ns, Zuri Obozuwa +# Release 2.17.1 + +### Bug Fixes and Other Changes + +* Add necessary header files in the aar library. These are needed if developers build apps with header files unpacked from tflite aar files from maven. +* Implement Name() for GCSWritableFile to fix the profiler trace viewer cache file generation. +* Fix `cstring.h` missing file issue with the Libtensorflow archive. + # Release 2.17.0 ## TensorFlow From 0b22d3247058eb970edb38348dd4638715f271d0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 26 Sep 2024 14:29:10 -0700 Subject: [PATCH 340/483] [XLA:Python] Avoid copying an nb::detail::dict_iterator. Nanobind 2.2.0 makes dict iterators uncopyable. In addition, avoid a possible exception-safety problem where Python .equals() was called from an equality test used by an ABSL hash table. PiperOrigin-RevId: 679295293 --- third_party/xla/xla/python/BUILD | 2 - .../xla/xla/python/weakref_lru_cache.cc | 143 ++++++++---------- 2 files changed, 67 insertions(+), 78 deletions(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 1b73a0f2185af0..0c5bed03247840 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -1109,11 +1109,9 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:private"], deps = [ - ":nb_helpers", # placeholder for index annotation deps "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@nanobind", diff --git a/third_party/xla/xla/python/weakref_lru_cache.cc b/third_party/xla/xla/python/weakref_lru_cache.cc index 2c2e5bfc222e2a..736bf437d106d6 100644 --- a/third_party/xla/xla/python/weakref_lru_cache.cc +++ b/third_party/xla/xla/python/weakref_lru_cache.cc @@ -19,14 +19,15 @@ limitations under the License. #include #include #include +#include #include #include // NOLINT +#include #include #include #include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" -#include "absl/container/node_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" @@ -35,7 +36,6 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/pjrt/lru_cache.h" -#include "xla/python/nb_helpers.h" namespace nb = nanobind; @@ -44,36 +44,38 @@ namespace { // Minimal wrapper to expose a nb::dict_iterator's value as something // hashable with Abseil. -class HashablePyDictValue { - protected: - using Iter = nb::detail::dict_iterator; +class HashablePyDictEntry { + public: + explicit HashablePyDictEntry(std::pair entry) + : entry_(entry) {} template - friend H AbslHashValue(H h, const HashablePyDictValue& value) { - auto kv = *value.iter_; - return H::combine(std::move(h), nb::hash(kv.first), nb::hash(kv.second)); + friend H AbslHashValue(H h, const HashablePyDictEntry& v) { + return H::combine(std::move(h), nb::hash(v.entry_.first), + nb::hash(v.entry_.second)); } - explicit HashablePyDictValue(const Iter& iter) : iter_(iter) {} - - Iter iter_; + std::pair entry_; }; // Similarly, a minimalist adaptor around the nb::detail::dict_iterator // itself. Note that the iterator "is" also a Value. Does not meet the full // standard iterator requirements, only enough to support H::combine_unordered. -class HashablePyDictIter : protected HashablePyDictValue { +class HashablePyDictIter { public: using iterator_category = std::input_iterator_tag; - explicit HashablePyDictIter(const Iter& iter) : HashablePyDictValue(iter) {} + explicit HashablePyDictIter(nb::detail::dict_iterator& iter) : iter_(iter) {} // Minimal set of iterator operations. - const HashablePyDictValue& operator*() const { return *this; } + HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); } bool operator!=(const HashablePyDictIter& rhs) const { return iter_ != rhs.iter_; } void operator++() { ++iter_; } + + private: + nb::detail::dict_iterator& iter_; }; } // namespace @@ -92,10 +94,15 @@ class WeakrefLRUCache : public std::enable_shared_from_this { template friend H AbslHashValue(H h, const Key& key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); - h = H::combine_unordered(std::move(h), - HashablePyDictIter(key.kwargs.begin()), - HashablePyDictIter(key.kwargs.end())); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); h = H::combine(std::move(h), key.kwargs.size()); return h; } @@ -115,82 +122,65 @@ class WeakrefLRUCache : public std::enable_shared_from_this { int64_t currsize; }; - struct UnboundWeakrefCacheEntry { + struct WeakrefCacheKey { nb::handle object; - WeakrefLRUCache* cache; size_t cached_hash; }; - struct WeakrefCacheEntry { - nb::weakref weakref; - size_t cached_hash; + using Cache = xla::LRUCache>; + + struct WeakrefCacheValue { + std::optional weakref; + std::shared_ptr cache; }; struct WeakrefKeyHash { - using is_transparent = void; - - size_t operator()(const UnboundWeakrefCacheEntry& v) const { - return v.cached_hash; - } - size_t operator()(const WeakrefCacheEntry& v) const { - return v.cached_hash; - } + size_t operator()(const WeakrefCacheKey& v) const { return v.cached_hash; } }; struct WeakrefKeyEq { - using is_transparent = void; - bool operator()(const WeakrefCacheEntry& lhs, - const WeakrefCacheEntry& rhs) const { - return lhs.weakref.equal(rhs.weakref); - } - bool operator()(const WeakrefCacheEntry& lhs, - const UnboundWeakrefCacheEntry& rhs) const { - PyObject* obj = PyWeakref_GET_OBJECT(lhs.weakref.ptr()); - if (obj == Py_None) { - return false; - } - return nb::borrow(obj).equal(rhs.object); + bool operator()(const WeakrefCacheKey& lhs, + const WeakrefCacheKey& rhs) const { + return lhs.object.equal(rhs.object); } }; - using Cache = xla::LRUCache>; WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} - std::shared_ptr GetCache(const UnboundWeakrefCacheEntry& key) { - auto it = entries_.find(key); - if (it != entries_.end()) { - return (it->second); + std::shared_ptr GetCache(WeakrefCacheKey key) { + auto [it, inserted] = entries_.emplace(key, WeakrefCacheValue()); + if (!inserted) { + return it->second.cache; } - nb::weakref weakref( - key.object, - nb::cpp_function([this_weak = weak_from_this(), - cached_hash = key.cached_hash](nb::handle weakref) { - auto cache = this_weak.lock(); - if (cache == nullptr) { - return; - } - auto it = cache->entries_.find( - WeakrefCacheEntry{nb::borrow(weakref), cached_hash}); - if (it == cache->entries_.end()) { - return; - } - // Create temp-var to avoid re-entrant erase. - auto tmp = std::move(it->second); - cache->entries_.erase(it); - })); - return (entries_ - .emplace(WeakrefCacheEntry{std::move(weakref), key.cached_hash}, - std::make_shared(&lru_list_)) - .first->second); + + auto& value = it->second; + + value.cache = std::make_shared(&lru_list_); + value.weakref = + nb::weakref(key.object, nb::cpp_function([this_weak = weak_from_this(), + key](nb::handle weakref) { + auto cache = this_weak.lock(); + if (cache == nullptr) { + return; + } + auto it = cache->entries_.find(key); + if (it == cache->entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + cache->entries_.erase(it); + })); + return value.cache; } nb::object Call(nb::object weakref_key, nb::args args, nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { nb::object context = cache_context_fn_(); - std::shared_ptr cache_ptr = GetCache(UnboundWeakrefCacheEntry{ - weakref_key, this, static_cast(nb::hash(weakref_key))}); + std::shared_ptr cache_ptr = GetCache(WeakrefCacheKey{ + weakref_key, static_cast(nb::hash(weakref_key))}); Cache& cache = *cache_ptr; ++total_queries_; @@ -246,10 +236,10 @@ class WeakrefLRUCache : public std::enable_shared_from_this { std::vector GetKeys() { std::vector results; mu_.Lock(); - for (const auto& wr_key : entries_) { - for (const auto& rest : *wr_key.second) { + for (const auto& wr_entry : entries_) { + for (const auto& rest : *wr_entry.second.cache) { nb::tuple result = - nb::make_tuple(wr_key.first.weakref, rest.first.context, + nb::make_tuple(*wr_entry.second.weakref, rest.first.context, rest.first.args, rest.first.kwargs); results.push_back(std::move(result)); } @@ -268,8 +258,9 @@ class WeakrefLRUCache : public std::enable_shared_from_this { void Clear() { total_queries_ = misses_ = 0; std::vector> deferred_deletes; + deferred_deletes.reserve(entries_.size()); for (auto& entry : entries_) { - deferred_deletes.push_back(std::move(entry.second)); + deferred_deletes.push_back(std::move(entry.second.cache)); } entries_.clear(); deferred_deletes.clear(); @@ -278,8 +269,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this { nb::callable cache_context_fn_; nb::callable fn_; Cache::LRUList lru_list_; - absl::node_hash_map, WeakrefKeyHash, - WeakrefKeyEq> + std::unordered_map entries_; int64_t misses_ = 0; int64_t total_queries_ = 0; From 2f09837b028ed162d6f16ee5dea2d27869991b27 Mon Sep 17 00:00:00 2001 From: Toli Yevtushenko Date: Thu, 26 Sep 2024 15:08:23 -0700 Subject: [PATCH 341/483] Add a helper method to HloTestBase to run a pass on a parameterized HLO string. This is a common pattern in HLO transformation tests, and it's useful to have a helper method to reduce boilerplate. This CL also updates all_reduce_folder_test.cc to use the new helper method. PiperOrigin-RevId: 679309211 --- third_party/xla/xla/service/BUILD | 2 +- .../xla/xla/service/all_reduce_folder_test.cc | 284 ++++++++---------- third_party/xla/xla/tests/BUILD | 9 + third_party/xla/xla/tests/hlo_test_base.cc | 28 +- third_party/xla/xla/tests/hlo_test_base.h | 21 ++ 5 files changed, 174 insertions(+), 170 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 9b51d54582b12f..7277e853f3a131 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -276,8 +276,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/all_reduce_folder_test.cc b/third_party/xla/xla/service/all_reduce_folder_test.cc index e984d089adb196..f23d1f7bdf0972 100644 --- a/third_party/xla/xla/service/all_reduce_folder_test.cc +++ b/third_party/xla/xla/service/all_reduce_folder_test.cc @@ -16,12 +16,10 @@ limitations under the License. #include "xla/service/all_reduce_folder.h" #include -#include +#include #include -#include #include "absl/algorithm/container.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -29,220 +27,180 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -namespace m = xla::testing::opcode_matchers; +namespace matcher = xla::testing::opcode_matchers; using ::testing::HasSubstr; -class AllReduceFolderTest : public HloTestBase { - public: - absl::StatusOr> RunPass( - absl::string_view hlo_module, bool expect_change) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); - auto changed = AllReduceFolder().Run(module.get()); - if (!changed.ok()) { - return changed.status(); - } - EXPECT_EQ(changed.value(), expect_change); - return absl::StatusOr>(std::move(module)); - } +class AllReduceFolderTest : public HloTestBase {}; - size_t AllReduceCount(std::unique_ptr &module) { - return absl::c_count_if(module->entry_computation()->instructions(), - HloPredicateIsOp); - } -}; +const char *k2AllReduce = R"( + HloModule m -TEST_F(AllReduceFolderTest, Simple) { - absl::string_view hlo_string = R"( -HloModule m + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups=$group_0, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups=$group_1, to_apply=sum + } + )"; -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum +size_t AllReduceCount(HloModule *module) { + return absl::c_count_if(module->entry_computation()->instructions(), + HloPredicateIsOp); } -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); + +void ExpectOneAllReduce(HloModule *module, + absl::string_view target_replica_groups) { EXPECT_EQ(AllReduceCount(module), 1); HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), HasSubstr("replica_groups={{0,1,2,3}}")); + EXPECT_THAT(root, matcher::AllReduce(matcher::Parameter(0))); + EXPECT_THAT(root->ToString(), HasSubstr(target_replica_groups)); } -// Same as Simple, but groups for the 2 all-reduce's are swapped. -TEST_F(AllReduceFolderTest, SimpleSwap) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) +TEST_F(AllReduceFolderTest, Simple) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), true, + {{"$group_0", "{{0,1},{2,3}}"}, + {"$group_1", "{{0,2},{1,3}}"}})); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3}}"); } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,2},{1,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), HasSubstr("replica_groups={{0,1,2,3}}")); +// Same as Simple, but groups for the 2 all-reduce's are swapped. +TEST_F(AllReduceFolderTest, SimpleSwap) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), true, + {{"$group_1", "{{0,1},{2,3}}"}, + {"$group_0", "{{0,2},{1,3}}"}})); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3}}"); } -TEST_F(AllReduceFolderTest, EmptyReplicaGroups) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) +TEST_F(AllReduceFolderTest, BothEmptyReplicaGroups_NotTransformed) { + TF_ASSERT_OK(RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), false, + {{"$group_0", "{}"}, {"$group_1", "{}"}})); } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); +TEST_F(AllReduceFolderTest, EmptyReplicaGroups_NotTransformed) { + TF_ASSERT_OK(RunAndCheckHloRewrite( + k2AllReduce, AllReduceFolder(), false, + {{"$group_0", "{}"}, {"$group_1", "{{0,2},{1,3}}"}})); } -TEST_F(AllReduceFolderTest, MismatchOtherProperties0) { +TEST_F(AllReduceFolderTest, MismatchOtherProperties0_NotTransformed) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, channel_id=1, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, channel_id=1, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); } -TEST_F(AllReduceFolderTest, MismatchOtherProperties1) { +TEST_F(AllReduceFolderTest, MismatchOtherProperties1_NotTransformed) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -mul { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT mul = f32[] multiply(a, b) -} + mul { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT mul = f32[] multiply(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=mul -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=mul + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); } -TEST_F(AllReduceFolderTest, NotFoldable) { +TEST_F(AllReduceFolderTest, NotFoldable_NotTransformed) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); } TEST_F(AllReduceFolderTest, Foldable0) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,4},{1,5},{2,3},{6,7}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,5},{4,1},{2,7},{3,6}}, to_apply=sum -} -)"; + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,4},{1,5},{2,3},{6,7}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,5},{4,1},{2,7},{3,6}}, to_apply=sum + } + )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), - HasSubstr("replica_groups={{0,1,4,5},{2,3,6,7}}")); + RunAndCheckHloRewrite(hlo_string, AllReduceFolder())); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,4,5},{2,3,6,7}}"); } // Verify that a chain of foldable all-reduce's folds in a single pass // invocation. TEST_F(AllReduceFolderTest, FoldableChain) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=sum - ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3},{4,6},{5,7}}, to_apply=sum - ROOT ar2 = f32[8] all-reduce(ar1), replica_groups={{0,4},{1,5},{2,6},{3,7}}, to_apply=sum -} -)"; + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=sum + ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3},{4,6},{5,7}}, to_apply=sum + ROOT ar2 = f32[8] all-reduce(ar1), replica_groups={{0,4},{1,5},{2,6},{3,7}}, to_apply=sum + } + )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - std::cerr << module->ToString(); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), - HasSubstr("replica_groups={{0,1,2,3,4,5,6,7}}")); + RunAndCheckHloRewrite(hlo_string, AllReduceFolder())); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3,4,5,6,7}}"); } } // namespace diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index c7dbbee3efe6a7..f6f9a729bcbc62 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -179,15 +179,19 @@ cc_library( ":test_utils", ":verified_hlo_module", "//xla:debug_options_flags", + "//xla:error_spec", + "//xla:literal", "//xla:shape_layout", "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:computation_layout", + "//xla/service:hlo_module_config", "//xla/service:hlo_module_util", "//xla/service:hlo_parser", "//xla/service:hlo_runner", @@ -200,12 +204,17 @@ cc_library( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index dbbcd5f866e924..28985fd2ba33b7 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include #include -#include #include +#include #include #include @@ -28,19 +28,19 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/utils/hlo_query.h" -#include "xla/layout_util.h" #include "xla/service/hlo_module_util.h" -#include "xla/service/hlo_parser.h" #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_runner_pjrt.h" #include "xla/service/platform_util.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/filecheck.h" @@ -49,9 +49,9 @@ limitations under the License. #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -323,6 +323,21 @@ void HloTestBase::RunAndFilecheckHloModuleGroupRewrite( } } +absl::StatusOr> HloTestBase::RunAndCheckHloRewrite( + absl::string_view hlo_template, HloPassInterface&& hlo_pass, + bool expect_change, FixedMapping params) { + std::string hlo_string = absl::StrReplaceAll(hlo_template, params); + SCOPED_TRACE("Input HLO: " + hlo_string); + VLOG(7) << "Input HLO: " << hlo_string; + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get())); + VLOG(7) << "Output HLO: " + << module->ToString(HloPrintOptions::ShortParsable()); + EXPECT_EQ(changed, expect_change); + return module; +} + absl::StatusOr HloTestBase::Execute( std::unique_ptr module, absl::Span arguments, bool run_hlo_passes) { @@ -665,7 +680,8 @@ ::testing::AssertionResult HloTestBase::RunAndCompareTwoModulesReplicated( auto num_args = module_0->entry_computation()->num_parameters(); if (num_args != options.arguments.size()) { return ::testing::AssertionFailure() - << "Mismatch in number of arguments passed while running replicated " + << "Mismatch in number of arguments passed while running " + "replicated " "hlo module. Expected: " << num_args << ", actual: " << options.arguments.size(); } diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index e075d7fd7123a0..c312fc35090d7c 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -16,21 +16,31 @@ limitations under the License. #ifndef XLA_TESTS_HLO_TEST_BASE_H_ #define XLA_TESTS_HLO_TEST_BASE_H_ +#include #include +#include #include #include #include #include #include +#include "absl/base/attributes.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/computation_layout.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_runner.h" #include "xla/service/hlo_verifier.h" #include "xla/service/platform_util.h" @@ -187,6 +197,17 @@ class HloTestBase : public ::testing::Test { HloPassInterface&& hlo_pass, std::optional> expected); + using FixedMapping = + std::initializer_list>; + + // Creates an HLO module from a template and an optional replacement map and + // runs the given hlo_pass on the module. Validates whether the pass has + // changed the module or not based on expect_change flag. Returns unique_ptr + // to the HLO module for further inspection. + absl::StatusOr> RunAndCheckHloRewrite( + absl::string_view hlo_template, HloPassInterface&& hlo_pass, + bool expect_change = true, FixedMapping params = {}); + // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. From 49b06638ed25282e9d8167aa027817d3ee9f893a Mon Sep 17 00:00:00 2001 From: Mehrdad Khani Date: Thu, 26 Sep 2024 15:16:41 -0700 Subject: [PATCH 342/483] [XLA:MSA] Enable MSA to check if the lowering for an aysnc version of a synchronous slice instruction is available. PiperOrigin-RevId: 679311831 --- .../xla/xla/service/memory_space_assignment/options.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/third_party/xla/xla/service/memory_space_assignment/options.h b/third_party/xla/xla/service/memory_space_assignment/options.h index fb9730ced90641..e7cb78a17bebf7 100644 --- a/third_party/xla/xla/service/memory_space_assignment/options.h +++ b/third_party/xla/xla/service/memory_space_assignment/options.h @@ -61,6 +61,8 @@ using WindowPrefetchDetailFunction = std::function; using WindowPrefetchNotifyOperandAppendedFunction = std::function; +using IsAsyncSliceImplementedFunction = + std::function; // The different options to be passed to the Run() API. struct Options { @@ -124,6 +126,9 @@ struct Options { WindowPrefetchNotifyOperandAppendedFunction notify_operand_appended_fn = [](HloInstruction*, int64_t, int64_t) {}; + IsAsyncSliceImplementedFunction is_async_slice_implemented_fn = + [](const HloInstruction*) { return false; }; + // If true, we will try to reduce scoped allocation buffer size for all // instructions if their operand/output has been allocated in alternate // memory. From 649cb45e83c458678126088317c1f81da436c8b5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 16:23:34 -0700 Subject: [PATCH 343/483] [PJRT] Don't include headers inside xla namespace. PiperOrigin-RevId: 679333386 --- third_party/xla/xla/pjrt/transpose_kernels.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/transpose_kernels.h b/third_party/xla/xla/pjrt/transpose_kernels.h index 18b79bdae2e3f4..cba611d67cc30b 100644 --- a/third_party/xla/xla/pjrt/transpose_kernels.h +++ b/third_party/xla/xla/pjrt/transpose_kernels.h @@ -24,8 +24,6 @@ limitations under the License. #include "xla/compiler_macros.h" -namespace xla { - #ifdef XLA_HAS_SSE2 #include // IWYU pragma: keep #endif @@ -38,6 +36,8 @@ namespace xla { #define XLA_HAS_VEC128 #endif // defined(XLA_HAS_SSE2) || defined(XLA_HAS_ARM_NEON) +namespace xla { + // The transpose microkernels use a general approach of zipping elements from // different rows together. We start zipping together elements of size 1, size 2 // and so-on until we have achieved our transpose. As we increase the number of From 23355777fb2fd3a4e04f1de7a006fbff1792adee Mon Sep 17 00:00:00 2001 From: Raviteja Gorijala Date: Thu, 26 Sep 2024 16:45:00 -0700 Subject: [PATCH 344/483] Update version to 2.19.0 PiperOrigin-RevId: 679340011 --- tensorflow/core/public/version.h | 2 +- tensorflow/tensorflow.bzl | 2 +- tensorflow/tools/pip_package/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index ce31c7a2e1df87..6ca4bceb99ecd8 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -21,7 +21,7 @@ limitations under the License. // Also update tensorflow/tensorflow.bzl and // tensorflow/tools/pip_package/setup.py #define TF_MAJOR_VERSION 2 -#define TF_MINOR_VERSION 18 +#define TF_MINOR_VERSION 19 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index f76b13bef0aecd..d35fbedc4f4413 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -88,7 +88,7 @@ def register_extension_info(**kwargs): # not contain rc or alpha, only numbers. # Also update tensorflow/core/public/version.h # and tensorflow/tools/pip_package/setup.py -VERSION = "2.18.0" +VERSION = "2.19.0" VERSION_MAJOR = VERSION.split(".")[0] two_gpu_tags = ["requires-gpu-nvidia:2", "manual", "no_pip"] diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 0d0ada7899a006..07a4935a9a43c6 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -48,7 +48,7 @@ # result for pip. # Also update tensorflow/tensorflow.bzl and # tensorflow/core/public/version.h -_VERSION = '2.18.0' +_VERSION = '2.19.0' # We use the same setup.py for all tensorflow_* packages and for the nightly From 7c87e5b6dbb3805b1e050edf51bdf91a2299e166 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 17:01:12 -0700 Subject: [PATCH 345/483] Remove some conditional checks on mesh dimensions when generating reshape strategies. PiperOrigin-RevId: 679345175 --- .../auto_sharding/auto_sharding.cc | 90 +++++++++---------- 1 file changed, 40 insertions(+), 50 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 4841c6a5c76e02..06141018c5b8a3 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1691,57 +1691,48 @@ std::unique_ptr CreateReshapeStrategies( const bool only_allow_divisible, const double replicated_penalty, const AutoShardingOption& option, StrategyGroups& strategy_groups, const CallGraph& call_graph) { - const DeviceMesh& device_mesh = cluster_env.device_mesh_; - - int mesh_nn_dims = VectorGreaterThanOneElementCount(device_mesh.dimensions()); std::unique_ptr strategy_group = CreateLeafStrategyGroup( instruction_id, ins, strategy_map, strategy_groups); - if (mesh_nn_dims < 2 || !option.allow_mixed_mesh_shape) { - const HloInstruction* operand = ins->operand(0); - - // Create follow strategies - const StrategyGroup& src_strategy_group = *strategy_map.at(operand); - CHECK(!src_strategy_group.is_tuple); - strategy_group->following = &src_strategy_group; - - for (const auto& src_strategy : src_strategy_group.GetStrategies()) { - std::optional output_spec = - hlo_sharding_util::ReshapeSharding(operand->shape(), ins->shape(), - src_strategy.output_sharding); - - if (!output_spec.has_value()) { - continue; - } - - if (!IsValidTileAssignment(*output_spec)) { - continue; - } - - if (!TileAssignmentMatchesMesh(*output_spec, device_mesh)) { - continue; - } - const std::string name = ToStringSimple(*output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = - ByteSizeOfShapeWithSharding(ins->shape(), output_spec); - std::vector communication_resharding_costs = - CommunicationReshardingCostVector( - src_strategy_group, operand->shape(), - src_strategy.output_sharding, cluster_env); - std::vector memory_resharding_costs = - MemoryReshardingCostVector(src_strategy_group, operand->shape(), - src_strategy.output_sharding, cluster_env); - strategy_group->AddStrategy( - ShardingStrategy({name, - *output_spec, - compute_cost, - communication_cost, - memory_cost, - {communication_resharding_costs}, - {memory_resharding_costs}}), - {src_strategy.output_sharding}); + // Create strategies from operands, but do not follow the operand. We + // anecdotally observe that following the operands causes regressions. + const HloInstruction* operand = ins->operand(0); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); + CHECK(!operand_strategy_group.is_tuple); + + for (const ShardingStrategy& operand_strategy : + operand_strategy_group.GetStrategies()) { + std::optional output_sharding = + hlo_sharding_util::ReshapeSharding(operand->shape(), ins->shape(), + operand_strategy.output_sharding); + + if (!output_sharding.has_value() || + !IsValidTileAssignment(*output_sharding) || + !TileAssignmentMatchesMesh(*output_sharding, + cluster_env.device_mesh_)) { + continue; } + + const std::string name = ToStringSimple(*output_sharding); + double compute_cost = 0, communication_cost = 0; + double memory_cost = + ByteSizeOfShapeWithSharding(ins->shape(), output_sharding); + std::vector communication_resharding_costs = + CommunicationReshardingCostVector( + operand_strategy_group, operand->shape(), + operand_strategy.output_sharding, cluster_env); + std::vector memory_resharding_costs = MemoryReshardingCostVector( + operand_strategy_group, operand->shape(), + operand_strategy.output_sharding, cluster_env); + strategy_group->AddStrategy( + ShardingStrategy({name, + *output_sharding, + compute_cost, + communication_cost, + memory_cost, + {communication_resharding_costs}, + {memory_resharding_costs}}), + {operand_strategy.output_sharding}); } if (strategy_group->GetStrategies().empty()) { @@ -1750,10 +1741,9 @@ std::unique_ptr CreateReshapeStrategies( FillAllStrategiesForArray( ins, ins->shape(), cluster_env, strategy_map, option, replicated_penalty, call_graph, only_allow_divisible, - /* create_replicated_strategies */ true, - /* create_partially_replicated_strategies */ true, *strategy_group); + /*create_replicated_strategies=*/true, + /*create_partially_replicated_strategies=*/true, *strategy_group); } - return strategy_group; } From d492bf7f8ca7ee73de15bc34e5b8fa2ce672ffc0 Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Thu, 26 Sep 2024 17:01:50 -0700 Subject: [PATCH 346/483] [XLA] Don't use while/conditional back pointers Specifically HloComputation::WhileCallInstruction() and ConditionalCallInstruction(). These are broken because in too many places we don't update the references between the instruction and computation after cloning. CallGraph::GetComputationCallers() can be used for the same purposes. Refactor InfeedTokenPropagation to use it. PiperOrigin-RevId: 679345390 --- third_party/xla/xla/hlo/ir/hlo_computation.h | 12 + third_party/xla/xla/service/BUILD | 1 + .../xla/service/infeed_token_propagation.cc | 233 ++++++++---------- .../xla/service/infeed_token_propagation.h | 14 ++ 4 files changed, 136 insertions(+), 124 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index e6463f774cf513..dcafd8480207d6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -787,17 +787,23 @@ class HloComputation { } // Returns if this computation is a body computation of a while. + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] bool IsWhileBodyComputation() const { return instruction_type() == InstructionType::kWhile; } // Returns the owning while call instruction, or nullptr if this is not a // while call body computation. + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] HloInstruction* WhileCallInstruction() const { return instruction_type() == InstructionType::kWhile ? instruction() : nullptr; } + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] void SetWhileCallInstruction(HloInstruction* while_call_instruction) { CHECK(while_call_instruction != nullptr); CHECK(while_call_instruction->opcode() == HloOpcode::kWhile); @@ -805,17 +811,23 @@ class HloComputation { } // Returns if this computation is a branch computation of a conditional. + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] bool IsConditionalBranchComputation() const { return instruction_type() == InstructionType::kConditional; } // Returns the owning conditional call instruction, or nullptr if this is not // a conditional branch computation. + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] HloInstruction* ConditionalCallInstruction() const { return instruction_type() == InstructionType::kConditional ? instruction() : nullptr; } + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] void SetConditionalCallInstruction( HloInstruction* conditional_call_instruction) { CHECK(conditional_call_instruction != nullptr); diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 7277e853f3a131..35d58653848b72 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -8648,6 +8648,7 @@ cc_library( srcs = ["infeed_token_propagation.cc"], hdrs = ["infeed_token_propagation.h"], deps = [ + ":call_graph", ":hlo_dce", ":tuple_simplifier", "//xla:shape_util", diff --git a/third_party/xla/xla/service/infeed_token_propagation.cc b/third_party/xla/xla/service/infeed_token_propagation.cc index c14fe7e1824086..2959293a42d45e 100644 --- a/third_party/xla/xla/service/infeed_token_propagation.cc +++ b/third_party/xla/xla/service/infeed_token_propagation.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include "absl/container/flat_hash_set.h" @@ -30,6 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" #include "xla/service/hlo_dce.h" #include "xla/service/tuple_simplifier.h" #include "xla/shape.h" @@ -129,46 +129,47 @@ absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, tuple, tuple->shape().tuple_shapes_size() - 1)); return input_token_gte; } +} // namespace -absl::Status CanonicalizeConditionalBranch(HloComputation* branch) { - CHECK(branch->IsConditionalBranchComputation()); - CHECK_EQ(branch->num_parameters(), 1); - - // Tuplify the branch parameter if needed. - HloInstruction* parameter = branch->parameter_instruction(0); - if (!parameter->shape().IsTuple()) { - *parameter->mutable_shape() = - ShapeUtil::MakeTupleShape({parameter->shape()}); - HloInstruction* original = branch->AddInstruction( - HloInstruction::CreateGetTupleElement(parameter, 0)); - TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); - } +absl::Status CanonicalizeConditionalInstruction(HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + + for (HloComputation* branch : conditional->branch_computations()) { + // Tuplify the branch parameter if needed. + HloInstruction* parameter = branch->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = branch->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } - // Tuplify the branch tuple if needed. - HloInstruction* conditional = branch->ConditionalCallInstruction(); - int64_t branch_operand_idx = conditional->branch_index(branch) + 1; - HloInstruction* branch_tuple = - conditional->mutable_operand(branch_operand_idx); - if (!branch_tuple->shape().IsTuple()) { - branch_tuple = conditional->parent()->AddInstruction( - HloInstruction::CreateTuple({branch_tuple})); - TF_RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape( - branch_operand_idx, branch_tuple)); - } + // Tuplify the branch tuple if needed. + int64_t branch_operand_idx = conditional->branch_index(branch) + 1; + HloInstruction* branch_tuple = + conditional->mutable_operand(branch_operand_idx); + if (!branch_tuple->shape().IsTuple()) { + branch_tuple = conditional->parent()->AddInstruction( + HloInstruction::CreateTuple({branch_tuple})); + TF_RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + } - // Explicitly disjoin computation parameters from branch inputs, so we can - // insert tokens into the input tuple. - if (branch_tuple->opcode() == HloOpcode::kParameter) { - branch_tuple = ReconstructTuple(branch_tuple); - TF_RETURN_IF_ERROR( - conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); - } + // Explicitly disjoin computation parameters from branch inputs, so we can + // insert tokens into the input tuple. + if (branch_tuple->opcode() == HloOpcode::kParameter) { + branch_tuple = ReconstructTuple(branch_tuple); + TF_RETURN_IF_ERROR( + conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); + } - // Explicitly make the root of the branch a tuple. - HloInstruction* root = branch->root_instruction(); - if (root->opcode() != HloOpcode::kTuple) { - root = ReconstructTuple(root); - branch->set_root_instruction(root); + // Explicitly make the root of the branch a tuple. + HloInstruction* root = branch->root_instruction(); + if (root->opcode() != HloOpcode::kTuple) { + root = ReconstructTuple(root); + branch->set_root_instruction(root); + } } // ConditionalCanonicalizer should have already turned the conditional output @@ -185,18 +186,20 @@ absl::Status CanonicalizeConditionalBranch(HloComputation* branch) { return absl::OkStatus(); } -absl::Status CanonicalizeWhileBody(HloComputation* body) { - CHECK(body->IsWhileBodyComputation()); - CHECK_EQ(body->num_parameters(), 1); +absl::Status CanonicalizeWhileInstruction(HloInstruction* loop) { + CHECK_EQ(loop->opcode(), HloOpcode::kWhile); + HloComputation* body = loop->while_body(); + HloComputation* cond = loop->while_condition(); // Tuplify the body parameter if needed. - HloInstruction* parameter = body->parameter_instruction(0); - if (!parameter->shape().IsTuple()) { - *parameter->mutable_shape() = - ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* body_parameter = body->parameter_instruction(0); + if (!body_parameter->shape().IsTuple()) { + *body_parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({body_parameter->shape()}); HloInstruction* original = body->AddInstruction( - HloInstruction::CreateGetTupleElement(parameter, 0)); - TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + HloInstruction::CreateGetTupleElement(body_parameter, 0)); + TF_RETURN_IF_ERROR( + body_parameter->ReplaceAllUsesWithDifferentShape(original)); } // Tuplify the body root if needed. @@ -207,8 +210,6 @@ absl::Status CanonicalizeWhileBody(HloComputation* body) { } // Tuplify the condition parameter if needed. - HloInstruction* loop = body->WhileCallInstruction(); - HloComputation* cond = loop->while_condition(); HloInstruction* cond_parameter = cond->parameter_instruction(0); if (!cond_parameter->shape().IsTuple()) { *cond_parameter->mutable_shape() = @@ -258,27 +259,20 @@ absl::Status CanonicalizeWhileBody(HloComputation* body) { return absl::OkStatus(); } -absl::StatusOr> -PropagateTokenThroughConditionalBranch(HloInstruction* instruction, - HloInstruction* input_token, - HloInstruction* output_token) { +absl::Status InfeedTokenPropagation::PropagateTokenThroughConditionalBranch() { // Conditional branches can diverge in inputs, but must converge on outputs. - // Fixup every branch of the conditional, since we have to insert a token - // into each branches root. - HloComputation* comp = instruction->parent(); - HloInstruction* next_instruction = comp->ConditionalCallInstruction(); - for (HloComputation* branch : next_instruction->branch_computations()) { - TF_RETURN_IF_ERROR(CanonicalizeConditionalBranch(branch)); - } + HloComputation* comp = dangling_instruction_->parent(); + dangling_instruction_ = call_graph_->GetComputationCallers(comp)[0]; + CHECK_EQ(dangling_instruction_->opcode(), HloOpcode::kConditional); // Insert the output token into each branch. - for (HloComputation* branch : next_instruction->branch_computations()) { + for (HloComputation* branch : dangling_instruction_->branch_computations()) { HloInstruction* root = branch->root_instruction(); if (branch == comp) { TF_RETURN_IF_ERROR( InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); - root->AppendOperand(output_token); + root->AppendOperand(output_token_); } else { TF_RETURN_IF_ERROR( InsertTokenIntoTuple(root, /*add_token_operand=*/true).status()); @@ -290,103 +284,90 @@ PropagateTokenThroughConditionalBranch(HloInstruction* instruction, TF_ASSIGN_OR_RETURN( HloInstruction * input_token_gte, InsertTokenIntoTuple(parameter, /*add_token_operand=*/false)); - TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + TF_RETURN_IF_ERROR(input_token_->ReplaceAllUsesWith(input_token_gte)); // Insert the input token into the branch tuple. - int64_t branch_operand_idx = next_instruction->branch_index(comp) + 1; + int64_t branch_operand_idx = dangling_instruction_->branch_index(comp) + 1; HloInstruction* branch_tuple = - next_instruction->mutable_operand(branch_operand_idx); + dangling_instruction_->mutable_operand(branch_operand_idx); TF_ASSIGN_OR_RETURN( HloInstruction * next_input_token_gte, InsertTokenIntoTuple(branch_tuple, /*add_token_operand=*/true)); - TF_RETURN_IF_ERROR(next_instruction->ReplaceOperandWithDifferentShape( + TF_RETURN_IF_ERROR(dangling_instruction_->ReplaceOperandWithDifferentShape( branch_operand_idx, branch_tuple)); - HloInstruction* next_input_token = + input_token_ = branch_tuple->mutable_operand(next_input_token_gte->tuple_index()); // Insert the output token into conditional instruction. TF_ASSIGN_OR_RETURN( - HloInstruction * next_output_token, - InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + output_token_, + InsertTokenIntoTuple(dangling_instruction_, /*add_token_operand=*/false)); - return std::make_tuple(next_instruction, next_input_token, next_output_token); + return absl::OkStatus(); } -absl::StatusOr> -PropagateTokenThroughWhileBody(HloInstruction* instruction, - HloInstruction* input_token, - HloInstruction* output_token) { +absl::Status InfeedTokenPropagation::PropagateTokenThroughWhileBody() { // While loops need to converge on input and output. - // Fixup the while body. - HloComputation* comp = instruction->parent(); - TF_RETURN_IF_ERROR(CanonicalizeWhileBody(comp)); - HloInstruction* next_instruction = comp->WhileCallInstruction(); + HloComputation* comp = dangling_instruction_->parent(); + dangling_instruction_ = call_graph_->GetComputationCallers(comp)[0]; + CHECK_EQ(dangling_instruction_->opcode(), HloOpcode::kWhile); // Insert the output token into the body root. HloInstruction* root = comp->root_instruction(); TF_RETURN_IF_ERROR( InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); - root->AppendOperand(output_token); + root->AppendOperand(output_token_); // Insert the input token into the body parameter. HloInstruction* body_parameter = comp->parameter_instruction(0); TF_ASSIGN_OR_RETURN( HloInstruction * input_token_gte, InsertTokenIntoTuple(body_parameter, /*add_token_operand=*/false)); - TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + TF_RETURN_IF_ERROR(input_token_->ReplaceAllUsesWith(input_token_gte)); // Insert the input token into the condition parameter. - HloComputation* cond = next_instruction->while_condition(); + HloComputation* cond = dangling_instruction_->while_condition(); HloInstruction* cond_parameter = cond->parameter_instruction(0); TF_RETURN_IF_ERROR( InsertTokenIntoTuple(cond_parameter, /*add_token_operand=*/false) .status()); // Insert the input token into the while tuple. - HloInstruction* while_tuple = next_instruction->mutable_operand(0); + HloInstruction* while_tuple = dangling_instruction_->mutable_operand(0); TF_ASSIGN_OR_RETURN( - HloInstruction * next_input_token, + input_token_, InsertTokenIntoTuple(while_tuple, /*add_token_operand=*/true)); TF_RETURN_IF_ERROR( - next_instruction->ReplaceOperandWithDifferentShape(0, while_tuple)); + dangling_instruction_->ReplaceOperandWithDifferentShape(0, while_tuple)); // Insert the input token into the while instruction. TF_ASSIGN_OR_RETURN( - HloInstruction * next_output_token, - InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + output_token_, + InsertTokenIntoTuple(dangling_instruction_, /*add_token_operand=*/false)); - return std::make_tuple(next_instruction, next_input_token, next_output_token); + return absl::OkStatus(); } -absl::Status PropagateToken(HloInstruction* instruction, - HloInstruction* input_token, - HloInstruction* output_token) { - HloComputation* comp = instruction->parent(); +absl::Status InfeedTokenPropagation::PropagateToken() { + HloComputation* comp = dangling_instruction_->parent(); if (comp->IsEntryComputation()) { return absl::OkStatus(); } + VLOG(2) << "Propagating tokens for: " << dangling_instruction_->name(); - HloInstruction* next_instruction = nullptr; - HloInstruction* next_input_token = nullptr; - HloInstruction* next_output_token = nullptr; - if (comp->IsConditionalBranchComputation()) { - // TODO: b/368327832 - Skip handling sharding until it is removed. - if (comp->ConditionalCallInstruction()->has_sharding()) { - return absl::OkStatus(); - } - TF_ASSIGN_OR_RETURN( - std::tie(next_instruction, next_input_token, next_output_token), - PropagateTokenThroughConditionalBranch(instruction, input_token, - output_token)); - } else if (comp->IsWhileBodyComputation()) { - // TODO: b/368327832 - Skip handling sharding until it is removed. - if (comp->WhileCallInstruction()->has_sharding()) { - return absl::OkStatus(); - } - TF_ASSIGN_OR_RETURN( - std::tie(next_instruction, next_input_token, next_output_token), - PropagateTokenThroughWhileBody(instruction, input_token, output_token)); + HloInstruction* caller = call_graph_->GetComputationCallers(comp)[0]; + // TODO: b/368327832 - Skip handling sharding until it is removed. + if (caller->has_sharding()) { + return absl::OkStatus(); + } + if (caller->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR(CanonicalizeConditionalInstruction(caller)); + TF_RETURN_IF_ERROR(PropagateTokenThroughConditionalBranch()); + } else if (caller->opcode() == HloOpcode::kWhile && + comp == caller->while_body()) { + TF_RETURN_IF_ERROR(CanonicalizeWhileInstruction(caller)); + TF_RETURN_IF_ERROR(PropagateTokenThroughWhileBody()); } else { // We only expect to encounter computations behind while and conditional // instructions. In the case of it being behind a while condition, there is @@ -396,13 +377,9 @@ absl::Status PropagateToken(HloInstruction* instruction, VLOG(2) << "Unhandled computation: " << comp->name(); return absl::OkStatus(); } - CHECK_NE(next_instruction, nullptr); - CHECK_NE(next_input_token, nullptr); - CHECK_NE(next_output_token, nullptr); - return PropagateToken(next_instruction, next_input_token, next_output_token); + return PropagateToken(); } -} // namespace absl::StatusOr InfeedTokenPropagation::Run( HloModule* module, @@ -428,22 +405,30 @@ absl::StatusOr InfeedTokenPropagation::Run( } } } + bool changed = !dangling_infeeds.empty() || !dangling_outfeeds.empty(); + + if (changed) { + call_graph_ = CallGraph::Build(module); + if (!call_graph_->IsFlattened()) { + return FailedPrecondition( + "Call graph must be flattened before infeed token propagation."); + } + } for (HloInstruction* dangling_infeed : dangling_infeeds) { - HloInstruction* input_token = dangling_infeed->mutable_operand(0); - HloInstruction* output_token = dangling_infeed->AddInstruction( + dangling_instruction_ = dangling_infeed; + input_token_ = dangling_infeed->mutable_operand(0); + output_token_ = dangling_infeed->AddInstruction( HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); - TF_RETURN_IF_ERROR( - PropagateToken(dangling_infeed, input_token, output_token)); + TF_RETURN_IF_ERROR(PropagateToken()); } for (HloInstruction* dangling_outfeed : dangling_outfeeds) { - HloInstruction* input_token = dangling_outfeed->mutable_operand(1); - HloInstruction* output_token = dangling_outfeed; - TF_RETURN_IF_ERROR( - PropagateToken(dangling_outfeed, input_token, output_token)); + dangling_instruction_ = dangling_outfeed; + input_token_ = dangling_outfeed->mutable_operand(1); + output_token_ = dangling_outfeed; + TF_RETURN_IF_ERROR(PropagateToken()); } - bool changed = !dangling_infeeds.empty() || !dangling_outfeeds.empty(); if (changed) { TF_RETURN_IF_ERROR( TupleSimplifier().Run(module, execution_threads).status()); diff --git a/third_party/xla/xla/service/infeed_token_propagation.h b/third_party/xla/xla/service/infeed_token_propagation.h index cc6994a62a98a9..eeed81d4da9477 100644 --- a/third_party/xla/xla/service/infeed_token_propagation.h +++ b/third_party/xla/xla/service/infeed_token_propagation.h @@ -16,13 +16,17 @@ limitations under the License. #ifndef XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ #define XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ +#include #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/call_graph.h" namespace xla { // Finds dangling infeed/outfeed tokens inside nested computations and bubbles @@ -39,6 +43,16 @@ class InfeedTokenPropagation : public HloModulePass { absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + absl::Status PropagateToken(); + absl::Status PropagateTokenThroughWhileBody(); + absl::Status PropagateTokenThroughConditionalBranch(); + + std::unique_ptr call_graph_; + HloInstruction* dangling_instruction_ = nullptr; + HloInstruction* input_token_ = nullptr; + HloInstruction* output_token_ = nullptr; }; } // namespace xla From 46e0c13201d91a61d8e179bc9426436fb89b1b53 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 17:13:04 -0700 Subject: [PATCH 347/483] #tf-data-service Fix null pointer access. PiperOrigin-RevId: 679348546 --- .../core/kernels/data/experimental/data_service_dataset_op.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index a46aa2937b6f1e..24d4fa8899d9e9 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -342,7 +342,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { ctx->cancellation_manager(), [this]() { data_service_client_.Cancel(); }, &deregister_fn_)); tsl::AllocatorAttributes attrs; - attrs.set_gpu_compatible(ctx->options()->service_options().pinned()); + if (ctx->options() != nullptr) { + attrs.set_gpu_compatible(ctx->options()->service_options().pinned()); + } return data_service_client_.Initialize(ctx->accelerator_device_info(), ctx->allocator(attrs)); } From 6a4f1d7f888deb3e3d6e3bc80cfbd66231ea4638 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 17:15:36 -0700 Subject: [PATCH 348/483] Better naming for loop fusion. PiperOrigin-RevId: 679349330 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 6 ++++-- third_party/xla/xla/hlo/ir/hlo_instruction.h | 3 ++- third_party/xla/xla/hlo/ir/hlo_instructions.cc | 6 ++++-- third_party/xla/xla/hlo/ir/hlo_instructions.h | 3 ++- third_party/xla/xla/service/instruction_fusion.cc | 7 +++++-- .../xla/xla/service/propagate_original_value_test.cc | 2 +- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 4b200a2c831c41..c0fec2434de508 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -2174,8 +2174,10 @@ HloInstruction::CreateDynamicReshape( } /* static */ std::unique_ptr HloInstruction::CreateFusion( - const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - return std::make_unique(shape, fusion_kind, fused_root); + const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root, + absl::string_view prefix) { + return std::make_unique(shape, fusion_kind, fused_root, + prefix); } /* static */ std::unique_ptr HloInstruction::CreateFusion( diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 3ef42bfc41adc6..49bd9698605941 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1347,7 +1347,8 @@ class HloInstruction { // "fused_root". Additional instructions can be added to the fusion // instruction with the method FuseInstruction. static std::unique_ptr CreateFusion( - const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); + const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root, + absl::string_view prefix = ""); static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index 3909ced6f00f92..3960414f1ac73d 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -2172,11 +2172,13 @@ void HloCallableInstruction::RecursivelySetComputationsThreadName( HloFusionInstruction::HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, - HloInstruction* fused_root) + HloInstruction* fused_root, + absl::string_view prefix) : HloCallableInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { CHECK(fused_root != nullptr); - SetAndSanitizeName(HloOpcodeString(opcode())); + SetAndSanitizeName(absl::StrCat(prefix, HloOpcodeString(opcode()))); + set_parent(fused_root->parent()); set_metadata(fused_root->metadata()); set_frontend_attributes(fused_root->frontend_attributes()); diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index 8a0fca0bade44d..8768150f793edd 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -1439,7 +1439,8 @@ class HloCallableInstruction : public HloInstruction { class HloFusionInstruction : public HloCallableInstruction { public: explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, - HloInstruction* fused_root); + HloInstruction* fused_root, + absl::string_view prefix = ""); explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, absl::Span operands, diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index 9fb259061e5de6..3e9240c9c205aa 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -718,8 +718,11 @@ HloInstruction* InstructionFusion::AddFusionInstruction( fusion_instruction->set_fusion_kind(kind); } } else { - fusion_instruction = computation->AddInstruction( - HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); + fusion_instruction = + computation->AddInstruction(HloInstruction::CreateFusion( + consumer->shape(), kind, consumer, + absl::StrCat(HloOpcodeString(producer->opcode()), "_", + HloOpcodeString(consumer->opcode()), "_"))); TF_CHECK_OK(computation->ReplaceInstruction(consumer, fusion_instruction)); } fusion_instruction->set_called_computations_execution_thread( diff --git a/third_party/xla/xla/service/propagate_original_value_test.cc b/third_party/xla/xla/service/propagate_original_value_test.cc index 764786ceef23dd..e1077695c3848c 100644 --- a/third_party/xla/xla/service/propagate_original_value_test.cc +++ b/third_party/xla/xla/service/propagate_original_value_test.cc @@ -62,7 +62,7 @@ CHECK: ROOT %[[ADD:.*]] = u32[2]{0:T(256)} add(%[[PAD]], %[[PAD1]]), origin={{ CHECK: ENTRY %test CHECK: %Arg_0 = s32[]{:T(256)} parameter(0), origin={{[{]}}{"Arg_0"} -CHECK: ROOT %fusion = u32[2]{0:T(256)} fusion(%Arg_0), kind=kLoop, calls=%fused_computation +CHECK: ROOT %pad_add_fusion = u32[2]{0:T(256)} fusion(%Arg_0), kind=kLoop, calls=%fused_computation )"); } From 75379dcaf9d2211e389f7cfd5795b6d32475889c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 17:27:32 -0700 Subject: [PATCH 349/483] Calculate different flops for different nvidia gpu. Apply device flop adjustment for hlo op profiles. Calculate nvidia gpu shared memory bandwidth to be used in roofline analysis. PiperOrigin-RevId: 679353085 --- tensorflow/core/profiler/convert/BUILD | 2 +- .../profiler/convert/xplane_to_op_stats.cc | 24 +- .../convert/xplane_to_op_stats_test.cc | 8 +- tensorflow/core/profiler/utils/BUILD | 48 +++ .../core/profiler/utils/device_caps_utils.cc | 1 - .../profiler/utils/hardware_type_utils.cc | 318 ++++++++++++++---- .../core/profiler/utils/hardware_type_utils.h | 38 +++ .../utils/hardware_type_utils_test.cc | 66 ++++ .../profiler/utils/xprof_gpu_cost_analysis.cc | 98 ++++++ .../profiler/utils/xprof_gpu_cost_analysis.h | 53 +++ .../utils/xprof_gpu_cost_analysis_test.cc | 138 ++++++++ .../tsl/tsl/profiler/utils/xplane_schema.cc | 1 + .../tsl/tsl/profiler/utils/xplane_schema.h | 4 +- .../backends/profiler/gpu/cupti_collector.cc | 18 + 14 files changed, 736 insertions(+), 81 deletions(-) create mode 100644 tensorflow/core/profiler/utils/hardware_type_utils_test.cc create mode 100644 tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc create mode 100644 tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h create mode 100644 tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 82538b6392aa34..b1cafe647a323e 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -326,7 +326,6 @@ cc_library( "//tensorflow/core/profiler/utils:hardware_type_utils", "//tensorflow/core/profiler/utils:hlo_proto_map", "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:math_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", @@ -334,6 +333,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@local_tsl//tsl/profiler/utils:math_utils", "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", "@local_tsl//tsl/profiler/utils:xplane_schema", diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 21326dacbf98df..5774010604b4d0 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -38,11 +38,11 @@ limitations under the License. #include "tensorflow/core/profiler/utils/hardware_type_utils.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/math_utils.h" #include "tsl/profiler/utils/tf_xplane_visitor.h" #include "tsl/profiler/utils/tpu_xplane_utils.h" #include "tsl/profiler/utils/xplane_schema.h" @@ -80,14 +80,20 @@ PerfEnv MakePerfEnv(double peak_tera_flops_per_second, PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane) { DeviceCapabilities cap = GetDeviceCaps(device_plane); if (!absl::StartsWith(device_plane.name(), kTpuPlanePrefix)) { - return MakePerfEnv( - tsl::profiler::GigaToTera(GetFlopMaxThroughputPerSM(cap)) * - cap.num_cores(), - // Ideally, the cap should report separate hbm BW, for now set to same. - {tsl::profiler::UniToGiga(cap.memory_bandwidth()), - tsl::profiler::UniToGiga(cap.memory_bandwidth()), - tsl::profiler::UniToGiga(cap.memory_bandwidth()), - tsl::profiler::UniToGiga(cap.memory_bandwidth())}); + double peak_tera_flops_per_second = + cap.num_cores() * + tsl::profiler::GigaToTera(GetFlopMaxThroughputPerSM(cap)); + double hbm_bw_giga_bytes_per_second = + tsl::profiler::UniToGiga(cap.memory_bandwidth()); + double shm_giga_bytes_per_second = + cap.num_cores() * + tsl::profiler::UniToGiga(GetSharedMemoryBandwidthPerSM(cap)); + // Note that treat SRAM_RD and SRAM_WR as the same. So in future, we could + // only use one for shared memory / L1 cache, one for another like L2. + return MakePerfEnv(peak_tera_flops_per_second, + {/*HBM_RW=*/hbm_bw_giga_bytes_per_second, + /*SRAM_RD=*/shm_giga_bytes_per_second, + /*SRAM_WR=*/shm_giga_bytes_per_second}); } else { XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(&device_plane); auto peak_tera_flops_per_second = diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc index 64971ef99cbc25..01bfb6c9c2f575 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc @@ -86,12 +86,16 @@ TEST(ConvertXPlaneToOpStats, GpuPerfEnv) { TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), options, &op_stats)); const PerfEnv& perf_env = op_stats.perf_env(); - EXPECT_NEAR(141, perf_env.peak_tera_flops_per_second(), kMaxError); + // Change to lower flops number that we do not use sum of the tensor core peak + // flops and the cuda core peak flops together as peak flops. Only use the + // tensor core peak flops as all those white papers are using. + EXPECT_NEAR(125.34, perf_env.peak_tera_flops_per_second(), kMaxError); EXPECT_NEAR( 900, perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_HBM_RW), kMaxError); - EXPECT_NEAR(156.67, perf_env.ridge_point(), kMaxError); + // Ridge point changed accordingly from above peak flops change. + EXPECT_NEAR(139.26, perf_env.ridge_point(), kMaxError); } TEST(ConvertXPlaneToOpStats, GpuRunEnvironment) { diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 3a0164bee093c6..258c83545fe55e 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -56,7 +56,21 @@ cc_library( ":xplane_schema", "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/strings", + "@local_tsl//tsl/profiler/utils:math_utils", + ], +) + +tf_cc_test( + name = "hardware_type_utils_test", + srcs = ["hardware_type_utils_test.cc"], + deps = [ + ":hardware_type_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@local_tsl//tsl/profiler/utils:math_utils", ], ) @@ -450,3 +464,37 @@ cc_library( hdrs = ["hlo_module_utils.h"], deps = ["@local_xla//xla/hlo/ir:hlo"], ) + +cc_library( + name = "xprof_gpu_cost_analysis", + srcs = ["xprof_gpu_cost_analysis.cc"], + hdrs = ["xprof_gpu_cost_analysis.h"], + visibility = [":friends"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@local_xla//xla:shape_util", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/service:hlo_cost_analysis", + "@local_xla//xla/service/gpu/model:gpu_hlo_cost_analysis", + ], +) + +tf_cc_test( + name = "xprof_gpu_cost_analysis_test", + srcs = ["xprof_gpu_cost_analysis_test.cc"], + deps = [ + ":xprof_gpu_cost_analysis", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:shape_util", + "@local_xla//xla:test_helpers", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/service:hlo_cost_analysis", + "@local_xla//xla/service/gpu/model:hlo_op_profiles", + "@local_xla//xla/tests:hlo_test_base", + "@local_xla//xla/tests:xla_internal_test_main", + ], +) diff --git a/tensorflow/core/profiler/utils/device_caps_utils.cc b/tensorflow/core/profiler/utils/device_caps_utils.cc index 5e8edea62493f8..e795081311f28c 100644 --- a/tensorflow/core/profiler/utils/device_caps_utils.cc +++ b/tensorflow/core/profiler/utils/device_caps_utils.cc @@ -81,7 +81,6 @@ DeviceCapabilities GetDeviceCaps(const XPlane& plane) { break; } }); - return caps; } diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc index ad3682ad99289c..35bc8a9def667b 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils.cc +++ b/tensorflow/core/profiler/utils/hardware_type_utils.cc @@ -15,92 +15,276 @@ limitations under the License. #include "tensorflow/core/profiler/utils/hardware_type_utils.h" +#include + +#include "absl/container/btree_map.h" #include "absl/strings/match.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/math_utils.h" namespace tensorflow { namespace profiler { namespace { -// Get theoretical upperbound of single precision FMA throughput of the GPU per -// cycle per streaming multiprocessor. -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions__throughput-native-arithmetic-instructions -uint32 GetFmaMaxThroughputPerSMPerCycle(const DeviceCapabilities& device_cap) { - if (device_cap.device_vendor() == kDeviceVendorNvidia) { - uint32 n_fp32_cores = 0; - uint32 n_tc_cores = 0; - switch (device_cap.compute_capability().major()) { - case 2: - // Fermi - n_fp32_cores = 32; - break; - case 3: - // Kepler - n_fp32_cores = 192; - break; - case 5: - // Maxwell - n_fp32_cores = 128; - break; - case 6: - // Pascal - if (device_cap.compute_capability().minor() > 0) { - // Pascal SM61/62 - n_fp32_cores = 128; - } else { - // Pascal SM60 - n_fp32_cores = 64; - } - break; - case 7: - // Volta and Turing - n_fp32_cores = 64; - n_tc_cores = 8; - break; - case 8: - // Ampere - if (device_cap.compute_capability().minor() >= 6) { - // Ampere SM86 - n_fp32_cores = 128; - } else { - // Ampere SM80 - n_fp32_cores = 64; - } - n_tc_cores = 4; - break; - default: - LOG(ERROR) << "Invalid GPU compute capability."; - break; - } - // GPU TensorCore can execute 64 FMAs per cycle. - // https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/ - return n_fp32_cores + n_tc_cores * 64; - } else if (device_cap.device_vendor() == kDeviceVendorAMD) { - uint32_t n_xdlops = 0; - uint32_t n_fp32_cores = 0; +// The calculation methods is referred from Nvidia developer forum: +// https://forums.developer.nvidia.com/t/how-to-calculate-the-tensor-core-fp16-performance-of-h100/244727 +// Below data are calculated from the various NVidia whitepapers/specs. + +// https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_9_0 = { + .cuda_core = + { + .fp64_tflops = 128, + .fp32_tflops = 256, + .bf16_tflops = 512, + .fp16_tflops = 512, + .int8_tops = 1024, + }, + .tensor_core = + { + .fp64_tflops = 256, + .fp32_tflops = 2048, + .bf16_tflops = 4096, + .fp16_tflops = 4096, + .fp8_tflops = 8192, + .int8_tops = 8192, + }, + .has_tensor_core_sparsity_support = true, +}; + +// https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_9 = { + .cuda_core = + { + .fp64_tflops = 128, + .fp32_tflops = 256, + .bf16_tflops = 256, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = + { + .fp32_tflops = 512, + .bf16_tflops = 1024, + .fp16_tflops = 1024, + .fp8_tflops = 2048, + .int8_tops = 2048, + .int4_tops = 4096, + }, + .has_tensor_core_sparsity_support = true, +}; + +// https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_6 = { + .cuda_core = + { + .fp64_tflops = 128, + .fp32_tflops = 256, + .bf16_tflops = 256, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = + { + .fp32_tflops = 256, + .bf16_tflops = 512, + .fp16_tflops = 1024, + .int8_tops = 2048, + .int4_tops = 4096, + }, + .has_tensor_core_sparsity_support = true, +}; - if (device_cap.compute_capability().major() <= 9) { - n_fp32_cores = 64; - } else { - n_fp32_cores = 32; +// https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_0 = { + .cuda_core = + { + .fp64_tflops = 64, + .fp32_tflops = 128, + .bf16_tflops = 256, + .fp16_tflops = 512, + .int8_tops = 512, + }, + .tensor_core = + { + .fp64_tflops = 128, + .fp32_tflops = 1024, + .bf16_tflops = 2048, + .fp16_tflops = 2048, + .int8_tops = 4096, + }, + .has_tensor_core_sparsity_support = true, +}; + +// https://images.nvidia.com/aem-dam/en-zz/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_7_5 = { + .cuda_core = + { + .fp64_tflops = 64, + .fp32_tflops = 128, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = + { + .fp16_tflops = 1024, + .int8_tops = 2048, + .int4_tops = 4096, + }, + .has_tensor_core_sparsity_support = false, +}; + +// https://images.nvidia.com/content/volta-architecture/pdf/volta-architecture-whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_7_0 = { + .cuda_core = + { + .fp64_tflops = 64, + .fp32_tflops = 128, + .bf16_tflops = 0.0, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = + { + .fp16_tflops = 1024, + }, + .has_tensor_core_sparsity_support = false, +}; + +// https://images.nvidia.com/content/pdf/tesla/whitepaper/pascal-architecture-whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_6_1 = { + .cuda_core = + { + .fp64_tflops = 8, + .fp32_tflops = 256, + .fp16_tflops = 4, + .int8_tops = 1024, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +// https://images.nvidia.com/content/pdf/tesla/whitepaper/pascal-architecture-whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_6_0 = { + .cuda_core = + { + .fp64_tflops = 64, + .fp32_tflops = 128, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-product-literature/NVIDIA-Kepler-GK110-GK210-Architecture-Whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_5_0 = { + .cuda_core = + { + .fp64_tflops = 4, + .fp32_tflops = 256, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +// https://www.nvidia.com/content/PDF/product-specifications/GeForce_GTX_680_Whitepaper_FINAL.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_3_0 = { + .cuda_core = + { + .fp64_tflops = 128, + .fp32_tflops = 384, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_2_0 = { + .cuda_core = + { + .fp64_tflops = 8, + .fp32_tflops = 64, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +GpuFlopCapabilities GetNvidiaFlopCapsPerSMPerCycle(int major_comp_cap, + int minor_comp_cap) { + static const auto& kPerSMFlopCapsTable = + *new absl::btree_map{ + // TODO: Add incoming blackwell, and other old GPUS + {9000, &kComputeCap_PerSM_PerCycle_9_0}, + {8090, &kComputeCap_PerSM_PerCycle_8_9}, + {8060, &kComputeCap_PerSM_PerCycle_8_6}, + {8000, &kComputeCap_PerSM_PerCycle_8_0}, + {7050, &kComputeCap_PerSM_PerCycle_7_5}, + {7000, &kComputeCap_PerSM_PerCycle_7_0}, + {6010, &kComputeCap_PerSM_PerCycle_6_1}, + {6000, &kComputeCap_PerSM_PerCycle_6_0}, + {5000, &kComputeCap_PerSM_PerCycle_5_0}, + {3000, &kComputeCap_PerSM_PerCycle_3_0}, + {2000, &kComputeCap_PerSM_PerCycle_2_0}, + }; + + const int normalized_compute_cap = + major_comp_cap * 1000 + minor_comp_cap * 10; + GpuFlopCapabilities flops_cap{}; + auto it = kPerSMFlopCapsTable.lower_bound(normalized_compute_cap); + if (it == kPerSMFlopCapsTable.end()) { + LOG(WARNING) << "GPU compute capability " << major_comp_cap << "." + << minor_comp_cap << " is too old to support."; + } else { + flops_cap = *it->second; + if (it->first != normalized_compute_cap) { + LOG(WARNING) << "GPU compute capability " << major_comp_cap << "." + << minor_comp_cap + << " is not found. Use the highest compute cap known " + << (it->first / 1000) << "." << ((it->first % 1000) / 10) + << " instead."; } - // TODO(rocm-profiler): verify with new devices - return n_fp32_cores + n_xdlops * 1; + } + return flops_cap; +} + +GpuFlopCapabilities GetGpuFlopCapabilitiesPerSM( + const DeviceCapabilities& device_cap) { + GpuFlopCapabilities flops_cap{}; + if (device_cap.device_vendor() == kDeviceVendorNvidia) { + flops_cap = + GetNvidiaFlopCapsPerSMPerCycle(device_cap.compute_capability().major(), + device_cap.compute_capability().minor()); } else { - LOG(ERROR) << "Unknown device vendor " << device_cap.device_vendor(); - return 0; + LOG(WARNING) << "Unsupported device vendor " << device_cap.device_vendor(); } + + flops_cap.ScaleWith(device_cap.clock_rate_in_ghz()); + return flops_cap; } } // namespace double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap) { - // One FMA = 2 floating point operations, one multiply and one add. - return GetFmaMaxThroughputPerSMPerCycle(device_cap) * 2 * - device_cap.clock_rate_in_ghz(); + GpuFlopCapabilities sm_flops = GetGpuFlopCapabilitiesPerSM(device_cap); + double result = std::max( + {sm_flops.cuda_core.fp32_tflops, sm_flops.cuda_core.fp16_tflops, + sm_flops.tensor_core.fp32_tflops, sm_flops.tensor_core.fp16_tflops}); + VLOG(3) << "GetFlopMaxThroughputPerSM get result: " << result << " GFLOPs"; + return result; +} + +double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap) { + // https://docs.nvidia.com/gameworks/content/developertools/desktop/analysis/report/cudaexperiments/kernellevel/memorystatisticsshared.htm + // Compute capability 2.0, each bank has bandwidth of 4 bytes per 2 cycles. + // For compute capability 3.0 and above, each bank has bandwidth 8 bytes per + // cycle. Each SM has 32 banks. + double transaction_byts_per_cycle = + device_cap.compute_capability().major() <= 2 ? (32 * 4 / 2) : (32 * 8); + double GiBPS = transaction_byts_per_cycle * device_cap.clock_rate_in_ghz(); + return tsl::profiler::GigaToUni(GiBPS); } absl::string_view GpuModelName(const DeviceCapabilities& device_cap) { diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.h b/tensorflow/core/profiler/utils/hardware_type_utils.h index 894b8c5753805e..41b1bd4b65471c 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils.h +++ b/tensorflow/core/profiler/utils/hardware_type_utils.h @@ -22,10 +22,48 @@ limitations under the License. namespace tensorflow { namespace profiler { +struct GpuFlopCapabilities { + struct FlopCapabilityOnPrecisions { + double fp64_tflops = 0; + double fp32_tflops = 0; // also for tf32 for nvidia tensor core + double bf16_tflops = 0; + double fp16_tflops = 0; + double fp8_tflops = 0; + double int8_tops = 0; + double fp4_tflops = 0; + double int4_tops = 0; + + void ScaleWith(double scale) { + fp64_tflops *= scale; + fp32_tflops *= scale; + bf16_tflops *= scale; + fp16_tflops *= scale; + fp8_tflops *= scale; + int8_tops *= scale; + fp4_tflops *= scale; + int4_tops *= scale; + } + }; + + FlopCapabilityOnPrecisions cuda_core; + FlopCapabilityOnPrecisions tensor_core; + bool has_tensor_core_sparsity_support = false; + + void ScaleWith(double scale) { + cuda_core.ScaleWith(scale); + tensor_core.ScaleWith(scale); + } +}; + // Get peak single precision throughput of the GPU in GFLOPS per // streaming multiprocessor. +// TODO: Need design on how to use the sparsity capability of FLOPs. double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap); +// for Nvidia GPU, return shared memory bandwidth in Bytes Per Second on +// one single SM given the GPU core freq in device_cap. +double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap); + // Returns the GPU model name from the given DeviceCapabilities. // For nvidia GPUs, the name is like "Nvidia GPU (Kepler)" or "Nvidia GPU // (Turing)". For AMD GPUs, the name is like "AMD GPU - gfx-10XX series". diff --git a/tensorflow/core/profiler/utils/hardware_type_utils_test.cc b/tensorflow/core/profiler/utils/hardware_type_utils_test.cc new file mode 100644 index 00000000000000..f97ccc6fecd40a --- /dev/null +++ b/tensorflow/core/profiler/utils/hardware_type_utils_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/utils/hardware_type_utils.h" + +#include "tensorflow/core/platform/test.h" +#include "tsl/profiler/utils/math_utils.h" + +namespace tensorflow { +namespace profiler { +namespace { + +TEST(HardwareTypeUtilsTest, H100PeakComputTFlops) { + DeviceCapabilities device_cap; + // For NVIDIA H100 PCIe 80 GB, according to + // https://resources.nvidia.com/en-us-data-center-overview/gtc22-whitepaper-hopper + // https://www.techpowerup.com/gpu-specs/h100-pcie-80-gb.c3899 + device_cap.set_clock_rate_in_ghz(1.620); + device_cap.set_num_cores(114); + device_cap.set_memory_size_in_bytes( + tsl::profiler::GibiToGiga(tsl::profiler::GigaToUni(80))); + device_cap.set_memory_bandwidth(tsl::profiler::GigaToUni(2.04 * 1024)); + device_cap.set_device_vendor("Nvidia"); + device_cap.mutable_compute_capability()->set_major(9); + device_cap.mutable_compute_capability()->set_minor(0); + + // Get target TFLOPS per SM and check. + double peak_tflops = + GetFlopMaxThroughputPerSM(device_cap) * device_cap.num_cores() / 1000.0; + EXPECT_NEAR(peak_tflops, 756, /*abs_error=*/1.0); +} + +TEST(HardwareTypeUtilsTest, A100PeakComputTFlops) { + DeviceCapabilities device_cap; + // For NVIDIA A100 SXM4 80 GB, according to: + // https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + // https://www.techpowerup.com/gpu-specs/a100-sxm4-80-gb.c3746 + device_cap.set_clock_rate_in_ghz(1.410); + device_cap.set_num_cores(108); + device_cap.set_memory_size_in_bytes( + tsl::profiler::GibiToGiga(tsl::profiler::GigaToUni(80))); + device_cap.set_memory_bandwidth(tsl::profiler::GigaToUni(2.04 * 1024)); + device_cap.set_device_vendor("Nvidia"); + device_cap.mutable_compute_capability()->set_major(8); + device_cap.mutable_compute_capability()->set_minor(0); + + double peak_tflops = + GetFlopMaxThroughputPerSM(device_cap) * device_cap.num_cores() / 1000.0; + EXPECT_NEAR(peak_tflops, 312, /*abs_error=*/1.0); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc new file mode 100644 index 00000000000000..9df196901de411 --- /dev/null +++ b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc @@ -0,0 +1,98 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" + +namespace tensorflow { +namespace profiler { + +namespace { + +std::vector GetInputBitwidths(const xla::HloInstruction& hlo) { + std::vector input_bitwidths; + for (const auto& operand : hlo.operands()) { + switch (operand->shape().element_type()) { + case xla::PRIMITIVE_TYPE_INVALID: + case xla::TUPLE: + case xla::OPAQUE_TYPE: + case xla::TOKEN: + break; + default: + input_bitwidths.push_back( + xla::primitive_util::BitWidth(operand->shape().element_type())); + } + } + return input_bitwidths; +} + +} // namespace + +absl::Status XProfGpuCostAnalysis::Postprocess(const xla::HloInstruction* hlo) { + if (hlo == nullptr) { + return absl::OkStatus(); + } + + uint32_t flop_rate_adjustment = 1; + float model_flops = current_properties_[kFlopsKey]; + // Calculate adjustment of device flops based on input bit widths. + // This provide most general adjustment for all ops, and for all gpus. + // TODO: Add adjustment for specific GPUs. + std::vector input_bitwidths = GetInputBitwidths(*hlo); + if (!input_bitwidths.empty()) { + int max_input_bitwidth = + *std::max_element(input_bitwidths.begin(), input_bitwidths.end()); + if (model_flops) { + // for int8/fp8, 2x flops assumed comparing with fp16 flops(most of + // recent GPU models); for int4, 4x of model flops assumed comparing + // with fp16 flops. (like Nvidia T4, 3090). It will be more precise + // after adjustment based on specific GPUs mentioned above. + switch (max_input_bitwidth) { + case 8: + flop_rate_adjustment = 2; + break; + case 4: + flop_rate_adjustment = 4; + break; + } + } + } + current_properties_[kDeviceFlopsAdjustment] = + model_flops - model_flops / flop_rate_adjustment; + return xla::gpu::GpuHloCostAnalysis::Postprocess(hlo); +} + +std::unique_ptr +XProfGpuCostAnalysis::CreateNestedCostAnalysis() { + return std::make_unique(options_); +} + +int64_t XProfGpuCostAnalysis::GetDeviceFlopsAdjustment( + const xla::HloInstruction& hlo) { + return GetPropertyForHlo(hlo, kDeviceFlopsAdjustment, hlo_properties_); +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h new file mode 100644 index 00000000000000..6977295c76939b --- /dev/null +++ b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" + +namespace tensorflow { +namespace profiler { + +// XProfGpuCostAnalysis provides additional cost analysis for XProf, which +// normalizes the flops to the device flops based on input bit widths. +class XProfGpuCostAnalysis : public xla::gpu::GpuHloCostAnalysis { + public: + explicit XProfGpuCostAnalysis(const xla::HloCostAnalysis::Options& options) + : xla::gpu::GpuHloCostAnalysis(options) {} + + absl::Status Postprocess(const xla::HloInstruction* hlo) override; + + int64_t GetDeviceFlopsAdjustment(const xla::HloInstruction& hlo); + + protected: + std::unique_ptr CreateNestedCostAnalysis() override; + + private: + static inline constexpr absl::string_view kDeviceFlopsAdjustment = + "device_flops_adjustment"; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc new file mode 100644 index 00000000000000..2586e131b53a44 --- /dev/null +++ b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/test_helpers.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace profiler { + +class XprofGpuHloCostAnalysisTest : public xla::HloTestBase { + xla::HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { + return [&](const xla::Shape& shape) { + constexpr int64_t kPointerSize = 8; + return xla::ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + } + + public: + xla::HloCostAnalysis::Options options_{ + ShapeSizeBytesFunction(), + /*per_second_rates=*/{}, + /*min_latencies_seconds=*/{}, + /*count_multiple_input_accesses=*/true}; + XProfGpuCostAnalysis analysis_{options_}; + XprofGpuHloCostAnalysisTest() : xla::HloTestBase() {} +}; + +TEST_F(XprofGpuHloCostAnalysisTest, Fp16GemmNoAdjustment) { + absl::string_view hlo_string = R"( +HloModule r + +ENTRY e { + arg0 = f16[65536,32800] parameter(0) + arg1 = f16[32800,32] parameter(1) + gemm = (f16[65536,32], s8[0]) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config="{ + \"gemm_backend_config\": { + \"alpha_real\":1, + \"beta\":0, + \"dot_dimension_numbers\":{ + \"lhs_contracting_dimensions\":[\"1\"], + \"rhs_contracting_dimensions\":[\"0\"], + \"lhs_batch_dimensions\":[], + \"rhs_batch_dimensions\":[] + }, + \"alpha_imag\":0, + \"precision_config\":{ + \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] + }, + \"epilogue\":\"DEFAULT\" + } + }" + ROOT get-tuple-element = f16[65536,32] + get-tuple-element((f16[65536,32], s8[0]) gemm), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + xla::HloComputation* comp = module->entry_computation(); + const xla::HloInstruction* fp16gemm = comp->GetInstructionWithName("gemm"); + // flops of gemm A * B = rows(A) * cols(B) * cols(A) * 2 + // where 2 is for the add and multiply + int64_t gold_flops = 65536LL * 32800 * 32 * 2; + EXPECT_EQ(analysis_.flop_count(*fp16gemm), gold_flops); + EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*fp16gemm), 0); +} + +TEST_F(XprofGpuHloCostAnalysisTest, S8GemmAdjustment) { + absl::string_view hlo_string = R"( +HloModule r + +ENTRY e { + arg0 = s8[65536,32800] parameter(0) + arg1 = s8[32800,32] parameter(1) + gemm = (s32[65536,32], s8[0]) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config="{ + \"gemm_backend_config\": { + \"alpha_real\":1, + \"beta\":0, + \"dot_dimension_numbers\":{ + \"lhs_contracting_dimensions\":[\"1\"], + \"rhs_contracting_dimensions\":[\"0\"], + \"lhs_batch_dimensions\":[], + \"rhs_batch_dimensions\":[] + }, + \"alpha_imag\":0, + \"precision_config\":{ + \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] + }, + \"epilogue\":\"DEFAULT\" + } + }" + ROOT get-tuple-element = s32[65536,32] + get-tuple-element((s32[65536,32], s8[0]) gemm), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + xla::HloComputation* comp = module->entry_computation(); + const xla::HloInstruction* s8gemm = comp->GetInstructionWithName("gemm"); + int64_t gold_flops = 65536LL * 32800 * 32 * 2; + EXPECT_EQ(analysis_.flop_count(*s8gemm), gold_flops); + // Matmul of int8 * int8 -> int32, normalized it to equivalent fp16 flops by + // dividing by 2 as all inputs are 8 bits + EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*s8gemm), gold_flops / 2); +} + +} // namespace profiler +} // namespace tensorflow diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index 6c088dfb8f0192..81e370f3cc0c7f 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -342,6 +342,7 @@ const StatTypeMap& GetStatTypeMap() { {"cuda_graph_exec_id", kCudaGraphExecId}, {"cuda_graph_orig_id", kCudaGraphOrigId}, {"step_idle_time_ps", kStepIdleTimePs}, + {"gpu_device_name", kGpuDeviceName}, }); DCHECK_EQ(stat_type_map->size(), kNumStatTypes); return *stat_type_map; diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h index c84c62d8c73996..0d51de4d2905a5 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -278,6 +278,7 @@ enum StatType { kHloProto, // Device capability related. kDevCapClockRateKHz, + // For GPU, this is the number of SMs. kDevCapCoreCount, kDevCapMemoryBandwidth, kDevCapMemorySize, @@ -330,7 +331,8 @@ enum StatType { kCudaGraphExecId, kCudaGraphOrigId, kStepIdleTimePs, - kLastStatType = kStepIdleTimePs, + kGpuDeviceName, + kLastStatType = kGpuDeviceName, }; enum MegaScaleStatType : uint8_t { diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc index 937b85f9a6c8fa..49c9c86a49a89a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc @@ -15,7 +15,9 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_collector.h" +#include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -322,6 +324,15 @@ class PerDeviceCollector { return ret_val; } + std::optional GetDeviceName(CUdevice device) { + char device_name[512]; + if (cuDeviceGetName(device_name, sizeof(device_name), device) != + CUDA_SUCCESS) { + return std::nullopt; + } + return std::string(device_name); + } + std::string GetDeviceXLineName( int64_t stream_id, absl::flat_hash_set& event_types) { @@ -390,6 +401,13 @@ class PerDeviceCollector { CUdevice device; if (cuDeviceGet(&device, device_ordinal) != CUDA_SUCCESS) return; + std::optional device_name = GetDeviceName(device); + if (device_name.has_value()) { + device_plane->AddStatValue(*device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kGpuDeviceName)), + *device_name); + } + auto clock_rate_in_khz = GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_CLOCK_RATE); if (clock_rate_in_khz) { From 56b6f3e482dedd0b307fad5c3600c64e56d88f34 Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Thu, 26 Sep 2024 18:11:37 -0700 Subject: [PATCH 350/483] PR #23853: Enable the activation offloading test Imported from GitHub PR https://github.com/jax-ml/jax/pull/23853 This test ActivationOffloadingTest.test_remat_scan_layout_change_offloadable can be enabled after [XLA PR 17500](https://github.com/openxla/xla/pull/17500) is in. This test is also a small reproducer for MaxText activation offloading --remat-policy=qkv_proj_offloaded. Copybara import of the project: -- adaf54a4bbe10ce05edcfeb29039c6948444c641 by Jane Liu : enable the activation offloading test Merging this change closes #23853 PiperOrigin-RevId: 679365040 --- third_party/xla/xla/python/xla_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index a692e90f5813c8..5885155dcd887a 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 288 +_version = 289 # Version number for MLIR:Python components. mlir_api_version = 57 From 230c98e4f7af7c75d9e07682b6cddb66c45981c2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Sep 2024 21:08:25 -0700 Subject: [PATCH 351/483] Updates the solver to use "deterministic mode" exclusively. PiperOrigin-RevId: 679417582 --- .../auto_sharding/auto_sharding_solver.cc | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 073b2b8ac5221b..c8321348ab7c26 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -420,13 +420,17 @@ AutoShardingSolverResult CallORToolsSolver( // Set random_seed, interleave_search and share_binary_clauses for // determinism, mip_max_bound (to handle large costs), and num_workers for // parallelism. - solver_parameter_str = - request.deterministic_mode() - ? absl::StrCat( - "share_binary_clauses:false,random_seed:1,interleave_" - "search:true,num_workers:", - num_workers) - : absl::StrCat("num_workers:", num_workers); + solver_parameter_str = absl::StrCat("num_workers:", num_workers); + if (request.deterministic_mode()) { + absl::StrAppend( + &solver_parameter_str, + ",share_binary_clauses:false,random_seed:1,interleave_search:true"); + } + if (request.has_solver_timeout()) { + absl::StrAppend( + &solver_parameter_str, ",max_deterministic_time:", + 0.1 * request.solver_timeout().solver_timeout_in_seconds()); + } solver->SetSolverSpecificParametersAsString(solver_parameter_str); } // Create variables @@ -755,10 +759,6 @@ AutoShardingSolverResult CallORToolsSolver( } } #endif - if (request.has_solver_timeout()) { - solver->SetTimeLimit( - absl::Seconds(request.solver_timeout().solver_timeout_in_seconds())); - } if (request.enable_output()) { solver->EnableOutput(); } From 8ebd83cbabe4868dd37a27572e5c78fefdeb22a5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 00:24:03 -0700 Subject: [PATCH 352/483] Automated Code Change PiperOrigin-RevId: 679469195 --- tensorflow/compiler/mlir/BUILD | 2 ++ tensorflow/compiler/mlir/register_common_dialects.cc | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index ef068f28a999e6..e960b8d35cd268 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -195,9 +195,11 @@ cc_library( hdrs = ["register_common_dialects.h"], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlprogram_util", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", + "//tensorflow/core/ir/types:Dialect", "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/register_common_dialects.cc b/tensorflow/compiler/mlir/register_common_dialects.cc index fe626375a8ee8f..cd492c58fcb0fe 100644 --- a/tensorflow/compiler/mlir/register_common_dialects.cc +++ b/tensorflow/compiler/mlir/register_common_dialects.cc @@ -23,11 +23,13 @@ limitations under the License. #include "mlir/InitAllExtensions.h" // from @llvm-project #include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "xla/mlir/framework/ir/xla_framework.h" #include "xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/core/ir/types/dialect.h" namespace mlir { From a1955ad687222648f138f90f1f9c7e8cd0a29c6d Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Fri, 27 Sep 2024 00:28:23 -0700 Subject: [PATCH 353/483] PR #16893: Unary Ops in FP8 Windowed Einsums Imported from GitHub PR https://github.com/openxla/xla/pull/16893 Adds support for unary ops between dequantization and windowed einsum loop. Copybara import of the project: -- fffc93fbab5a85609c0c6feb7b5bb259b47a7627 by Philipp Hack : Adds support for unary ops between dequantization and windowed einsum loop. Merging this change closes #16893 PiperOrigin-RevId: 679470392 --- .../gpu/transforms/windowed_einsum_handler.cc | 40 +++- .../windowed_einsum_handler_test.cc | 191 ++++++++++-------- .../xla/xla/tests/collective_ops_e2e_test.cc | 170 ++++++++++------ 3 files changed, 251 insertions(+), 150 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc index b7ac16438ecb5d..0abe00005316ae 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -56,12 +56,15 @@ namespace m = match; // and type conversions of FP8 operands into the bodies of their while loops, // i.e. rewrites // -// inputs --> dequant --> while loop {collective-permute/dot/etc} +// inputs --> dequant --> (unary) --> while loop {collective-permute/dot/etc} // // into // -// inputs --> while loop {dequant --> collective-permute/dot/etc}. -// Returns whether the input computation has been changed. +// inputs --> (unary) --> while loop {dequant --> collective-permute/dot/etc}. +// +// Unary bitcast, broadcast, copy, reshape and transpose ops are allowed between +// dequantization and while loop. Returns whether the input computation has been +// changed. absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { HloInstruction* while_instr = while_body->WhileCallInstruction(); // The input of the while loop will be modified and must have no other users. @@ -73,8 +76,21 @@ absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { // while loop. HloInstruction* param_tuple = while_instr->mutable_operand(0); std::array binaries, operands, scales; + std::array, 2> unaries; for (int k = 0; k < 2; ++k) { - if (!Match(param_tuple->mutable_operand(k), + HloInstruction* operand = param_tuple->mutable_operand(k); + // Capture bitcast, broadcast, copy, reshape and transpose ops between + // dequantization and the loop. + while (operand->opcode() == HloOpcode::kBitcast || + operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kCopy || + operand->opcode() == HloOpcode::kReshape || + operand->opcode() == HloOpcode::kTranspose) { + unaries[k].emplace_back(operand); + operand = operand->mutable_operand(0); + } + std::reverse(unaries[k].begin(), unaries[k].end()); + if (!Match(operand, m::AnyOf( m::Divide(&binaries[k], m::Convert(m::Op(&operands[k])), m::Broadcast(m::Op(&scales[k]))), @@ -156,6 +172,22 @@ absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { return false; } + // Replace any dequantized bitcast, broadcast, copy, reshape and transpose ops + // before the while loop with FP8 unary ops. + for (int k = 0; k < 2; ++k) { + for (HloInstruction* unary : unaries[k]) { + Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout( + operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().layout().minor_to_major()); + + operands[k] = unary->AddInstruction(unary->CloneWithNewOperands( + ShapeUtil::MakeShapeWithDenseLayout( + operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().layout().minor_to_major()), + {operands[k]})); + } + } + // Replace the dequantized dot operands in the parameter tuple used by while // with FP8 operands. for (int k = 0; k < 2; ++k) { diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc index e5f1e57f306306..4f736ef8614979 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc @@ -634,127 +634,142 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204 TEST_F(WindowedEinsumHandlerTest, AllGatherF8) { constexpr absl::string_view kHloString = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[1536,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 windowed_dot_general_body_ag { - param.1 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) - get-tuple-element.lhs = f32[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0 - collective-permute.send_first_lhs_shard = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.lhs), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} - collective-permute.send_second_lhs_shard = f32[2,512,24576]{2,1,0} collective-permute(collective-permute.send_first_lhs_shard), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} - get-tuple-element.rhs = f32[24576,24576]{1,0} get-tuple-element(param.1), index=1 - get-tuple-element.3 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2 - dot.first_shard_dot = f32[2,512,24576]{2,1,0} dot(get-tuple-element.lhs, get-tuple-element.rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} - constant.12 = s32[] constant(0) - constant.13 = s32[4]{0} constant({0, 512, 1024, 1536}) - get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4 - partition-id = u32[] partition-id() - add = u32[] add(get-tuple-element.5, partition-id) - constant.11 = u32[] constant(4) - remainder = u32[] remainder(add, constant.11) - dynamic-slice = s32[1]{0} dynamic-slice(constant.13, remainder), dynamic_slice_sizes={1} - reshape = s32[] reshape(dynamic-slice) - dynamic-update-slice.update_first_shard_result = f32[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot.first_shard_dot, constant.12, reshape, constant.12) - dot.second_shard_dot = f32[2,512,24576]{2,1,0} dot(collective-permute.send_first_lhs_shard, get-tuple-element.rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} - constant.15 = u32[] constant(1) - add.1 = u32[] add(get-tuple-element.5, constant.15) - add.2 = u32[] add(add.1, partition-id) - remainder.1 = u32[] remainder(add.2, constant.11) - dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.13, remainder.1), dynamic_slice_sizes={1} - reshape.1 = s32[] reshape(dynamic-slice.1) - dynamic-update-slice.update_second_shard_result = f32[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice.update_first_shard_result, dot.second_shard_dot, constant.12, reshape.1, constant.12) - get-tuple-element.4 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3 - add.3 = u32[] add(add.1, constant.15) - ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.send_second_lhs_shard, get-tuple-element.rhs, dynamic-update-slice.update_second_shard_result, get-tuple-element.4, add.3) + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + lhs = f32[2,512,24576]{2,1,0} get-tuple-element(input), index=0 + permuted_lhs0 = f32[2,512,24576]{2,1,0} collective-permute(lhs), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + permuted_lhs1 = f32[2,512,24576]{2,1,0} collective-permute(permuted_lhs0), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + rhs = f32[24576,24576]{1,0} get-tuple-element(input), index=1 + partial_dot_output = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=2 + dot0 = f32[2,512,24576]{2,1,0} dot(lhs, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c0 = s32[] constant(0) + dot_update_slice_offsets = s32[4]{0} constant({0, 512, 1024, 1536}) + loop_counter = u32[] get-tuple-element(input), index=4 + partition_id = u32[] partition-id() + loop_counter_plus_partition_id = u32[] add(loop_counter, partition_id) + c4 = u32[] constant(4) + dot_update_slice_offsets_index0 = u32[] remainder(loop_counter_plus_partition_id, c4) + dot_update_slice_offset0 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index0), dynamic_slice_sizes={1} + dot_update_slice_offset_scalar0 = s32[] reshape(dot_update_slice_offset0) + updated_dot_output0 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(partial_dot_output, dot0, c0, dot_update_slice_offset_scalar0, c0) + dot1 = f32[2,512,24576]{2,1,0} dot(permuted_lhs0, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c1 = u32[] constant(1) + loop_counter_plus_one = u32[] add(loop_counter, c1) + loop_counter_plus_partiion_id_plus_one = u32[] add(loop_counter_plus_one, partition_id) + dot_update_slice_offsets_index1 = u32[] remainder(loop_counter_plus_partiion_id_plus_one, c4) + dot_update_slice_offset1 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index1), dynamic_slice_sizes={1} + dot_update_slice_offset1_scalar = s32[] reshape(dot_update_slice_offset1) + updated_dot_output1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(updated_dot_output0, dot1, c0, dot_update_slice_offset1_scalar, c0) + pass_through = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=3 + next_loop_counter = u32[] add(loop_counter_plus_one, c1) + ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(permuted_lhs1, rhs, updated_dot_output1, pass_through, next_loop_counter) } // windowed_dot_general_body_ag windowed_dot_general_cond_ag { - param = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) - get-tuple-element = u32[] get-tuple-element(param), index=4 - constant.10 = u32[] constant(4) - ROOT compare = pred[] compare(get-tuple-element, constant.10), direction=LT + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + loop_counter = u32[] get-tuple-element(input), index=4 + loop_limit = u32[] constant(4) + ROOT compare = pred[] compare(loop_counter, loop_limit), direction=LT } -ENTRY test_main { - param.4 = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - reshape.8 = f8e4m3fn[2,512,24576]{2,1,0} reshape(param.4) - param.5 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} - constant.18 = f32[] constant(0) - broadcast = f32[2,2048,24576]{2,1,0} broadcast(constant.18), dimensions={} - constant.20 = u32[] constant(0) +ENTRY main { + lhs = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[1536,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + c0_f32 = f32[] constant(0) + c0_f32_bcast = f32[2,2048,24576]{2,1,0} broadcast(c0_f32), dimensions={} + c0_u32 = u32[] constant(0) scale_lhs = f32[] parameter(2) scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={} - lhs_bf32 = f32[2,512,24576]{2,1,0} convert(reshape.8) - lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_bf32, scale_lhs_bcast) + lhs_f32 = f32[2,512,24576]{2,1,0} convert(lhs) + lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_f32, scale_lhs_bcast) scale_rhs = f32[] parameter(3) - scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={} - rhs_bf32 = f32[24576,24576]{1,0} convert(param.5) - rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf32, scale_rhs_bcast) - tuple.2 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, broadcast, broadcast, constant.20) - while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + scale_rhs_bcast = f32[1536,24576]{1,0} broadcast(scale_rhs), dimensions={} + rhs_f32 = f32[1536,24576]{1,0} convert(rhs) + rhs_scaled = f32[1536,24576]{1,0} multiply(rhs_f32, scale_rhs_bcast) + rhs_bcast = f32[16,1536,24576]{2,1,0} broadcast(rhs_scaled), dimensions={1,2} + rhs_reshaped = f32[24576,24576]{1,0} reshape(rhs_bcast) + while_input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_reshaped, c0_f32_bcast, c0_f32_bcast, c0_u32) + while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(while_input), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag ROOT get-tuple-element.13 = f32[2,2048,24576]{2,1,0} get-tuple-element(while), index=2 } )"; RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), R"( -; CHECK-LABEL: unrolled_windowed_dot_general_body_ag -; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) -; CHECK-NEXT: [[GTE0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=0 -; CHECK-NEXT: [[CP0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[GTE0]]), channel_id=6 -; CHECK-NEXT: [[CP1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[CP0]]), channel_id=7 -; CHECK-NEXT: [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1 -; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=2 -; CHECK-NEXT: [[CONVERT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[GTE0]]) -; CHECK-NEXT: [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5 -; CHECK-NEXT: [[BCAST0:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[GTE3]]), dimensions={} -; CHECK-NEXT: [[MUL0:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]]) -; CHECK-NEXT: [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]]) -; CHECK-NEXT: [[GTE4:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6 -; CHECK-NEXT: [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE4]]), dimensions={} -; CHECK-NEXT: [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]]) -; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL0]], [[MUL1]]), +; CHECK-LABEL: %unrolled_windowed_dot_general_body_ag +; CHECK-NEXT: [[INPUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) +; CHECK-NEXT: [[LHS:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[INPUT]]), index=0 +; CHECK-NEXT: [[PERMUTED_LHS0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[LHS]]), channel_id=6 +; CHECK-NEXT: [[PERMUTED_LHS1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[PERMUTED_LHS0]]), channel_id=7 +; CHECK-NEXT: [[RHS:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[INPUT]]), index=1 +; CHECK-NEXT: [[PARTIAL_DOT_OUTPUT:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[INPUT]]), index=2 +; CHECK-NEXT: [[LHS_F32:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[LHS]]) +; CHECK-NEXT: [[SCALE_LHS:%[^ ]+]] = f32[] get-tuple-element([[INPUT]]), index=5 +; CHECK-NEXT: [[SCALE_LHS_BCAST:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[SCALE_LHS]]), dimensions={} +; CHECK-NEXT: [[LHS_SCALED:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[LHS_F32]], [[SCALE_LHS_BCAST]]) +; CHECK-NEXT: [[RHS_F32:%[^ ]+]] = f32[24576,24576]{1,0} convert([[RHS]]) +; CHECK-NEXT: [[SCALE_RHS:%[^ ]+]] = f32[] get-tuple-element([[INPUT]]), index=6 +; CHECK-NEXT: [[SCALE_RHS_BCAST:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[SCALE_RHS]]), dimensions={} +; CHECK-NEXT: [[RHS_SCALED:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[RHS_F32]], [[SCALE_RHS_BCAST]]) +; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[LHS_SCALED]], [[RHS_SCALED]]), ; CHECK-DAG: lhs_contracting_dims={2}, ; CHECK-DAG: rhs_contracting_dims={0}, ; CHECK-DAG: backend_config={ ; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]", ; CHECK-DAG: "wait_on_operation_queues":[], ; CHECK-DAG: "force_earliest_schedule":false} -; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0) -; CHECK-NEXT: [[C4:%[^ ]+]] = u32[] constant(0) +; CHECK-NEXT: [[C0_S32:%[^ ]+]] = s32[] constant(0) +; CHECK-NEXT: [[C0_U32:%[^ ]+]] = u32[] constant(0) ; CHECK-NEXT: [[C5:%[^ ]+]] = u32[] constant(0) -; CHECK-NEXT: [[PID:%[^ ]+]] = u32[] partition-id() -; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[C5]], [[PID]]) -; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(3) -; CHECK-NEXT: [[AND0:%[^ ]+]] = u32[] and([[ADD0]], [[C2]]) -; CHECK-NEXT: [[CLAMP0:%[^ ]+]] = u32[] clamp([[C4]], [[AND0]], [[C2]]) +; CHECK-NEXT: [[PARTITION_ID:%[^ ]+]] = u32[] partition-id() +; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[C5]], [[PARTITION_ID]]) +; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(3) +; CHECK-NEXT: [[AND0:%[^ ]+]] = u32[] and([[ADD0]], [[C3]]) +; CHECK-NEXT: [[CLAMP0:%[^ ]+]] = u32[] clamp([[C0_U32]], [[AND0]], [[C3]]) ; CHECK-NEXT: [[CONVERT3:%[^ ]+]] = s32[] convert([[CLAMP0]]) -; CHECK-NEXT: [[C6:%[^ ]+]] = s32[] constant(512) -; CHECK-NEXT: [[MUL3:%[^ ]+]] = s32[] multiply([[CONVERT3]], [[C6]]) +; CHECK-NEXT: [[C512:%[^ ]+]] = s32[] constant(512) +; CHECK-NEXT: [[MUL3:%[^ ]+]] = s32[] multiply([[CONVERT3]], [[C512]]) ; CHECK-NEXT: [[RESHAPE0:%[^ ]+]] = s32[] reshape([[MUL3]]) -; CHECK-NEXT: [[DUPDATESLICE0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[GTE2]], [[DOT0]], [[C0]], [[RESHAPE0]], [[C0]]), +; CHECK-NEXT: [[UPDATED_DOT_OUTPUT0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[PARTIAL_DOT_OUTPUT]], [[DOT0]], [[C0_S32]], [[RESHAPE0]], [[C0_S32]]), ; CHECK-DAG: backend_config={ ; CHECK-DAG: "operation_queue_id":"0", ; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"], ; CHECK-DAG: "force_earliest_schedule":false} -; CHECK-NEXT: [[CONVERT2:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[CP0]]) -; CHECK-NEXT: [[MUL2:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT2]], [[BCAST0]]) -; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL2]], [[MUL1]]), +; CHECK-NEXT: [[PERMUTED_LHS0_F32:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[PERMUTED_LHS0]]) +; CHECK-NEXT: [[PERMUTED_LHS_SCALED:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[PERMUTED_LHS0_F32]], [[SCALE_LHS_BCAST]]) +; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[PERMUTED_LHS_SCALED]], [[RHS_SCALED]]), ; CHECK-DAG: lhs_contracting_dims={2}, ; CHECK-DAG: rhs_contracting_dims={0} -; CHECK-NEXT: [[GTE7:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4 -; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(1) -; CHECK-NEXT: [[ADD1:%[^ ]+]] = u32[] add([[GTE7]], [[C3]]) -; CHECK-NEXT: [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]]) -; CHECK-NEXT: [[AND1:%[^ ]+]] = u32[] and([[ADD2]], [[C2]]) -; CHECK-NEXT: [[CLAMP1:%[^ ]+]] = u32[] clamp([[C4]], [[AND1]], [[C2]]) +; CHECK-NEXT: [[LOOP_COUNTER:%[^ ]+]] = u32[] get-tuple-element([[INPUT]]), index=4 +; CHECK-NEXT: [[C1:%[^ ]+]] = u32[] constant(1) +; CHECK-NEXT: [[LOOP_COUNTER_PLUS_ONE:%[^ ]+]] = u32[] add([[LOOP_COUNTER]], [[C1]]) +; CHECK-NEXT: [[LOOP_COUNTER_PLUS_ONE_PLUS_PARTITION_ID:%[^ ]+]] = u32[] add([[LOOP_COUNTER_PLUS_ONE]], [[PARTITION_ID]]) +; CHECK-NEXT: [[AND1:%[^ ]+]] = u32[] and([[LOOP_COUNTER_PLUS_ONE_PLUS_PARTITION_ID]], [[C3]]) +; CHECK-NEXT: [[CLAMP1:%[^ ]+]] = u32[] clamp([[C0_U32]], [[AND1]], [[C3]]) ; CHECK-NEXT: [[CONVERT4:%[^ ]+]] = s32[] convert([[CLAMP1]]) -; CHECK-NEXT: [[MUL4:%[^ ]+]] = s32[] multiply([[CONVERT4]], [[C6]]) +; CHECK-NEXT: [[MUL4:%[^ ]+]] = s32[] multiply([[CONVERT4]], [[C512]]) ; CHECK-NEXT: [[RESHAPE1:%[^ ]+]] = s32[] reshape([[MUL4]]) -; CHECK-NEXT: [[DUPDATESLICE1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[DUPDATESLICE0]], [[DOT1]], [[C0]], [[RESHAPE1]], [[C0]]) -; CHECK-NEXT: [[GTE6:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=3 -; CHECK-NEXT: [[C7:%[^ ]+]] = u32[] constant(2) -; CHECK-NEXT: [[ADD3:%[^ ]+]] = u32[] add([[GTE7]], [[C7]]) -; CHECK-NEXT: [[TUPLE0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[CP1]], [[GTE1]], [[DUPDATESLICE1]], [[GTE6]], [[ADD3]], /*index=5*/[[GTE3]], [[GTE4]]) +; CHECK-NEXT: [[UPDATED_DOT_OUTPUT1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[UPDATED_DOT_OUTPUT0]], [[DOT1]], [[C0_S32]], [[RESHAPE1]], [[C0_S32]]) +; CHECK-NEXT: [[PASS_THROUGH:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[INPUT]]), index=3 +; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(2) +; CHECK-NEXT: [[NEXT_LOOP_COUNTER:%[^ ]+]] = u32[] add([[LOOP_COUNTER]], [[C2]]) +; CHECK-NEXT: [[TUPLE:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[PERMUTED_LHS1]], [[RHS]], [[UPDATED_DOT_OUTPUT1]], [[PASS_THROUGH]], [[NEXT_LOOP_COUNTER]], /*index=5*/[[SCALE_LHS]], [[SCALE_RHS]]) +; CHECK-LABEL: ENTRY %main +; CHECK: [[LHS:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} +; CHECK-NEXT: [[RHS:%[^ ]+]] = f8e4m3fn[1536,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} +; CHECK-NEXT: [[RHS_BCAST:%[^ ]+]] = f8e4m3fn[16,1536,24576]{2,1,0} broadcast([[RHS]]), dimensions={1,2} +; CHECK-NEXT: [[RHS_RESHAPED:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} reshape([[RHS_BCAST]]) +; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0) +; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[C0]]), dimensions={} +; CHECK-NEXT: [[C0_U32:%[^ ]+]] = u32[] constant(0) +; CHECK-NEXT: [[SCALE_LHS:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[SCALE_RHS:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[WHILE_INPUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[LHS]], [[RHS_RESHAPED]], [[C0_BCAST]], [[C0_BCAST]], [[C0_U32]], /*index=5*/[[SCALE_LHS]], [[SCALE_RHS]]) +; CHECK: [[WHILE:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) while([[WHILE_INPUT]]), +; CHECK-DAG: condition=%unrolled_windowed_dot_general_cond_ag, +; CHECK-DAG: body=%unrolled_windowed_dot_general_body_ag )"); } diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 479935f7d01f66..442e07a71d6ef1 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -101,10 +101,11 @@ class CollectiveOpsTestE2E : public HloTestBase { CreateExecutable(std::move(module), /*run_hlo_passes=*/true)); EXPECT_TRUE(executable->has_module()); - HloInstruction* gemm_op = - FindInstruction(&executable->module(), HloOpcode::kCustomCall); - EXPECT_THAT(gemm_op, NotNull()); - EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + std::vector gemm_ops = + FindInstructions(&executable->module(), HloOpcode::kCustomCall); + for (HloInstruction* gemm_op : gemm_ops) { + EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + } } absl::StatusOr> ExecuteReplicated(Executable* executable, @@ -867,46 +868,62 @@ ENTRY main.12 { CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); } +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherF8) { + absl::string_view kModuleReplicatedStr = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 + +ENTRY main { + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + scale_lhs = bf16[] parameter(2) + scale_rhs = bf16[] parameter(3) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs_bcast = bf16[48,192]{1,0} broadcast(scale_rhs), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs_bf16 = bf16[48,192]{1,0} convert(rhs) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[48,192]{1,0} multiply(scale_rhs_bcast, rhs_bf16) + dot = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,192]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} +} // main +)"; + + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); +} + TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, - WindowedEinsumE2EAllGatherAndReduceScatterF8) { + WindowedEinsumE2EAllGatherReshapeF8) { absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<>[2,16,48]{2,1,0}, <>[48,192]{1,0}, <>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[2,24,192]{2,1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 -ENTRY main.12 { - Arg_0.1 = <>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = <>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} - Arg_2.3 = bf16[] parameter(3) - Arg_3.4 = bf16[] parameter(4) - broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} - broadcast.1 = bf16[48,192]{1,0} broadcast(Arg_3.4), dimensions={} - convert = bf16[2,16,48]{2,1,0} convert(Arg_0.1) - convert.1 = bf16[48,192]{1,0} convert(Arg_1.2) - multiply = bf16[2,16,48]{2,1,0} multiply(broadcast, convert) - multiply.1 = bf16[48,192]{1,0} multiply(broadcast.1, convert.1) - dot.5 = bf16[2,16,192]{2,1,0} dot(multiply, multiply.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} - custom-call.7 = bf16[2,16,192]{2,1,0} custom-call(dot.5), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} - Arg_4.5 = bf16[] parameter(5) - broadcast.2 = bf16[2,16,192]{2,1,0} broadcast(Arg_4.5), dimensions={} - divide = bf16[2,16,192]{2,1,0} divide(custom-call.7, broadcast.2) - constant = bf16[] constant(-448.) - broadcast.3 = bf16[2,16,192]{2,1,0} broadcast(constant), dimensions={} - constant.1 = bf16[] constant(448.) - broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} - clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = <>[2,16,192]{2,1,0} convert(clamp) - Arg_5.6 = bf16[] parameter(6) - broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} - convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) - multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = <>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} - Arg_7.8 = bf16[] parameter(7) - broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} - convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) - multiply.3 = bf16[192,48]{1,0} multiply(convert.4, broadcast.6) - dot.6 = bf16[2,16,48]{2,1,0} dot(multiply.2, multiply.3), lhs_contracting_dims={2}, rhs_contracting_dims={0} - tuple.10 = (bf16[2,16,48]{2,1,0}) tuple(dot.6) - ROOT get-tuple-element.11 = bf16[2,16,48]{2,1,0} get-tuple-element(tuple.10), index=0, sharding={devices=[1,4,1]<=[4]} -} // main.12 +ENTRY main { + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[2,24,192]{2,1,0} parameter(1), sharding={devices=[1,1,4]<=[4]} + scale_lhs = bf16[] parameter(2) + scale_rhs = bf16[] parameter(3) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={} + scale_rhs_bcast = bf16[2,24,192]{2,1,0} broadcast(scale_lhs), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs_bf16 = bf16[2,24,192]{2,1,0} convert(rhs) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[2,24,192]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16) + rhs_reshaped = bf16[48,192]{1,0} reshape(rhs_scaled) + dot = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs_reshaped), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,192]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} +} // main )"; // Disable the dot merger pass which can prevent the creation of FP8 GEMM @@ -933,24 +950,61 @@ TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 ENTRY main { - rhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - lhs0 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs0 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + scale_lhs = bf16[] parameter(3) + scale_rhs0 = bf16[] parameter(4) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs0_bcast = bf16[48,192]{1,0} broadcast(scale_rhs0), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs0_bf16 = bf16[48,192]{1,0} convert(rhs0) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs0_scaled = bf16[48,192]{1,0} multiply(scale_rhs0_bcast, rhs0_bf16) + dot0 = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + rhs1 = f8e4m3fn[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} + scale_rhs1 = bf16[] parameter(5) + scale_rhs1_bcast = bf16[48,192]{1,0} broadcast(scale_rhs1), dimensions={} + rhs1_bf16 = bf16[48,192]{1,0} convert(rhs1) + rhs1_scaled = bf16[48,192]{1,0} multiply(scale_rhs1_bcast, rhs1_bf16) + dot1 = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT add = bf16[2,16,192]{2,1,0} add(dot0, dot1) +} // main +)"; + + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); +} + +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, + WindowedEinsumE2EReduceScatterF8) { + absl::string_view kModuleReplicatedStr = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,192]{2,1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 + +ENTRY main { + lhs = f8e4m3fn[2,16,192]{2,1,0} parameter(0), sharding={devices=[1,1,4]<=[4]} + rhs = f8e4m3fn[192,48]{1,0} parameter(1), sharding={devices=[4,1]<=[4]} + scale_lhs = bf16[] parameter(2) scale_rhs = bf16[] parameter(3) - scale_lhs0 = bf16[] parameter(4) - scale_rhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={} - scale_lhs0_bcast = bf16[48,192]{1,0} broadcast(scale_lhs0), dimensions={} - rhs_bf16 = bf16[2,16,48]{2,1,0} convert(rhs) - lhs0_bf16 = bf16[48,192]{1,0} convert(lhs0) - rhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16) - lhs0_scaled = bf16[48,192]{1,0} multiply(scale_lhs0_bcast, lhs0_bf16) - dot0 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} - lhs1 = f8e4m3fn[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} - scale_lhs1 = bf16[] parameter(5) - scale_lhs1_bcast = bf16[48,192]{1,0} broadcast(scale_lhs1), dimensions={} - lhs1_bf16 = bf16[48,192]{1,0} convert(lhs1) - lhs1_scaled = bf16[48,192]{1,0} multiply(scale_lhs1_bcast, lhs1_bf16) - dot1 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} - ROOT add.8 = bf16[2,16,192]{2,1,0} add(dot0, dot1) + scale_lhs_bcast = bf16[2,16,192]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs_bcast = bf16[192,48]{1,0} broadcast(scale_rhs), dimensions={} + lhs_bf16 = bf16[2,16,192]{2,1,0} convert(lhs) + rhs_bf16 = bf16[192,48]{1,0} convert(rhs) + lhs_scaled = bf16[2,16,192]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[192,48]{1,0} multiply(scale_rhs_bcast, rhs_bf16) + dot = bf16[2,16,48]{2,1,0} dot(lhs_scaled, rhs_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,48]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,4,1]<=[4]} } // main )"; From 5d5e0be377835a06941a4291a3c8b221724ad200 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Fri, 27 Sep 2024 00:41:47 -0700 Subject: [PATCH 354/483] PR #17330: Add stride for amax_o/s for fp8 cudnn fused attention Imported from GitHub PR https://github.com/openxla/xla/pull/17330 As per requirement of cudnn graph API, the amax_s and amax_o has to be set stride. Otherwise, the following error will be hit. ``` xla/service/gpu/tests/gpu_fused_mha_test.cc:1348 Value of: RunAndCompareTwoModules(hlo_string, hlo_string_ref, ErrorSpec{1e-2, 1e-2}) Actual: false (INTERNAL: Tensor 'sdpa_fp8::Amax_O' strides not set. in xla/stream_executor/cuda/cuda_dnn.cc(8232): 'graph_.validate()' ) Copybara import of the project: -- 01c0ede92cfba4bc80263ae51cdcb7880b381daf by shuw : Add strides for amax_o/s at graph building which is required by cudnn-fe. Add tests for bnth and btnh layouts. -- 16b83a2c7a85f0a0371f1ef4edbec2f1a2f27b9b by Shu Wang : Split into multiple lines. -- 77a8e91e7edd339a6935c5772752a5166e585118 by shuw : Improve after review 1 Merging this change closes #17330 PiperOrigin-RevId: 679474160 --- .../service/gpu/tests/gpu_fused_mha_test.cc | 519 +++++++++++++++--- .../xla/xla/stream_executor/cuda/cuda_dnn.cc | 2 + 2 files changed, 430 insertions(+), 91 deletions(-) diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc index b0e2d9c86c95a9..666cc1a041e8eb 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1263,94 +1263,7 @@ class FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM } }; -class FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8 - : public MultiHeadedAttentionTest { - protected: - void TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8() { - if (skip_reason_) GTEST_SKIP() << *skip_reason_; - if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(9, 1, 0)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; - } - XlaBuilder builder(TestName()); - std::string hlo_string_ref = - R"( - HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,4,16,16]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true} - clip.33 { - Arg_2.36 = bf16[] parameter(2) - broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} - Arg_1.35 = bf16[] parameter(1) - broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} - Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) - ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) - } // clip.33 - ENTRY main.106 { - Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - constant.6 = bf16[] constant(1) - broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} - divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) - constant.5 = bf16[] constant(-448) - constant.4 = bf16[] constant(448) - call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 - convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) - convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) - Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) - divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) - call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 - convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) - convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) - Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) - divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) - call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 - convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) - convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) - custom-call.4.0 = (bf16[4,4,16,16]{3,1,2,0}, u8[16]{0}) custom-call(convert.19, convert.31, convert.43), custom_call_target="__cudnn$fmhaSoftmax", operand_layout_constraints={bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 1.0, "dropout_rate": 0.0, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["4", "4", "16", "16"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "seed": 42, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}}} - ROOT get-tuple-element.5.0 = bf16[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.4.0), index=0 - } // main.106 - )"; // NOLINT - std::string hlo_string = R"( - HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,4,16,16]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true} - clip.33 { - Arg_2.36 = bf16[] parameter(2) - broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} - Arg_1.35 = bf16[] parameter(1) - broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} - Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) - ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) - } // clip.33 - ENTRY main.106 { - constant.99 = f32[] constant(1) - broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={} - Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - constant.6 = bf16[] constant(1) - broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} - divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) - constant.5 = bf16[] constant(-448) - constant.4 = bf16[] constant(448) - call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 - convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) - convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) - Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) - divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) - call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 - convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) - convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) - Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) - divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) - call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 - convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) - convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) - custom-call.21.0 = (f8e4m3fn[4,4,16,16]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, u8[16]{0}) custom-call(convert.18, convert.30, convert.42, broadcast.99, broadcast.99, /*index=5*/broadcast.99, broadcast.99, broadcast.99, broadcast.99), custom_call_target="__cudnn$fmhaSoftmaxF8", operand_layout_constraints={f8e4m3fn[4,16,4,16]{3,2,1,0}, f8e4m3fn[4,16,4,16]{3,2,1,0}, f8e4m3fn[4,16,4,16]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}} - get-tuple-element.5.0 = f8e4m3fn[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.21.0), index=0 - ROOT out = bf16[4,4,16,16]{3,1,2,0} convert(get-tuple-element.5.0) - } // main.106 - )"; // NOLINT - EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, - ErrorSpec{1e-2, 1e-2})); - } -}; +class FlashAttentionBMMScaleSoftmaxBMMF8 : public MultiHeadedAttentionTest {}; class FlashAttentionBMMScaleSoftmaxDropoutBMM : public MultiHeadedAttentionTest { @@ -1465,10 +1378,434 @@ XLA_TEST_F(FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM, bfloat16>(); // NOLINT } +absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef() { + static constexpr absl::string_view hlo_text = + R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,16,4,16]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true} + + clip.33 { + Arg_2.36 = bf16[] parameter(2) + broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} + Arg_1.35 = bf16[] parameter(1) + broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} + Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) + ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) + } + + ENTRY main.106 { + Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) + + constant.6 = bf16[] constant(1) + broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} + + divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) + call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) + convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) + + divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) + call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) + convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) + + divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) + call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) + convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) + )"; + return hlo_text; +} + +absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8() { + static constexpr absl::string_view hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,16,4,16]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true} + clip.33 { + Arg_2.36 = bf16[] parameter(2) + broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} + Arg_1.35 = bf16[] parameter(1) + broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} + Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) + ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) + } // clip.33 + ENTRY main.106 { + constant.99 = f32[] constant(1) + broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={} + Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + constant.6 = bf16[] constant(1) + broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} + divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) + constant.5 = bf16[] constant(-448) + constant.4 = bf16[] constant(448) + call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 + convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) + convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) + Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) + divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) + call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 + convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) + convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) + Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) + divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) + call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 + convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) + convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) + )"; + return hlo_text; +} // BMM1 - Scale - Softmax - BMM2 fp8 -XLA_TEST_F(FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8, - Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8) { - TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8(); +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, + Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_BNTH_F8) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 1, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; + } + XlaBuilder builder(TestName()); + std::string ref_bnth = R"( + custom-call.4.0 = ( + bf16[4,4,16,16]{3,1,2,0}, + u8[0]{0} + ) custom-call( + convert.19, + convert.31, + convert.43 + ), + custom_call_target="__cudnn$fmhaSoftmax", + operand_layout_constraints={ + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "dropout_rate": 0.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "4", "16", "16"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "seed": 42, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "sliding_window_length": 0, + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "2"], + "rhs_batch_dimensions": ["0", "2"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"] + } + } + } + get-tuple-element.5.0 = bf16[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.4.0), index=0 + ROOT transpose.7 = bf16[4,16,4,16]{3,2,1,0} transpose(get-tuple-element.5.0), dimensions={0,2,1,3} + } +)"; + + std::string fp8_bnth = R"( + custom-call.21.0 = ( + f8e4m3fn[4,4,16,16]{3,1,2,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + u8[16]{0} + ) custom-call( + convert.18, + convert.30, + convert.42, + broadcast.99, + broadcast.99, + /*index=5*/broadcast.99, + broadcast.99, + broadcast.99, + broadcast.99 + ), + custom_call_target="__cudnn$fmhaSoftmaxF8", + operand_layout_constraints={ + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "4", "16", "16"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "2"], + "rhs_batch_dimensions": ["0", "2"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"] + } + } + } + get-tuple-element.5.0 = f8e4m3fn[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.21.0), index=0 + transpose.26 = f8e4m3fn[4,16,4,16]{3,2,1,0} transpose(get-tuple-element.5.0), dimensions={0,2,1,3} + ROOT out = bf16[4,16,4,16]{3,2,1,0} convert(transpose.26) + } + )"; + + std::string hlo_string = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8()) + + fp8_bnth; + std::string hlo_string_ref = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef()) + + ref_bnth; + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{5e-2, 5e-2})); +} + +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, + Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_BTNH_F8) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 1, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; + } + XlaBuilder builder(TestName()); + + std::string ref_btnh = R"( + custom-call.4.0 = ( + bf16[4,16,4,16]{3,2,1,0}, + u8[0]{0} + ) custom-call( + convert.19, + convert.31, + convert.43 + ), + custom_call_target="__cudnn$fmhaSoftmax", + operand_layout_constraints={ + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "dropout_rate": 0.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "16", "4", "4"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "seed": 42, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "sliding_window_length": 0, + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + ROOT get-tuple-element.5.0 = bf16[4,16,4,16]{3,2,1,0} get-tuple-element(custom-call.4.0), index=0 + } +)"; + + std::string fp8_btnh = R"( + custom-call.21.0 = ( + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + u8[16]{0} + ) custom-call( + convert.18, + convert.30, + convert.42, + broadcast.99, + broadcast.99, + /*index=5*/broadcast.99, + broadcast.99, + broadcast.99, + broadcast.99 + ), + custom_call_target="__cudnn$fmhaSoftmaxF8", + operand_layout_constraints={ + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "16", "4", "4"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + get-tuple-element.5.0 = f8e4m3fn[4,16,4,16]{3,2,1,0} get-tuple-element(custom-call.21.0), index=0 + ROOT out = bf16[4,16,4,16]{3,2,1,0} convert(get-tuple-element.5.0) + } + )"; + + std::string hlo_string = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8()) + + fp8_btnh; + std::string hlo_string_ref = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef()) + + ref_btnh; + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{5e-2, 5e-2})); } // BMM1 - Scale - Softmax - BMM2 fp8 diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 29f7ddd1754df2..93bb5e46b14f38 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -5240,10 +5240,12 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( .set_uid(next_uid()); amax_s->set_output(true) .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::FLOAT) .set_uid(next_uid()); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::FLOAT) .set_uid(next_uid()); From 5b1b7029573b20ce8a79d3f9b9fe6914ebb0219f Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Fri, 27 Sep 2024 01:36:35 -0700 Subject: [PATCH 355/483] In place dynamic reshape PiperOrigin-RevId: 679488996 --- tensorflow/lite/core/c/c_api_types.h | 7 +++++++ tensorflow/lite/core/subgraph.cc | 6 ++++-- tensorflow/lite/kernels/reshape.cc | 18 +++++++++++++++--- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index f0b76bde0258cb..79a00319709300 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -110,6 +110,13 @@ typedef enum TfLiteStatus { // TODO(b/250636993): Cancellation triggered by `SetCancellationFunction` // should also return this status code. kTfLiteCancelled = 8, + + // This status is returned by Prepare when the output shape cannot be + // determined but the size of the output tensor is known. For example, the + // output of reshape is always the same size as the input. This means that + // such ops may be + // done in place. + kTfLiteOutputShapeNotKnown = 9, } TfLiteStatus; /// Types supported by tensor diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 0c5d20d0c02d34..dbd250364a3d82 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1506,7 +1506,8 @@ TfLiteStatus Subgraph::PrepareOpsStartingAt( node_index); #endif // TF_LITE_TENSORFLOW_PROFILER const TfLiteStatus op_prepare_status = OpPrepare(registration, &node); - if (op_prepare_status != kTfLiteOk) { + if (op_prepare_status != kTfLiteOk && + op_prepare_status != kTfLiteOutputShapeNotKnown) { ReportOpError(&context_, node, registration, node_index, "failed to prepare"); return op_prepare_status; @@ -1517,7 +1518,8 @@ TfLiteStatus Subgraph::PrepareOpsStartingAt( // Discontinue if the node has dynamic outputs. Note that we don't // stop for dynamic temporary tensors since they won't affect the // sizes of other tensors in the graph. - if (HasDynamicTensor(context_, node.outputs, &dynamic_tensor_index_)) { + if (HasDynamicTensor(context_, node.outputs, &dynamic_tensor_index_) || + op_prepare_status == kTfLiteOutputShapeNotKnown) { has_dynamic_tensors_ = true; return kTfLiteOk; } diff --git a/tensorflow/lite/kernels/reshape.cc b/tensorflow/lite/kernels/reshape.cc index ff53ddb85be876..0c72d480fe1eaf 100644 --- a/tensorflow/lite/kernels/reshape.cc +++ b/tensorflow/lite/kernels/reshape.cc @@ -37,6 +37,7 @@ struct OpData { // This is to prevent incorrect results when mischievous users overwrite // output pointers with their own. const void* output_ptr; + bool output_shape_known = true; }; TfLiteIntArray* GetOutputShape(TfLiteContext*, TfLiteNode*); @@ -169,9 +170,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); } } else { - SetTensorToDynamic(output); + op_data->output_shape_known = false; + return kTfLiteOutputShapeNotKnown; } } + op_data->output_shape_known = true; return kTfLiteOk; } @@ -186,8 +189,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // There are two ways in which the 'output' can be made dynamic: it could be // a string tensor, or its shape cannot be calculated during Prepare(). In // either case, we now have all the information to calculate its shape. - if (IsDynamicTensor(output)) { - TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + if (output->type != kTfLiteString) { + if (!op_data->output_shape_known) { + if (output->data.data != input->data.data) { + // If the otuput cannot overwrite the input, then we have to set the + // tensor to dyanmic. + SetTensorToDynamic(output); + } + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } } // Note that string tensors are always "dynamic" in the sense that their size @@ -197,6 +207,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // reshape doesn't change the data, the output tensor needs exactly as many // bytes as the input tensor. if (output->type == kTfLiteString) { + SetTensorToDynamic(output); + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); auto bytes_required = input->bytes; TfLiteTensorRealloc(bytes_required, output); output->bytes = bytes_required; From 945815149458cae54cd49b8324851601e6b9bb17 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 27 Sep 2024 01:45:14 -0700 Subject: [PATCH 356/483] Avoid triggering of static_assert on MacOS. We have seen this issue before, and the fix was to explicitly check again for the condition in the "else" branch. Also fix the error message, it had a typo and still referenced the old name of the function. PiperOrigin-RevId: 679491609 --- .../xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc index c8e34d4ba8a880..c6dafe39b1f2ff 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -31,9 +31,10 @@ ExhaustiveOpTestTraits::FallbackErrorSpecGen() { } else if constexpr (N == 2) { return +[](NativeT, NativeT) { return ErrorSpec{}; }; } else { - static_assert(false, - "ExhaustieOpTestTraits::DefaultErrorSpecGen() is only " - "implemented for N == 1 and N == 2."); + static_assert( + N == 1 || N == 2, + "ExhaustiveOpTestTraits::FallbackErrorSpecGen() is only " + "implemented for N == 1 and N == 2."); } } From 99f101a69e17395d888d4f71806e6d25ae6c2cd4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 02:04:06 -0700 Subject: [PATCH 357/483] Update GraphDef version to 1998. PiperOrigin-RevId: 679497320 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 6ca4bceb99ecd8..13944c256422d0 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1997 // Updated: 2024/9/26 +#define TF_GRAPH_DEF_VERSION 1998 // Updated: 2024/9/27 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From fdcab4fb476b3b810406897b5981cf00100e79b1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 02:04:13 -0700 Subject: [PATCH 358/483] compat: Update forward compatibility horizon to 2024-09-27 PiperOrigin-RevId: 679497361 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index f8107f5d868557..b65e9c57c9c560 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 26) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 27) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 856ae320e90055c6ab97894826c0bbc7118cbd94 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 27 Sep 2024 02:30:33 -0700 Subject: [PATCH 359/483] [XLA:GPU] Add a verifier to IndexingMap and reuse it in IndexingMapAttr. PiperOrigin-RevId: 679504735 --- .../service/gpu/fusions/ir/xla_gpu_attrs.cc | 15 ++++++----- .../xla/xla/service/gpu/model/indexing_map.cc | 17 +++++++++++++ .../xla/xla/service/gpu/model/indexing_map.h | 3 +++ .../service/gpu/model/indexing_map_test.cc | 25 +++++++++++++++++++ 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index 5a2cd2b9a29584..8a0380b0706f75 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include @@ -92,15 +93,13 @@ mlir::LogicalResult IndexingMapAttr::verify( mlir::AffineMap map, ArrayRef dim_vars, ArrayRef range_vars, ArrayRef> constraints, bool is_simplified) { - if (map.getNumDims() != dim_vars.size()) { - return emitError() << "dim size must match the number of dimensions in " - "the affine map"; + auto indexing_map = IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, + constraints, is_simplified); + std::stringstream ss; + if (!indexing_map.Verify(ss)) { + return emitError() << ss.str(); } - if (map.getNumSymbols() != range_vars.size()) { - return emitError() - << "range size must match the number of symbols in the affine map"; - } - return mlir::success(); + return success(); } IndexingMap IndexingMapAttr::getIndexingMap() const { diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index c1964b351b35bc..8e431976467734 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -1274,6 +1274,23 @@ IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs) { return ComposeIndexingMaps(lhs, rhs); } +bool IndexingMap::Verify(std::ostream& out) const { + if (IsUndefined()) { + return true; + } + if (affine_map_.getNumDims() != dim_vars_.size()) { + out << "dim size must match the number of dimensions in " + "the affine map"; + return false; + } + if (affine_map_.getNumSymbols() != range_vars_.size() + rt_vars_.size()) { + out << "range vars size + rt var size must match the number of " + "symbols in the affine map"; + return false; + } + return true; +} + // Simplification of IndexingMap has two main parts. // At first we optimized constraints to make the domain as small and simple as // possible. And only then we simplify the affine_map, because its diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 81ca0e9d03588c..36780ddd1841e2 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -317,6 +317,9 @@ class IndexingMap { absl::Span symbol_upper_bounds, bool is_simplified = false); + // Returns true if the indexing map is valid. + bool Verify(std::ostream& out) const; + // Returns true if the map was simplified. bool Simplify(); diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index e5fe84b49bfe6c..8fd369b5b6e596 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include #include +#include +#include #include #include #include @@ -86,6 +88,29 @@ TEST_F(IndexingMapTest, VariableKind) { EXPECT_EQ(ToString(VariableKind::kBlockZ), "block_z"); } +TEST_F(IndexingMapTest, VerifyDimensions) { + auto indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), + /*dim_upper_bounds=*/{10, 10}, /*symbol_upper_bounds=*/{}); + + std::stringstream ss; + EXPECT_FALSE(indexing_map.Verify(ss)); + EXPECT_EQ(ss.str(), + "dim size must match the number of dimensions in the affine map"); +} + +TEST_F(IndexingMapTest, VerifySymbols) { + auto indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), + /*dim_upper_bounds=*/{10}, /*symbol_upper_bounds=*/{10}); + + std::stringstream ss; + EXPECT_FALSE(indexing_map.Verify(ss)); + EXPECT_EQ(ss.str(), + "range vars size + rt var size must match the number of symbols in " + "the affine map"); +} + TEST_F(IndexingMapTest, RTVar) { auto zero_dim_map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, &mlir_context_); From cc97551632b6f9c21923a6cfa34371d676c239e1 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Fri, 27 Sep 2024 03:14:15 -0700 Subject: [PATCH 360/483] Remove cuda_only_cc_library Since now we can exclude targets from building using tags, we won't need the `cuda_only_cc_library` rule anymore. This also required me to remove some wrongly added dependencies, notably I found several targets depending on `cublas_plugin`, even though those targets were not CUDA specific and shouldn't directly depend on CUDA-specific targets. I also found out that the `tsl_gpu_library` macro is not handling its `cuda_deps` attribute correctly. It was adding those dependencies both for ROCm and for CUDA. So this change is fixing that as well. PiperOrigin-RevId: 679516808 --- third_party/xla/xla/BUILD | 1 + third_party/xla/xla/lit.bzl | 7 +- third_party/xla/xla/pjrt/gpu/BUILD | 1 - .../xla/xla/service/gpu/fusions/triton/BUILD | 3 - third_party/xla/xla/service/gpu/kernels/BUILD | 1 - .../xla/xla/stream_executor/build_defs.bzl | 30 ----- .../xla/xla/stream_executor/cuda/BUILD | 125 +++++++++++++----- third_party/xla/xla/stream_executor/gpu/BUILD | 15 ++- third_party/xla/xla/tsl/tsl.bzl | 6 +- 9 files changed, 116 insertions(+), 73 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 64553ebfafb0ee..a4376ed422b6c8 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1347,6 +1347,7 @@ bzl_library( deps = [ "//xla/tsl:tsl_bzl", "@bazel_skylib//lib:paths", + "@local_tsl//tsl/platform/default:cuda_build_defs_bzl", ], ) diff --git a/third_party/xla/xla/lit.bzl b/third_party/xla/xla/lit.bzl index 5837c54ad81eab..5ac1cde98f1d8c 100644 --- a/third_party/xla/xla/lit.bzl +++ b/third_party/xla/xla/lit.bzl @@ -1,6 +1,7 @@ """Helper rules for writing LIT tests.""" load("@bazel_skylib//lib:paths.bzl", "paths") +load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") load("//xla/tsl:tsl.bzl", "if_cuda_tools", "if_google", "if_oss") def enforce_glob(files, **kwargs): @@ -209,7 +210,11 @@ def lit_test( srcs = tools, bin_dir = bin_dir, lib_dir = lib_dir, - deps = ["//xla/stream_executor/cuda:all_runtime"], + deps = if_cuda_is_configured( + [ + "//xla/stream_executor/cuda:all_runtime", + ], + ), visibility = ["//visibility:private"], **kwargs ) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 9c6949f79661f8..8bb23efbab7c82 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -315,7 +315,6 @@ xla_test( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", "//xla/service:hlo_parser", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:literal_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 4a16dfb0fe8dca..bcc4cf88c3b7f2 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -221,7 +221,6 @@ xla_test( "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -269,7 +268,6 @@ xla_test( "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", @@ -401,7 +399,6 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 5c517476917932..1e5bc9d4847126 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -242,7 +242,6 @@ xla_test( "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", - "//xla/stream_executor/cuda:cuda_platform", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/stream_executor/build_defs.bzl b/third_party/xla/xla/stream_executor/build_defs.bzl index 109872e2a0f0df..3204b886c651ff 100644 --- a/third_party/xla/xla/stream_executor/build_defs.bzl +++ b/third_party/xla/xla/stream_executor/build_defs.bzl @@ -1,6 +1,5 @@ """Configurations for StreamExecutor builds""" -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load( "@local_config_rocm//rocm:build_defs.bzl", _if_cuda_or_rocm = "if_cuda_or_rocm", @@ -64,34 +63,5 @@ def gpu_only_cc_library(name, tags = [], **kwargs): target_compatible_with = kwargs.get("target_compatible_with"), ) -def cuda_only_cc_library(name, tags = [], **kwargs): - """A library that only gets compiled when CUDA is configured, otherwise it's an empty target. - - Args: - name: Name of the target - tags: Tags being applied to the implementation target - **kwargs: Accepts all arguments that a `cc_library` would also accept - """ - if not native.package_name().startswith("xla/stream_executor"): - fail("cuda_only_cc_library may only be used in `xla/stream_executor/...`.") - - cc_library( - name = "%s_non_cuda" % name, - tags = ["manual"], - ) - cc_library( - name = "%s_cuda_only" % name, - tags = tags + ["manual", "cuda-only"], - **kwargs - ) - native.alias( - name = name, - actual = if_cuda_is_configured(":%s_cuda_only" % name, ":%s_non_cuda" % name), - visibility = kwargs.get("visibility"), - compatible_with = kwargs.get("compatible_with"), - restricted_to = kwargs.get("restricted_to"), - target_compatible_with = kwargs.get("target_compatible_with"), - ) - def stream_executor_build_defs_bzl_deps(): return [] diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 529a0b6197935f..26c394a144300c 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -10,20 +10,14 @@ load( ) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", "if_cuda_newer_than", ) load( "//xla:xla.bzl", "xla_cc_test", ) -load( - "//xla/service/gpu:build_defs.bzl", - "gpu_kernel_library", -) load( "//xla/stream_executor:build_defs.bzl", - "cuda_only_cc_library", "stream_executor_friends", "tf_additional_cuda_platform_deps", "tf_additional_cudnn_plugin_copts", @@ -87,10 +81,14 @@ cc_library( deps = ["//xla/stream_executor:platform"], ) -cuda_only_cc_library( +cc_library( name = "cuda_platform", srcs = ["cuda_platform.cc"], hdrs = ["cuda_platform.h"], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ @@ -123,10 +121,14 @@ cuda_only_cc_library( alwayslink = True, # Registers itself with the PlatformManager. ) -cuda_only_cc_library( +cc_library( name = "cuda_diagnostics", srcs = ["cuda_diagnostics.cc"], hdrs = ["cuda_diagnostics.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "//xla/stream_executor/gpu:gpu_diagnostics_header", "@com_google_absl//absl/container:inlined_vector", @@ -157,10 +159,14 @@ cc_library( ), ) -cuda_only_cc_library( +cc_library( name = "cuda_driver", srcs = ["cuda_driver.cc"], hdrs = ["cuda_driver.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":cuda_diagnostics", # buildcleaner: keep ":cuda_status", @@ -198,10 +204,14 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_status", srcs = ["cuda_status.cc"], hdrs = ["cuda_status.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", @@ -210,10 +220,14 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_runtime", srcs = ["cuda_runtime.cc"], hdrs = ["cuda_runtime.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -225,10 +239,13 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_collectives", hdrs = ["cuda_collectives.h"], - tags = ["gpu"], + tags = [ + "cuda-only", + "gpu", + ], deps = if_nccl( [":cuda_collectives_impl"], [":cuda_collectives_stub"], @@ -246,6 +263,7 @@ cc_library( "cuda_collectives.h", ], tags = [ + "cuda-only", "gpu", "manual", ], @@ -318,12 +336,16 @@ xla_test( ], ) -cuda_only_cc_library( +cc_library( name = "cublas_lt_header", hdrs = [ "cuda_blas_lt.h", "cuda_blas_utils.h", ], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ "//xla:types", @@ -340,7 +362,7 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cublas_plugin", srcs = [ "cuda_blas.cc", @@ -350,6 +372,10 @@ cuda_only_cc_library( "cuda_blas.h", "cuda_blas_lt.h", ], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_blas_utils", @@ -402,10 +428,14 @@ cuda_only_cc_library( alwayslink = True, ) -cuda_only_cc_library( +cc_library( name = "cuda_blas_utils", srcs = ["cuda_blas_utils.cc"], hdrs = ["cuda_blas_utils.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "//xla/stream_executor", "//xla/stream_executor:blas", @@ -418,10 +448,14 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cufft_plugin", srcs = ["cuda_fft.cc"], hdrs = ["cuda_fft.h"], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_helpers", @@ -447,13 +481,17 @@ cuda_only_cc_library( alwayslink = True, ) -gpu_kernel_library( +cuda_library( name = "delay_kernel_cuda", srcs = [ "delay_kernel.h", "delay_kernel_cuda.cu.cc", ], - tags = ["manual"], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "cuda-only", + "gpu", + ], visibility = internal_visibility([ "//xla/stream_executor:__subpackages__", ]), @@ -467,11 +505,15 @@ gpu_kernel_library( ], ) -cuda_only_cc_library( +cc_library( name = "cudnn_plugin", srcs = ["cuda_dnn.cc"], hdrs = ["cuda_dnn.h"], copts = tf_additional_cudnn_plugin_copts(), + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_diagnostics", @@ -524,10 +566,14 @@ cuda_only_cc_library( alwayslink = True, ) -cuda_only_cc_library( +cc_library( name = "cuda_kernel", srcs = ["cuda_kernel.cc"], hdrs = ["cuda_kernel.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_driver_header", @@ -574,22 +620,28 @@ cuda_library( ], ) -# TODO(leary) we likely need to canonicalize/eliminate this. cc_library( name = "cuda_helpers", - textual_hdrs = if_cuda_is_configured(["cuda_helpers.h"]), - deps = if_cuda_is_configured([ + hdrs = ["cuda_helpers.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ "//xla/stream_executor/gpu:gpu_helpers_header", - "@local_config_cuda//cuda:cuda_headers", - ]) + [ "@com_google_absl//absl/log:check", + "@local_config_cuda//cuda:cuda_headers", ], ) -cuda_only_cc_library( +cc_library( name = "cuda_event", srcs = ["cuda_event.cc"], hdrs = ["cuda_event.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":cuda_driver", "//xla/stream_executor:event", @@ -825,7 +877,7 @@ xla_cc_test( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_asm_compiler", srcs = ["cuda_asm_compiler.cc"], hdrs = ["cuda_asm_compiler.h"], @@ -844,6 +896,10 @@ cuda_only_cc_library( "@cuda_nvcc//:ptxas", ]), # copybara:comment_end + tags = [ + "cuda-only", + "gpu", + ], visibility = internal_visibility([ "//third_party/py/jax:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", @@ -889,7 +945,7 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_executor", srcs = [ "cuda_executor.cc", @@ -898,6 +954,10 @@ cuda_only_cc_library( hdrs = [ "cuda_executor.h", ], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":cuda_collectives", ":cuda_diagnostics", @@ -908,6 +968,7 @@ cuda_only_cc_library( ":cuda_runtime", ":cuda_status", ":cuda_version_parser", + ":delay_kernel_cuda", "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", @@ -954,13 +1015,17 @@ cuda_only_cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([":delay_kernel_cuda"]), + ], alwayslink = True, ) cc_library( name = "all_runtime", copts = tsl_copts(), + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cublas_plugin", diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 7c0bdb7d3c0361..df510e57075b35 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -498,7 +498,11 @@ cc_library( "redzone_allocator_kernel.h", "redzone_allocator_kernel_cuda.cc", ], - tags = ["manual"], + tags = [ + "cuda-only", + "gpu", + "manual", + ], deps = [ ":gpu_asm_opts", "//xla/stream_executor", @@ -596,9 +600,12 @@ xla_test( cc_library( name = "gpu_cudamallocasync_allocator", - srcs = if_cuda_is_configured(["gpu_cudamallocasync_allocator.cc"]), - hdrs = if_cuda_is_configured(["gpu_cudamallocasync_allocator.h"]), - tags = ["gpu"], + srcs = ["gpu_cudamallocasync_allocator.cc"], + hdrs = ["gpu_cudamallocasync_allocator.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":gpu_init_impl", "//xla/stream_executor:stream_executor_h", diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index 0cf769ddf4eadb..cca0a3001a307d 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -21,7 +21,6 @@ load( load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm", - "if_rocm_is_configured", ) load( "@local_tsl//tsl/platform:rules_cc.bzl", @@ -367,7 +366,7 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs cuda_deps = [] kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] - deps = deps + if_cuda_or_rocm(cuda_deps) + deps = deps + if_cuda(cuda_deps) if "default_copts" in kwargs: copts = kwargs["default_copts"] + copts kwargs.pop("default_copts", None) @@ -375,7 +374,8 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs deps = deps + if_cuda([ clean_dep("//xla/tsl/cuda:cudart"), "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ + ]) + if_rocm([ + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ]), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1", "-DNV_CUDNN_DISABLE_EXCEPTION"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), From 3bc6f0fc345c9a6d1b925ff12e4bb48fd2790c5b Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Fri, 27 Sep 2024 03:30:18 -0700 Subject: [PATCH 361/483] Remove unnecessary forward declaration PiperOrigin-RevId: 679521113 --- third_party/xla/xla/stream_executor/gpu/gpu_executor.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index d515e1099488fb..f149e242e45fc0 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -38,8 +38,6 @@ limitations under the License. namespace stream_executor { -class StreamExecutor; - namespace gpu { class GpuStream; From 2c672a09aafcbc5235a99270d8727552ac5343bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 05:06:49 -0700 Subject: [PATCH 362/483] Automated Code Change PiperOrigin-RevId: 679545364 --- third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc b/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc index 1e810cff21a555..318834a93e1e59 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc @@ -236,7 +236,7 @@ struct OneShotBufferizePass opts.bufferizeFunctionBoundaries = true; opts.functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, func::FuncOp funcOp, - const bufferization::BufferizationOptions& options) { + const bufferization::BufferizationOptions& /*options*/) { // Functions created by fusion outlining should have fully dynamic // layout. All other functions (for now only "main") gets static // layout. From 6f02ab87364abc5a672f46e88db651a8873fc190 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 27 Sep 2024 05:21:06 -0700 Subject: [PATCH 363/483] Reland fix to multi-row reduction triggering. Apparently there was no actual breakage, just a numerically unstable model. Reverts c5589c74cd20582d49fc8a6e34a06e77360aba38 PiperOrigin-RevId: 679548584 --- .../xla/service/gpu/fusions/reduction_mlir.cc | 160 +++++++++++------- .../xla/service/gpu/fusions/reduction_mlir.h | 17 +- .../fusions/tests/reduce_multirow/f16_v4.hlo | 22 +++ 3 files changed, 140 insertions(+), 59 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index b2db8fa5cd1730..d195d40fe7e0bd 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -770,32 +770,11 @@ llvm::SmallVector MlirSmallColumnReductionFusion::EmitReduction( shared_rows_ / 2); } -std::unique_ptr CreateMlirReductionFusion( - const HloFusionAnalysis& analysis) { - auto* hero_reduction = analysis.FindHeroReduction(); - CHECK_NE(hero_reduction, nullptr); - ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*hero_reduction); - if (reduction_dimensions.is_row_reduction) { - if (RowReductionGetRowsPerWarp( - reduction_dimensions.dimensions[kRowMinorReduced]) > 1) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); - } - - if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); -} - MlirRowReductionFusion::MlirRowReductionFusion( const HloFusionAnalysis& analysis) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1); constexpr int64_t kMinorReducedElementsPerThread = 16; int64_t num_threads_kept = 1; @@ -931,33 +910,28 @@ llvm::SmallVector MlirRowReductionFusion::EmitReduction( } MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( - const HloFusionAnalysis& analysis) + const HloFusionAnalysis& analysis, int vector_size) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]); input_shape_ = {shape[0], shape[1], shape[2]}; - CHECK_GT(rows_per_warp, 1); - - auto compute_block_size = [&](int vector_size) { - int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size; - - constexpr int64_t kThreadsPerBlockTarget = 256; - int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; - int64_t num_threads_kept = 1; - if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { - num_threads_kept = kept_size; - } else { - num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; - } - num_threads_ = {num_threads_kept, num_threads_reduced}; - tile_sizes_per_thread_ = {shape[0], vector_size}; - num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)}; - }; + num_threads_ = GetNumThreads(reduction_dimensions_, vector_size); + num_blocks_ = {GetNumBlocks(reduction_dimensions_, num_threads_)}; + tile_sizes_per_thread_ = {shape[0], vector_size}; +} - // Compute the launch grid without vectorization. We use the results to - // compute the vectorized launch grid. - compute_block_size(1); +std::unique_ptr MlirMultiRowReductionFusion::TryCreate( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + auto reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + auto shape = reduction_dimensions.dimensions; + // This emitter only supports reductions where the reduced dimension is a + // power of 2. + if (shape[kRowMinorReduced] & (shape[kRowMinorReduced] - 1)) { + return nullptr; + } // Normally, we only consider input types for vectorization. However, in // multi-row reductions, the input:output ratio is much higher, so we consider @@ -965,24 +939,75 @@ MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( int smallest_input_or_output_bits = std::min(analysis.input_output_info().smallest_input_dtype_bits, analysis.input_output_info().smallest_output_dtype_bits); + int largest_input_or_output_bits = + std::max(analysis.input_output_info().smallest_input_dtype_bits, + analysis.input_output_info().smallest_output_dtype_bits); - // This vector size is always valid: we know that the reduced dimension is a - // power of 2, since otherwise RowReductionGetRowsPerWarp would have - // returned 1. // Our codegen can't currently deal with vectorization across rows, so we // limit the vector size to the size of the row. Note that this emitter // essentially reverts to the loop emitter in this case, except for side // outputs. - int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), - 32 / smallest_input_or_output_bits); - - // We target 8 warps per block, which means there could be up to 8 blocks per - // SM, but we have no good way of knowing. In practice, enabling vectorization - // for decently sized reductions at least does not hurt. - if (num_blocks_.front() > analysis.device_info().core_count() && - vector_size > 1) { - compute_block_size(vector_size); + int vector_size = std::min(static_cast(shape[kRowMinorReduced]), + 64 / smallest_input_or_output_bits); + + // Very large vector sizes for f32 can be detrimental, so we limit the vector + // size to 16 bytes if we have some >= 32 bit inputs or outputs. This is still + // a bit on the high side, but remember that we also have very small inputs + // or outputs. + if (largest_input_or_output_bits >= 32) { + vector_size = std::min(128 / largest_input_or_output_bits, vector_size); + } + + // The reduced dimension must fit into a single warp. + if (shape[kRowMinorReduced] > WarpSize() * vector_size) { + return nullptr; + } + + // At the very least, we want to have work for every SM. + // TODO(jreiffers): This limit is probably too low: if we have as many blocks + // as SMs, we'll only run about 8 warps per SM, so occupancy will be very low. + // Further measurements are needed to refine this heuristic. + int64_t min_desired_blocks = analysis.device_info().core_count(); + while (vector_size > 1 && + GetNumBlocks(reduction_dimensions, + GetNumThreads(reduction_dimensions, vector_size)) < + min_desired_blocks) { + vector_size /= 2; } + + // Check again that the reduced dimension fits after potentially reducing the + // vector size. + if (shape[kRowMinorReduced] > WarpSize() * vector_size) { + return nullptr; + } + + return std::make_unique(analysis, vector_size); +} + +absl::InlinedVector MlirMultiRowReductionFusion::GetNumThreads( + const ReductionDimensions& reduction_dimensions, int vector_size) { + int64_t num_threads_reduced = + reduction_dimensions.dimensions[kRowMinorReduced] / vector_size; + + constexpr int64_t kThreadsPerBlockTarget = 256; + int64_t kept_size = reduction_dimensions.dimensions[kRowKept]; + int64_t num_threads_kept = 1; + if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { + num_threads_kept = kept_size; + } else { + num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; + } + return {num_threads_kept, num_threads_reduced}; +} + +int64_t MlirMultiRowReductionFusion::GetNumBlocks( + const ReductionDimensions& reduction_dimensions, + const absl::InlinedVector& num_threads) { + CHECK_EQ(num_threads.size(), 2) + << "Expected num_threads to contain the number of threads in the {kept, " + "reduced} dimensions."; + return CeilOfRatio(reduction_dimensions.dimensions[kRowKept], + num_threads.front()); } IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing( @@ -1013,8 +1038,7 @@ IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing( : mlir::getAffineDimExpr(3, ctx); IndexingMap projected_index = GetIndexingMap(block_id * num_threads_[0] + thread_id[0]); - projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()), - {0, 0}); + projected_index.AddConstraint(thread_id[1] % num_threads_[1], {0, 0}); // We don't need a constraint on the loop dimensions, because they are removed // by GetIndexingMap (since they don't show up in the output index // computation). @@ -1034,10 +1058,30 @@ llvm::SmallVector MlirMultiRowReductionFusion::EmitReduction( auto per_thread = state.EmitPerThreadElements(group_id, inits, state.FusionOutputs()); auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars, - WarpSize() / 2 / GetRowsPerWarp()); + num_threads_[1] / 2); return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state, group_id, /*symbol_values=*/{}); } +std::unique_ptr CreateMlirReductionFusion( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + ReductionDimensions reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + if (reduction_dimensions.is_row_reduction) { + auto multi_row_emitter = MlirMultiRowReductionFusion::TryCreate(analysis); + if (multi_row_emitter != nullptr) { + return multi_row_emitter; + } + return std::make_unique(analysis); + } + + if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { + return std::make_unique(analysis); + } + return std::make_unique(analysis); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h index 838729254070ac..db0fbd2b45c31c 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h @@ -16,6 +16,7 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_ #include +#include #include #include #include @@ -168,9 +169,23 @@ class MlirRowReductionFusion : public MlirReductionFusion { class MlirMultiRowReductionFusion : public MlirReductionFusion { public: - explicit MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis); + MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis, + int vector_size); + + // Attempts to create a multi-row reduction emitter for the given analysis. + // Returns nullptr if the fusion is not supported. + static std::unique_ptr TryCreate( + const HloFusionAnalysis& analysis); protected: + // Returns the number of {kept, reduced} threads for the given reduction and + // vector size. + static absl::InlinedVector GetNumThreads( + const ReductionDimensions& reduction_dimensions, int vector_size); + static int64_t GetNumBlocks( + const ReductionDimensions& reduction_dimensions, + const absl::InlinedVector& num_threads); + int GetRowsPerWarp() const; llvm::SmallVector EmitReduction( int group_id, EmitterState& state) const override; diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo new file mode 100644 index 00000000000000..d2e3928bdfd564 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s + +// The reference implementation reduces in f64, so we need a larger tolerance. +// RUN: test_correctness %s --bijection_inputs=reduce:0 \ +// RUN: --bijection_outputs=reduce --abs_error_bound=0.005 --rel_error_bound=0.005 + +add { + lhs = f16[] parameter(0) + rhs = f16[] parameter(1) + ROOT add = f16[] add(lhs, rhs) +} + +fusion { + param_0 = f16[2048,64] parameter(0) + c = f16[] constant(0) + ROOT reduce = f16[2048] reduce(param_0, c), dimensions={1}, to_apply=add +} + +// If unvectorized, this would be a regular row reduction. However, since we can +// vectorize to size four, we can emit this as a multi-row reduction. +// CHECK: vector.transfer_read {{.*}} vector<4xf16> From 6498610f10f2f32b897cc9bc61141fa92d080287 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Fri, 27 Sep 2024 05:37:45 -0700 Subject: [PATCH 364/483] Reset output pointer to input if ResizeOutput has been called PiperOrigin-RevId: 679553116 --- tensorflow/lite/kernels/reshape.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/reshape.cc b/tensorflow/lite/kernels/reshape.cc index 0c72d480fe1eaf..83ce8727c03e8e 100644 --- a/tensorflow/lite/kernels/reshape.cc +++ b/tensorflow/lite/kernels/reshape.cc @@ -195,8 +195,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // If the otuput cannot overwrite the input, then we have to set the // tensor to dyanmic. SetTensorToDynamic(output); + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } else { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + // The output pointer was set to zero during the call to ResizeTensor. + // Since the output aliases the input, set it back. + output->data.data = input->data.data; } - TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); } } From f8aa34037b4bf4cb51f06caf29b9712d1c3cba8c Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Fri, 27 Sep 2024 06:59:15 -0700 Subject: [PATCH 365/483] Reverts cc97551632b6f9c21923a6cfa34371d676c239e1 PiperOrigin-RevId: 679573999 --- third_party/xla/xla/BUILD | 1 - third_party/xla/xla/lit.bzl | 7 +- third_party/xla/xla/pjrt/gpu/BUILD | 1 + .../xla/xla/service/gpu/fusions/triton/BUILD | 3 + third_party/xla/xla/service/gpu/kernels/BUILD | 1 + .../xla/xla/stream_executor/build_defs.bzl | 30 +++++ .../xla/xla/stream_executor/cuda/BUILD | 125 +++++------------- third_party/xla/xla/stream_executor/gpu/BUILD | 15 +-- third_party/xla/xla/tsl/tsl.bzl | 6 +- 9 files changed, 73 insertions(+), 116 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index a4376ed422b6c8..64553ebfafb0ee 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1347,7 +1347,6 @@ bzl_library( deps = [ "//xla/tsl:tsl_bzl", "@bazel_skylib//lib:paths", - "@local_tsl//tsl/platform/default:cuda_build_defs_bzl", ], ) diff --git a/third_party/xla/xla/lit.bzl b/third_party/xla/xla/lit.bzl index 5ac1cde98f1d8c..5837c54ad81eab 100644 --- a/third_party/xla/xla/lit.bzl +++ b/third_party/xla/xla/lit.bzl @@ -1,7 +1,6 @@ """Helper rules for writing LIT tests.""" load("@bazel_skylib//lib:paths.bzl", "paths") -load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") load("//xla/tsl:tsl.bzl", "if_cuda_tools", "if_google", "if_oss") def enforce_glob(files, **kwargs): @@ -210,11 +209,7 @@ def lit_test( srcs = tools, bin_dir = bin_dir, lib_dir = lib_dir, - deps = if_cuda_is_configured( - [ - "//xla/stream_executor/cuda:all_runtime", - ], - ), + deps = ["//xla/stream_executor/cuda:all_runtime"], visibility = ["//visibility:private"], **kwargs ) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 8bb23efbab7c82..9c6949f79661f8 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -315,6 +315,7 @@ xla_test( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", "//xla/service:hlo_parser", + "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:literal_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index bcc4cf88c3b7f2..4a16dfb0fe8dca 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -221,6 +221,7 @@ xla_test( "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -268,6 +269,7 @@ xla_test( "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", @@ -399,6 +401,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 1e5bc9d4847126..5c517476917932 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -242,6 +242,7 @@ xla_test( "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor/cuda:cuda_platform", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/stream_executor/build_defs.bzl b/third_party/xla/xla/stream_executor/build_defs.bzl index 3204b886c651ff..109872e2a0f0df 100644 --- a/third_party/xla/xla/stream_executor/build_defs.bzl +++ b/third_party/xla/xla/stream_executor/build_defs.bzl @@ -1,5 +1,6 @@ """Configurations for StreamExecutor builds""" +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load( "@local_config_rocm//rocm:build_defs.bzl", _if_cuda_or_rocm = "if_cuda_or_rocm", @@ -63,5 +64,34 @@ def gpu_only_cc_library(name, tags = [], **kwargs): target_compatible_with = kwargs.get("target_compatible_with"), ) +def cuda_only_cc_library(name, tags = [], **kwargs): + """A library that only gets compiled when CUDA is configured, otherwise it's an empty target. + + Args: + name: Name of the target + tags: Tags being applied to the implementation target + **kwargs: Accepts all arguments that a `cc_library` would also accept + """ + if not native.package_name().startswith("xla/stream_executor"): + fail("cuda_only_cc_library may only be used in `xla/stream_executor/...`.") + + cc_library( + name = "%s_non_cuda" % name, + tags = ["manual"], + ) + cc_library( + name = "%s_cuda_only" % name, + tags = tags + ["manual", "cuda-only"], + **kwargs + ) + native.alias( + name = name, + actual = if_cuda_is_configured(":%s_cuda_only" % name, ":%s_non_cuda" % name), + visibility = kwargs.get("visibility"), + compatible_with = kwargs.get("compatible_with"), + restricted_to = kwargs.get("restricted_to"), + target_compatible_with = kwargs.get("target_compatible_with"), + ) + def stream_executor_build_defs_bzl_deps(): return [] diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 26c394a144300c..529a0b6197935f 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -10,14 +10,20 @@ load( ) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", "if_cuda_newer_than", ) load( "//xla:xla.bzl", "xla_cc_test", ) +load( + "//xla/service/gpu:build_defs.bzl", + "gpu_kernel_library", +) load( "//xla/stream_executor:build_defs.bzl", + "cuda_only_cc_library", "stream_executor_friends", "tf_additional_cuda_platform_deps", "tf_additional_cudnn_plugin_copts", @@ -81,14 +87,10 @@ cc_library( deps = ["//xla/stream_executor:platform"], ) -cc_library( +cuda_only_cc_library( name = "cuda_platform", srcs = ["cuda_platform.cc"], hdrs = ["cuda_platform.h"], - tags = [ - "cuda-only", - "gpu", - ], visibility = ["//visibility:public"], deps = [ @@ -121,14 +123,10 @@ cc_library( alwayslink = True, # Registers itself with the PlatformManager. ) -cc_library( +cuda_only_cc_library( name = "cuda_diagnostics", srcs = ["cuda_diagnostics.cc"], hdrs = ["cuda_diagnostics.h"], - tags = [ - "cuda-only", - "gpu", - ], deps = [ "//xla/stream_executor/gpu:gpu_diagnostics_header", "@com_google_absl//absl/container:inlined_vector", @@ -159,14 +157,10 @@ cc_library( ), ) -cc_library( +cuda_only_cc_library( name = "cuda_driver", srcs = ["cuda_driver.cc"], hdrs = ["cuda_driver.h"], - tags = [ - "cuda-only", - "gpu", - ], deps = [ ":cuda_diagnostics", # buildcleaner: keep ":cuda_status", @@ -204,14 +198,10 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_status", srcs = ["cuda_status.cc"], hdrs = ["cuda_status.h"], - tags = [ - "cuda-only", - "gpu", - ], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", @@ -220,14 +210,10 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_runtime", srcs = ["cuda_runtime.cc"], hdrs = ["cuda_runtime.h"], - tags = [ - "cuda-only", - "gpu", - ], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -239,13 +225,10 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_collectives", hdrs = ["cuda_collectives.h"], - tags = [ - "cuda-only", - "gpu", - ], + tags = ["gpu"], deps = if_nccl( [":cuda_collectives_impl"], [":cuda_collectives_stub"], @@ -263,7 +246,6 @@ cc_library( "cuda_collectives.h", ], tags = [ - "cuda-only", "gpu", "manual", ], @@ -336,16 +318,12 @@ xla_test( ], ) -cc_library( +cuda_only_cc_library( name = "cublas_lt_header", hdrs = [ "cuda_blas_lt.h", "cuda_blas_utils.h", ], - tags = [ - "cuda-only", - "gpu", - ], visibility = ["//visibility:public"], deps = [ "//xla:types", @@ -362,7 +340,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cublas_plugin", srcs = [ "cuda_blas.cc", @@ -372,10 +350,6 @@ cc_library( "cuda_blas.h", "cuda_blas_lt.h", ], - tags = [ - "cuda-only", - "gpu", - ], visibility = ["//visibility:public"], deps = [ ":cuda_blas_utils", @@ -428,14 +402,10 @@ cc_library( alwayslink = True, ) -cc_library( +cuda_only_cc_library( name = "cuda_blas_utils", srcs = ["cuda_blas_utils.cc"], hdrs = ["cuda_blas_utils.h"], - tags = [ - "cuda-only", - "gpu", - ], deps = [ "//xla/stream_executor", "//xla/stream_executor:blas", @@ -448,14 +418,10 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cufft_plugin", srcs = ["cuda_fft.cc"], hdrs = ["cuda_fft.h"], - tags = [ - "cuda-only", - "gpu", - ], visibility = ["//visibility:public"], deps = [ ":cuda_helpers", @@ -481,17 +447,13 @@ cc_library( alwayslink = True, ) -cuda_library( +gpu_kernel_library( name = "delay_kernel_cuda", srcs = [ "delay_kernel.h", "delay_kernel_cuda.cu.cc", ], - # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], - tags = [ - "cuda-only", - "gpu", - ], + tags = ["manual"], visibility = internal_visibility([ "//xla/stream_executor:__subpackages__", ]), @@ -505,15 +467,11 @@ cuda_library( ], ) -cc_library( +cuda_only_cc_library( name = "cudnn_plugin", srcs = ["cuda_dnn.cc"], hdrs = ["cuda_dnn.h"], copts = tf_additional_cudnn_plugin_copts(), - tags = [ - "cuda-only", - "gpu", - ], visibility = ["//visibility:public"], deps = [ ":cuda_diagnostics", @@ -566,14 +524,10 @@ cc_library( alwayslink = True, ) -cc_library( +cuda_only_cc_library( name = "cuda_kernel", srcs = ["cuda_kernel.cc"], hdrs = ["cuda_kernel.h"], - tags = [ - "cuda-only", - "gpu", - ], deps = [ "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_driver_header", @@ -620,28 +574,22 @@ cuda_library( ], ) +# TODO(leary) we likely need to canonicalize/eliminate this. cc_library( name = "cuda_helpers", - hdrs = ["cuda_helpers.h"], - tags = [ - "cuda-only", - "gpu", - ], - deps = [ + textual_hdrs = if_cuda_is_configured(["cuda_helpers.h"]), + deps = if_cuda_is_configured([ "//xla/stream_executor/gpu:gpu_helpers_header", - "@com_google_absl//absl/log:check", "@local_config_cuda//cuda:cuda_headers", + ]) + [ + "@com_google_absl//absl/log:check", ], ) -cc_library( +cuda_only_cc_library( name = "cuda_event", srcs = ["cuda_event.cc"], hdrs = ["cuda_event.h"], - tags = [ - "cuda-only", - "gpu", - ], deps = [ ":cuda_driver", "//xla/stream_executor:event", @@ -877,7 +825,7 @@ xla_cc_test( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_asm_compiler", srcs = ["cuda_asm_compiler.cc"], hdrs = ["cuda_asm_compiler.h"], @@ -896,10 +844,6 @@ cc_library( "@cuda_nvcc//:ptxas", ]), # copybara:comment_end - tags = [ - "cuda-only", - "gpu", - ], visibility = internal_visibility([ "//third_party/py/jax:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", @@ -945,7 +889,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_executor", srcs = [ "cuda_executor.cc", @@ -954,10 +898,6 @@ cc_library( hdrs = [ "cuda_executor.h", ], - tags = [ - "cuda-only", - "gpu", - ], deps = [ ":cuda_collectives", ":cuda_diagnostics", @@ -968,7 +908,6 @@ cc_library( ":cuda_runtime", ":cuda_status", ":cuda_version_parser", - ":delay_kernel_cuda", "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", @@ -1015,17 +954,13 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:statusor", - ], + ] + if_cuda_is_configured([":delay_kernel_cuda"]), alwayslink = True, ) cc_library( name = "all_runtime", copts = tsl_copts(), - tags = [ - "cuda-only", - "gpu", - ], visibility = ["//visibility:public"], deps = [ ":cublas_plugin", diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index df510e57075b35..7c0bdb7d3c0361 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -498,11 +498,7 @@ cc_library( "redzone_allocator_kernel.h", "redzone_allocator_kernel_cuda.cc", ], - tags = [ - "cuda-only", - "gpu", - "manual", - ], + tags = ["manual"], deps = [ ":gpu_asm_opts", "//xla/stream_executor", @@ -600,12 +596,9 @@ xla_test( cc_library( name = "gpu_cudamallocasync_allocator", - srcs = ["gpu_cudamallocasync_allocator.cc"], - hdrs = ["gpu_cudamallocasync_allocator.h"], - tags = [ - "cuda-only", - "gpu", - ], + srcs = if_cuda_is_configured(["gpu_cudamallocasync_allocator.cc"]), + hdrs = if_cuda_is_configured(["gpu_cudamallocasync_allocator.h"]), + tags = ["gpu"], deps = [ ":gpu_init_impl", "//xla/stream_executor:stream_executor_h", diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index cca0a3001a307d..0cf769ddf4eadb 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -21,6 +21,7 @@ load( load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm", + "if_rocm_is_configured", ) load( "@local_tsl//tsl/platform:rules_cc.bzl", @@ -366,7 +367,7 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs cuda_deps = [] kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] - deps = deps + if_cuda(cuda_deps) + deps = deps + if_cuda_or_rocm(cuda_deps) if "default_copts" in kwargs: copts = kwargs["default_copts"] + copts kwargs.pop("default_copts", None) @@ -374,8 +375,7 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs deps = deps + if_cuda([ clean_dep("//xla/tsl/cuda:cudart"), "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm([ - "@local_config_rocm//rocm:hip", + ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", ]), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1", "-DNV_CUDNN_DISABLE_EXCEPTION"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), From 345ba0cc964ff889605be9a1e3712ed070f073a3 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Fri, 27 Sep 2024 07:55:39 -0700 Subject: [PATCH 366/483] #sdy Support OpShardingRule in SDY round trip import. PiperOrigin-RevId: 679590102 --- .../shardy/sdy_round_trip/import_shardings.cc | 18 ++++++++++--- .../shardy/sdy_round_trip/import_shardings.h | 8 ++++-- .../test/sdy_round_trip_import_pipeline.mlir | 21 +++++++++++++++- .../xla/xla/service/spmd/shardy/utils.h | 25 ++++++++----------- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc index eb11cc53f1c456..fc6e55b2036781 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc @@ -71,15 +71,17 @@ using ::mlir::func::FuncOp; using ::mlir::mhlo::CustomCallOp; using ::mlir::sdy::kShardingAttr; +using ::mlir::sdy::kShardingRuleAttr; using ::mlir::sdy::MeshAttr; +using ::mlir::sdy::OpShardingRuleAttr; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; -// Builds the shardings coming from Shardy previously. This means +// Builds the shardy attributes coming from Shardy previously. This means // the module was exported from Shardy and we are now round-tripping back. // This should happen after the meshes were created from the `ModuleOp` attrs // (see `SdyRoundTripImportShardingsPass`). -void convertShardings(FuncOp funcOp) { +void convertShardyAttrs(FuncOp funcOp) { // Copy over the argument shardings, but not the result shardings yet. // We need to wait until after we've converted all the Operations before // copying the result shardings. @@ -102,7 +104,7 @@ void convertShardings(FuncOp funcOp) { resNum, StringAttr::get(funcOp.getContext(), kXlaShardingAttr)); } - // Extract the round-tripped SDY shardings from the operations. + // Extract the round-tripped SDY shardy attributes from the operations. funcOp.front().walk([&](Operation* op) { op->removeAttr(kXlaShardingAttr); if (DictionaryAttr dictAttr = getFrontendAttrs(op)) { @@ -141,6 +143,13 @@ void convertShardings(FuncOp funcOp) { } } removeFrontendAttribute(op, kShardingRoundTripAttr); + + // Import sharding rules. + if (auto shardingRuleAttr = parseStringAttr( + dictAttr, kShardingRuleRoundTripAttr)) { + op->setAttr(kShardingRuleAttr, shardingRuleAttr); + removeFrontendAttribute(op, kShardingRuleRoundTripAttr); + } } }); } @@ -176,6 +185,7 @@ class SdyRoundTripImportShardingsPass // Insert the meshes before any functions. builder.setInsertionPointToStart(moduleOp.getBody()); for (NamedAttribute mesh : sdyMeshes) { + mesh.getValue().dump(); auto meshAttr = mlir::cast(mesh.getValue()); symbolTable.insert(builder.create( moduleOp.getLoc(), mesh.getName(), meshAttr)); @@ -183,7 +193,7 @@ class SdyRoundTripImportShardingsPass removeFrontendAttribute(moduleOp, kMeshesRoundTripAttr); for (auto funcOp : moduleOp.getOps()) { - convertShardings(funcOp); + convertShardyAttrs(funcOp); } } diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h index 2f77466af87626..c750ef28705551 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h @@ -23,8 +23,12 @@ limitations under the License. namespace xla { namespace sdy { -// Creates the pass that converts the shardings from strings in MHLO frontend -// attributes to SDY meshes and shardings. +// Creates the pass to convert frontend attributes to SDY attributes: +// +// - Converts shardings from `kShardingRoundTripAttr` to `kShardingAttr` +// - Converts sharding rules from `kShardingRuleRoundTripAttr` to +// `kShardingRuleAttr` +// - Converts meshes from `kMeshesRoundTripAttr` to sdy.mesh symbols std::unique_ptr createSdyRoundTripImportShardingsPass(); // Registers the xla-sdy-round-trip-import-shardings pass. diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 82edbe3f82c7a7..85cbde38d43f52 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: sdy_opt %s -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s +// RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s // CHECK-LABEL: module @multiple_func_result_shardings module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>}"}} { @@ -110,3 +110,22 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x return %3 : tensor<32xi32> } } + +// ----- + +module @no_mesh_module attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{}"}} { + // CHECK-LABEL: func @no_sharding_rule + func.func @no_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + %0 = stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> + } + + // CHECK-LABEL: func @op_sharding_rule + func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>} + %0 = stablehlo.custom_call @foo(%arg0, %arg1) + {mhlo.frontend_attributes = {xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"}} : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> + } +} diff --git a/third_party/xla/xla/service/spmd/shardy/utils.h b/third_party/xla/xla/service/spmd/shardy/utils.h index 80194b3ca04c40..394b6c48e0bfd4 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.h +++ b/third_party/xla/xla/service/spmd/shardy/utils.h @@ -61,25 +61,20 @@ void removeFrontendAttribute(mlir::func::FuncOp funcOp, void loadAllRequiredDialects(mlir::MLIRContext* context); -// Parses `stringAttr` to an attribute of type `AttrTy`. -// -// NOTE: assumes `stringAttr` is of type `StringAttr`. -template -AttrTy parseStringAttr(mlir::Attribute stringAttr) { - std::string value; - std::string error; - CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), - &value, &error)) - << error; - return mlir::cast( - mlir::parseAttribute(value, stringAttr.getContext())); -} - // Parses `attrName` from `dictAttr` to an attribute of type `AttrTy`. template AttrTy parseStringAttr(mlir::DictionaryAttr dictAttr, llvm::StringRef attrName) { - return parseStringAttr(dictAttr.get(attrName)); + if (mlir::Attribute stringAttr = dictAttr.get(attrName)) { + std::string value; + std::string error; + CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), + &value, &error)) + << error; + return mlir::cast( + mlir::parseAttribute(value, stringAttr.getContext())); + } + return nullptr; } } // namespace sdy From 69ee80b4e16681dd3ab92912a5dbb64c9f7444c9 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 27 Sep 2024 08:07:01 -0700 Subject: [PATCH 367/483] PR #17319: Fixes XLA build with numpy>=2.1.0 Imported from GitHub PR https://github.com/openxla/xla/pull/17319 When building XLA using the command from dev guide: docs/developer_guide.md ```bash ./configure.py --backend=CPU bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` Using numpy==2.1.0 we can have the following linking error: ``` ERROR: /xla/xla/python/BUILD:1453:11: Linking xla/python/libnb_numpy.so failed: (Exit 1): clang failed: error executing command (from target //xla/python:nb_numpy) /usr/lib/llvm-14/bin/clang @bazel-out/k8-opt/bin/xla/python/libnb_numpy.so-2.params Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging ld.lld: error: undefined hidden symbol: _xla_numpy_api >>> referenced by nb_numpy.cc >>> bazel-out/k8-opt/bin/xla/python/_objs/nb_numpy/nb_numpy.pic.o:(xla::nb_dtype::from_args(nanobind::object const&)) >>> referenced by nb_numpy.cc >>> bazel-out/k8-opt/bin/xla/python/_objs/nb_numpy/nb_numpy.pic.o:(xla::nb_numpy_ndarray::nb_numpy_ndarray(xla::nb_dtype, absl::lts_20230802::Span, std::optional >, void const*, nanobind::handle)) >>> referenced by nb_numpy.cc >>> bazel-out/k8-opt/bin/xla/python/_objs/nb_numpy/nb_numpy.pic.o:(xla::nb_numpy_ndarray::nb_numpy_ndarray(xla::nb_dtype, absl::lts_20230802::Span, std::optional >, void const*, nanobind::handle)) >>> referenced 4 more times ld.lld: error: undefined hidden symbol: _xla_numpy_apiPyArray_RUNTIME_VERSION >>> referenced by nb_numpy.cc >>> bazel-out/k8-opt/bin/xla/python/_objs/nb_numpy/nb_numpy.pic.o:(xla::nb_numpy_ndarray::itemsize() const) clang: error: linker command failed with exit code 1 (use -v to see invocation) ``` Which should be related to https://github.com/numpy/numpy/blob/main/doc/source/release/2.1.0-notes.rst#api-symbols-now-hidden-but-customizable This PR fixes the build issue Copybara import of the project: -- 2f6e1b3e7e1bb189a1b9b5a9e4a94e60bd116a9d by vfdev-5 : Fixes XLA build with numpy>=2.1.0 When building XLA using the command from dev guide: docs/developer_guide.md ```bash ./configure.py --backend=CPU bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` Using numpy==2.1.0 we can have the following linking error: ``` ERROR: /xla/xla/python/BUILD:1453:11: Linking xla/python/libnb_numpy.so failed: (Exit 1): clang failed: error executing command (from target //xla/python:nb_numpy) /usr/lib/llvm-14/bin/clang @bazel-out/k8-opt/bin/xla/python/libnb_numpy.so-2.params Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging ld.lld: error: undefined hidden symbol: _xla_numpy_api >>> referenced by nb_numpy.cc >>> bazel-out/k8-opt/bin/xla/python/_objs/nb_numpy/nb_numpy.pic.o:(xla::nb_dtype::from_args(nanobind::object const&)) >>> referenced by nb_numpy.cc >>> bazel-out/k8-opt/bin/xla/python/_objs/nb_numpy/nb_numpy.pic.o:(xla::nb_numpy_ndarray::nb_numpy_ndarray(xla::nb_dtype, absl::lts_20230802::Span, std::optional >, void const*, nanobind::handle)) >>> referenced by nb_numpy.cc >>> bazel-out/k8-opt/bin/xla/python/_objs/nb_numpy/nb_numpy.pic.o:(xla::nb_numpy_ndarray::nb_numpy_ndarray(xla::nb_dtype, absl::lts_20230802::Span, std::optional >, void const*, nanobind::handle)) >>> referenced 4 more times ld.lld: error: undefined hidden symbol: _xla_numpy_apiPyArray_RUNTIME_VERSION >>> referenced by nb_numpy.cc >>> bazel-out/k8-opt/bin/xla/python/_objs/nb_numpy/nb_numpy.pic.o:(xla::nb_numpy_ndarray::itemsize() const) clang: error: linker command failed with exit code 1 (use -v to see invocation) ``` Which should be related to https://github.com/numpy/numpy/blob/main/doc/source/release/2.1.0-notes.rst#api-symbols-now-hidden-but-customizable This PR fixes the build issue Merging this change closes #17319 PiperOrigin-RevId: 679594145 --- third_party/xla/xla/tsl/python/lib/core/numpy.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/third_party/xla/xla/tsl/python/lib/core/numpy.h b/third_party/xla/xla/tsl/python/lib/core/numpy.h index 6a5a6a6486ccf7..ca57a0370548ed 100644 --- a/third_party/xla/xla/tsl/python/lib/core/numpy.h +++ b/third_party/xla/xla/tsl/python/lib/core/numpy.h @@ -29,6 +29,13 @@ limitations under the License. #define NO_IMPORT_ARRAY #endif +// Prevent linking error with numpy>=2.1.0 +// error: undefined hidden symbol: _xla_numpy_apiPyArray_RUNTIME_VERSION +// Without this define, Numpy's API symbols will have hidden symbol visibility, +// which may break things if Bazel chooses to build a cc_library target into +// its own .so file. Bazel typically does this for debug builds. +#define NPY_API_SYMBOL_ATTRIBUTE + // clang-format off // Place `` before to avoid build failure in macOS. #include From 644aaddd9350b666cafc616ea1bf85a7c3992e9f Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Fri, 27 Sep 2024 08:09:50 -0700 Subject: [PATCH 368/483] [XLA:GPU] propagate the algorithm flag of dot op to cublasGemm custom call. we have the algorithm flag of dot op. we handle it in triton emitter, now let's push it to cublas via gemm_rewriter. Otherwise the cublas call uses the default f32_f32_f32 algorithm and loses the competition with triton. As a result of this change it get clear that only Ampere ran bf16 version of cublas kernel. Hopper uses tf32 for that because it does not have the b16 version for this case. DotBF16ForBf16Bf16F32Tests was removed because the algorithm BF16_BF16_F32 expects F32 input and F32 output with the BF16 arithmetics inside cublas. PiperOrigin-RevId: 679595014 --- third_party/xla/xla/service/algorithm_util.cc | 3 +- .../service/gpu/dot_algorithm_support_test.cc | 9 --- .../xla/xla/service/gpu/fusions/triton/BUILD | 39 ++++++++-- .../gpu/fusions/triton/kernel_name_tracer.h | 39 ++++++++++ .../fusions/triton/kernel_name_tracer_cuda.cc | 72 +++++++++++++++++++ .../fusions/triton/kernel_name_tracer_noop.cc | 33 +++++++++ ...riton_fusion_emitter_device_legacy_test.cc | 72 ++++++++++++++++++- .../gpu/fusions/triton/triton_test_utils.cc | 7 +- .../service/gpu/transforms/gemm_rewriter.cc | 11 ++- 9 files changed, 261 insertions(+), 24 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer.h create mode 100644 third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc create mode 100644 third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc diff --git a/third_party/xla/xla/service/algorithm_util.cc b/third_party/xla/xla/service/algorithm_util.cc index 8eec061e84e13a..85380d2b2ef72c 100644 --- a/third_party/xla/xla/service/algorithm_util.cc +++ b/third_party/xla/xla/service/algorithm_util.cc @@ -41,10 +41,11 @@ absl::StatusOr GetBlasComputationType( switch (algorithm) { case PrecisionConfig::ALG_DOT_F16_F16_F16: return se::blas::ComputationType::kF16; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + return se::blas::ComputationType::kBF16AsF32; case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: case PrecisionConfig::ALG_DOT_F16_F16_F32: - case PrecisionConfig::ALG_DOT_BF16_BF16_F32: case PrecisionConfig::ALG_DOT_F32_F32_F32: return se::blas::ComputationType::kF32; case PrecisionConfig::ALG_DOT_TF32_TF32_F32: diff --git a/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc b/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc index 23d21d9fbd1b56..2a34e418dd30a8 100644 --- a/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc +++ b/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc @@ -215,15 +215,6 @@ INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest, Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBF16ForBf16Bf16F32Tests, DotAlgorithmSupportTest, - Combine(Values(PC::ALG_DOT_BF16_BF16_F32), - Values(BF16), Values(BF16, F32), - Values(CC(8, 0)), - Values(SemanticVersion{6, 0, 0}), - Values(BackendRestriction::kNoRestriction), - Values(Sizes{32, 32}, Sizes{16, 2})), - TestParamsToString); - INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest, Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32), Values(F32), Values(CC(8, 0)), diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 4a16dfb0fe8dca..99489d8e6d41b4 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -1,4 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") load("//xla:xla.bzl", "xla_cc_test") @@ -206,6 +207,7 @@ xla_test( "no_mac", ], deps = [ + ":kernel_name_tracer", ":triton_fusion_emitter", ":triton_test_utils", "//xla:autotuning_proto_cc", @@ -223,13 +225,10 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:filecheck", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", @@ -237,7 +236,6 @@ xla_test( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -286,6 +284,37 @@ xla_test( ], ) +cc_library( + name = "kernel_name_tracer_cuda", + testonly = True, + srcs = if_cuda(["kernel_name_tracer_cuda.cc"]), + hdrs = ["kernel_name_tracer.h"], + tags = ["manual"], # Need to exclude this from wildcard builds + deps = [ + "//xla/backends/profiler/gpu:cupti_collector", + "//xla/backends/profiler/gpu:cupti_tracer", + "@local_tsl//tsl/profiler/utils:time_utils", + ], +) + +cc_library( + name = "kernel_name_tracer_noop", + testonly = True, + srcs = ["kernel_name_tracer_noop.cc"], + hdrs = ["kernel_name_tracer.h"], + tags = ["manual"], # Need to exclude this from wildcard builds +) + +cc_library( + name = "kernel_name_tracer", + testonly = True, + hdrs = ["kernel_name_tracer.h"], + deps = if_cuda( + [":kernel_name_tracer_cuda"], + [":kernel_name_tracer_noop"], + ), +) + cc_library( name = "triton_test_utils", testonly = True, @@ -321,6 +350,7 @@ cc_library( "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/utils:time_utils", ], ) @@ -479,6 +509,7 @@ xla_test( ], tags = ["no_mac"], deps = [ + ":kernel_name_tracer", ":triton_fusion_emitter", ":triton_support", ":triton_test_utils", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer.h b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer.h new file mode 100644 index 00000000000000..73d348367ff874 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer.h @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ + +#include +#include + +namespace xla::gpu { + +// In some cases we need to know what exact kernel was used. It happens when we +// have no direct way to get this information from the HLO. For example, when we +// have a fusion with a custom call to cuBLAS or another third party library. +// This class allows to get the name of the kernel that was used. +class KernelNameTracer { + public: + static std::unique_ptr Create(); + + virtual void start() = 0; + virtual std::string stop() = 0; + virtual ~KernelNameTracer() = default; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc new file mode 100644 index 00000000000000..cd8830542cfec5 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc @@ -0,0 +1,72 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/backends/profiler/gpu/cupti_collector.h" +#include "xla/backends/profiler/gpu/cupti_tracer.h" +#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" +#include "tsl/profiler/utils/time_utils.h" + +namespace xla::gpu { + +// This class allows to get the name of the kernel that was used. +// It works only on CUDA. It uses CuptiTracer to get the kernel name. +class KernelNameTracerCuda : public KernelNameTracer { + public: + KernelNameTracerCuda() + : cupti_tracer_(profiler::CuptiTracer::GetCuptiTracerSingleton()) {} + + void start() override; + + // As of now it returns the name of the first kernel that was executed on + // GPU:0. + std::string stop() override; + + private: + std::unique_ptr cupti_tracer_; + std::unique_ptr cupti_collector_; +}; + +std::unique_ptr KernelNameTracer::Create() { + return std::make_unique(); +} + +void KernelNameTracerCuda::start() { + profiler::CuptiTracerCollectorOptions collector_options; + collector_options.num_gpus = profiler::CuptiTracer::NumGpus(); + auto start_gputime_ns = profiler::CuptiTracer::GetTimestamp(); + auto start_walltime_ns = tsl::profiler::GetCurrentTimeNanos(); + cupti_collector_ = profiler::CreateCuptiCollector( + collector_options, start_walltime_ns, start_gputime_ns); + profiler::CuptiTracerOptions options; + options.activities_selected = {CUPTI_ACTIVITY_KIND_KERNEL}; + cupti_tracer_->Enable(options, cupti_collector_.get()); +} + +std::string KernelNameTracerCuda::stop() { + cupti_tracer_->Disable(); + uint64_t end_gpu_ns = cupti_collector_->GetTracingEndTimeNs(); + auto space = std::make_unique(); + cupti_collector_->Export(space.get(), end_gpu_ns); + for (const auto& plane : space->planes()) { + if (plane.name() == "/device:GPU:0") { + return plane.event_metadata().at(1).name(); + } + } + return ""; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc new file mode 100644 index 00000000000000..f8b0c2f2d8f186 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" + +namespace xla::gpu { + +class KernelNameTracerNoop : public KernelNameTracer { + public: + void start() override {}; + std::string stop() override { return "kernel_name_tracer_not_implemented"; }; +}; + +std::unique_ptr KernelNameTracer::Create() { + return std::make_unique(); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index da369139994aaf..2d5c4232038891 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -37,6 +38,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/fusions/triton/triton_test_utils.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" @@ -46,7 +48,6 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" @@ -147,6 +148,74 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { } }; +class TritonBF16BF16F32BlasTest : public TritonTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + debug_options.set_xla_gpu_enable_triton_gemm(false); + return debug_options; + } + + protected: + void SetUp() override { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + } +}; + +TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) { + // We check that the algorithm is propagated to the BLAS call. + // We also check that the kernel name matches the algorithm for Ampere. + // The algorithm for Hopper is not the one we expect because it uses TF32. + + constexpr std::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = f32[8512,256]{1,0} parameter(0) + rhs = f32[256,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = R"(CHECK: "algorithm":"ALG_DOT_BF16_BF16_F32")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + + auto tracer = KernelNameTracer::Create(); + tracer->start(); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); + auto kernel_name = tracer->stop(); + + if (kernel_name == "kernel_name_tracer_not_implemented") return; + + auto cc = GetCudaComputeCapability(); + using CudaComputeCapabilities = + stream_executor::CudaComputeCapability::CudaComputeCapabilities; + switch (cc.major) { + case CudaComputeCapabilities::BLACKWELL: + GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: " + << kernel_name; + break; + case CudaComputeCapabilities::AMPERE: + EXPECT_THAT(kernel_name, ::testing::HasSubstr("bf16gemm_")); + break; + case CudaComputeCapabilities::HOPPER: + // Hopper does not have bf16 kernels for ALG_DOT_BF16_BF16_F32 algorithm. + // As a result it uses TF32. + EXPECT_THAT(kernel_name, ::testing::HasSubstr("gemm_f32f32_tf32f32_f32")); + break; + default: + GTEST_SKIP() << "Unsupported compute capability: " << cc.major + << " has the kernel name: " << kernel_name; + } +} + TEST_F(TritonGemmTest, RejectDotInt4HLO) { constexpr std::string_view kHloText = R"( HloModule t @@ -200,6 +269,7 @@ TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { rhs_batch_dims={0} } )"; + const std::string pattern = R"(CHECK-NOT: "kind":"__triton_gemm","triton_gemm_config")"; TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc index 6fa9635663999e..564f44da73c42b 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc @@ -70,9 +70,10 @@ bool SupportsBF16(const stream_executor::GpuComputeCapability& cc) { CHECK(false); } -absl::Status CreateTritonIrAndFileCheck( - HloTestBase* test, absl::string_view hlo_text, - absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) { +absl::Status CreateTritonIrAndFileCheck(HloTestBase* test, + absl::string_view hlo_text, + absl::string_view triton_fusion_name, + absl::string_view filecheck_pattern) { TF_ASSIGN_OR_RETURN(std::unique_ptr verified_module, test->ParseAndReturnVerifiedModule(hlo_text)); auto* comp = verified_module->GetComputationWithName(triton_fusion_name); diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 9ba39cf977c0bb..dea1f704c5801e 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -1887,12 +1887,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (!absl::c_linear_search(supported_type, output_type)) return false; TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, se::gpu::AsBlasDataType(output_type)); - // TODO(tdanyluk): Investigate why don't we use the actual precision (and - // algorithm) here? Why do we use the default? - TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type, - se::gpu::GetBlasComputationType( - PrecisionConfig::ALG_UNSET, a_dtype, output_type, - stream_executor::blas::kDefaultComputePrecision)); + TF_ASSIGN_OR_RETURN( + const se::blas::ComputationType compute_type, + se::gpu::GetBlasComputationType( + instr.precision_config().algorithm(), a_dtype, output_type, + stream_executor::blas::kDefaultComputePrecision)); se::blas::DataType scale_type = se::gpu::GetScaleType(output_dtype, compute_type); From 27840bcfa397f4e36c66b8cdf044d69956896660 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 27 Sep 2024 08:48:10 -0700 Subject: [PATCH 369/483] [HLO Componentization] Create hlo/builder sub-component (Phase I). This CL takes care of 1. Migrating xla/client --> xla/hlo/builder 2. Setting up build aliases in xla/client ensuring external dependencies are still satisfied. A following phase II will take care of migration of external projects dependencies from xla/client --> xla/hlo/builder PiperOrigin-RevId: 679606580 --- tensorflow/core/ops/nn_ops.cc | 2 +- third_party/xla/xla/client/BUILD | 143 +- third_party/xla/xla/client/lib/BUILD | 676 +--- third_party/xla/xla/client/lib/approx_topk.h | 54 +- .../xla/xla/client/lib/approx_topk_shape.h | 32 +- third_party/xla/xla/client/lib/arithmetic.h | 72 +- third_party/xla/xla/client/lib/broadcast.h | 17 +- third_party/xla/xla/client/lib/comparators.h | 42 +- third_party/xla/xla/client/lib/constants.h | 122 +- .../xla/xla/client/lib/conv_grad_size_util.h | 26 +- .../xla/xla/client/lib/dynamic_shaped_ops.h | 41 +- third_party/xla/xla/client/lib/loops.h | 56 +- .../xla/xla/client/lib/lu_decomposition.h | 43 +- third_party/xla/xla/client/lib/math.h | 109 +- third_party/xla/xla/client/lib/matrix.h | 141 +- third_party/xla/xla/client/lib/pooling.h | 65 +- third_party/xla/xla/client/lib/prng.h | 83 +- third_party/xla/xla/client/lib/qr.h | 34 +- third_party/xla/xla/client/lib/quantize.h | 166 +- .../xla/xla/client/lib/self_adjoint_eig.h | 23 +- third_party/xla/xla/client/lib/slicing.h | 65 +- third_party/xla/xla/client/lib/sorting.h | 20 +- third_party/xla/xla/client/lib/svd.h | 31 +- third_party/xla/xla/client/lib/testing.cc | 4 +- third_party/xla/xla/client/lib/testing.h | 2 +- third_party/xla/xla/client/lib/tridiagonal.h | 25 +- third_party/xla/xla/client/lib/tuple.h | 18 +- third_party/xla/xla/client/padding.h | 49 +- third_party/xla/xla/client/sharding_builder.h | 45 +- third_party/xla/xla/client/value_inference.h | 100 +- third_party/xla/xla/client/xla_builder.h | 3068 +--------------- third_party/xla/xla/client/xla_computation.h | 55 +- third_party/xla/xla/hlo/builder/BUILD | 190 + third_party/xla/xla/hlo/builder/lib/BUILD | 787 +++++ .../builder}/lib/approx_topk.cc | 8 +- .../xla/xla/hlo/builder/lib/approx_topk.h | 72 + .../builder}/lib/approx_topk_shape.cc | 2 +- .../xla/hlo/builder/lib/approx_topk_shape.h | 50 + .../{client => hlo/builder}/lib/arithmetic.cc | 8 +- .../xla/xla/hlo/builder/lib/arithmetic.h | 90 + .../builder}/lib/arithmetic_test.cc | 4 +- .../{client => hlo/builder}/lib/broadcast.cc | 4 +- .../xla/xla/hlo/builder/lib/broadcast.h | 35 + .../builder}/lib/comparators.cc | 6 +- .../xla/xla/hlo/builder/lib/comparators.h | 60 + .../builder}/lib/comparators_test.cc | 8 +- .../{client => hlo/builder}/lib/constants.cc | 4 +- .../xla/xla/hlo/builder/lib/constants.h | 140 + .../builder}/lib/constants_test.cc | 4 +- .../builder}/lib/conv_grad_size_util.cc | 4 +- .../xla/hlo/builder/lib/conv_grad_size_util.h | 44 + .../builder}/lib/dynamic_shaped_ops.cc | 10 +- .../xla/hlo/builder/lib/dynamic_shaped_ops.h | 59 + .../builder}/lib/generate_math_impl.py | 0 .../xla/{client => hlo/builder}/lib/logdet.cc | 14 +- .../xla/{client => hlo/builder}/lib/logdet.h | 8 +- .../builder}/lib/logdet_test.cc | 4 +- .../xla/{client => hlo/builder}/lib/loops.cc | 6 +- third_party/xla/xla/hlo/builder/lib/loops.h | 74 + .../builder}/lib/lu_decomposition.cc | 4 +- .../xla/hlo/builder/lib/lu_decomposition.h | 61 + .../xla/{client => hlo/builder}/lib/math.cc | 12 +- third_party/xla/xla/hlo/builder/lib/math.h | 127 + .../{client => hlo/builder}/lib/math_impl.h | 12 +- .../{client => hlo/builder}/lib/math_test.cc | 6 +- .../xla/{client => hlo/builder}/lib/matrix.cc | 10 +- third_party/xla/xla/hlo/builder/lib/matrix.h | 159 + .../builder}/lib/matrix_test.cc | 8 +- .../{client => hlo/builder}/lib/pooling.cc | 12 +- third_party/xla/xla/hlo/builder/lib/pooling.h | 83 + .../builder}/lib/pooling_test.cc | 6 +- .../xla/{client => hlo/builder}/lib/prng.cc | 6 +- third_party/xla/xla/hlo/builder/lib/prng.h | 101 + .../{client => hlo/builder}/lib/prng_test.cc | 6 +- .../xla/xla/{client => hlo/builder}/lib/qr.cc | 10 +- third_party/xla/xla/hlo/builder/lib/qr.h | 52 + .../{client => hlo/builder}/lib/qr_test.cc | 6 +- .../xla/xla/hlo/builder/lib/quantize.h | 184 + .../builder}/lib/quantize_test.cc | 4 +- .../builder}/lib/self_adjoint_eig.cc | 6 +- .../xla/hlo/builder/lib/self_adjoint_eig.h | 41 + .../builder}/lib/self_adjoint_eig_test.cc | 12 +- .../{client => hlo/builder}/lib/slicing.cc | 8 +- third_party/xla/xla/hlo/builder/lib/slicing.h | 83 + .../builder}/lib/slicing_test.cc | 4 +- .../{client => hlo/builder}/lib/sorting.cc | 12 +- third_party/xla/xla/hlo/builder/lib/sorting.h | 38 + .../builder}/lib/sorting_test.cc | 4 +- .../xla/{client => hlo/builder}/lib/svd.cc | 18 +- third_party/xla/xla/hlo/builder/lib/svd.h | 49 + .../{client => hlo/builder}/lib/svd_test.cc | 12 +- .../builder}/lib/tridiagonal.cc | 10 +- .../xla/xla/hlo/builder/lib/tridiagonal.h | 43 + .../builder}/lib/tridiagonal_test.cc | 6 +- .../xla/{client => hlo/builder}/lib/tuple.cc | 4 +- third_party/xla/xla/hlo/builder/lib/tuple.h | 36 + .../{client => hlo/builder}/lib/tuple_test.cc | 4 +- .../xla/{client => hlo/builder}/padding.cc | 2 +- third_party/xla/xla/hlo/builder/padding.h | 66 + .../{client => hlo/builder}/padding_test.cc | 2 +- .../builder}/sharding_builder.cc | 2 +- .../xla/xla/hlo/builder/sharding_builder.h | 60 + .../builder}/value_inference.cc | 4 +- .../xla/xla/hlo/builder/value_inference.h | 117 + .../{client => hlo/builder}/xla_builder.cc | 8 +- third_party/xla/xla/hlo/builder/xla_builder.h | 3086 +++++++++++++++++ .../builder}/xla_builder_test.cc | 10 +- .../builder}/xla_computation.cc | 2 +- .../xla/xla/hlo/builder/xla_computation.h | 73 + 109 files changed, 6368 insertions(+), 5447 deletions(-) create mode 100644 third_party/xla/xla/hlo/builder/BUILD create mode 100644 third_party/xla/xla/hlo/builder/lib/BUILD rename third_party/xla/xla/{client => hlo/builder}/lib/approx_topk.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/lib/approx_topk.h rename third_party/xla/xla/{client => hlo/builder}/lib/approx_topk_shape.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h rename third_party/xla/xla/{client => hlo/builder}/lib/arithmetic.cc (97%) create mode 100644 third_party/xla/xla/hlo/builder/lib/arithmetic.h rename third_party/xla/xla/{client => hlo/builder}/lib/arithmetic_test.cc (97%) rename third_party/xla/xla/{client => hlo/builder}/lib/broadcast.cc (97%) create mode 100644 third_party/xla/xla/hlo/builder/lib/broadcast.h rename third_party/xla/xla/{client => hlo/builder}/lib/comparators.cc (97%) create mode 100644 third_party/xla/xla/hlo/builder/lib/comparators.h rename third_party/xla/xla/{client => hlo/builder}/lib/comparators_test.cc (98%) rename third_party/xla/xla/{client => hlo/builder}/lib/constants.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/lib/constants.h rename third_party/xla/xla/{client => hlo/builder}/lib/constants_test.cc (98%) rename third_party/xla/xla/{client => hlo/builder}/lib/conv_grad_size_util.cc (97%) create mode 100644 third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h rename third_party/xla/xla/{client => hlo/builder}/lib/dynamic_shaped_ops.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h rename third_party/xla/xla/{client => hlo/builder}/lib/generate_math_impl.py (100%) rename third_party/xla/xla/{client => hlo/builder}/lib/logdet.cc (90%) rename third_party/xla/xla/{client => hlo/builder}/lib/logdet.h (87%) rename third_party/xla/xla/{client => hlo/builder}/lib/logdet_test.cc (98%) rename third_party/xla/xla/{client => hlo/builder}/lib/loops.cc (97%) create mode 100644 third_party/xla/xla/hlo/builder/lib/loops.h rename third_party/xla/xla/{client => hlo/builder}/lib/lu_decomposition.cc (96%) create mode 100644 third_party/xla/xla/hlo/builder/lib/lu_decomposition.h rename third_party/xla/xla/{client => hlo/builder}/lib/math.cc (99%) create mode 100644 third_party/xla/xla/hlo/builder/lib/math.h rename third_party/xla/xla/{client => hlo/builder}/lib/math_impl.h (97%) rename third_party/xla/xla/{client => hlo/builder}/lib/math_test.cc (99%) rename third_party/xla/xla/{client => hlo/builder}/lib/matrix.cc (99%) create mode 100644 third_party/xla/xla/hlo/builder/lib/matrix.h rename third_party/xla/xla/{client => hlo/builder}/lib/matrix_test.cc (98%) rename third_party/xla/xla/{client => hlo/builder}/lib/pooling.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/lib/pooling.h rename third_party/xla/xla/{client => hlo/builder}/lib/pooling_test.cc (99%) rename third_party/xla/xla/{client => hlo/builder}/lib/prng.cc (99%) create mode 100644 third_party/xla/xla/hlo/builder/lib/prng.h rename third_party/xla/xla/{client => hlo/builder}/lib/prng_test.cc (97%) rename third_party/xla/xla/{client => hlo/builder}/lib/qr.cc (96%) create mode 100644 third_party/xla/xla/hlo/builder/lib/qr.h rename third_party/xla/xla/{client => hlo/builder}/lib/qr_test.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/lib/quantize.h rename third_party/xla/xla/{client => hlo/builder}/lib/quantize_test.cc (99%) rename third_party/xla/xla/{client => hlo/builder}/lib/self_adjoint_eig.cc (95%) create mode 100644 third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h rename third_party/xla/xla/{client => hlo/builder}/lib/self_adjoint_eig_test.cc (97%) rename third_party/xla/xla/{client => hlo/builder}/lib/slicing.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/lib/slicing.h rename third_party/xla/xla/{client => hlo/builder}/lib/slicing_test.cc (99%) rename third_party/xla/xla/{client => hlo/builder}/lib/sorting.cc (97%) create mode 100644 third_party/xla/xla/hlo/builder/lib/sorting.h rename third_party/xla/xla/{client => hlo/builder}/lib/sorting_test.cc (98%) rename third_party/xla/xla/{client => hlo/builder}/lib/svd.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/lib/svd.h rename third_party/xla/xla/{client => hlo/builder}/lib/svd_test.cc (97%) rename third_party/xla/xla/{client => hlo/builder}/lib/tridiagonal.cc (99%) create mode 100644 third_party/xla/xla/hlo/builder/lib/tridiagonal.h rename third_party/xla/xla/{client => hlo/builder}/lib/tridiagonal_test.cc (98%) rename third_party/xla/xla/{client => hlo/builder}/lib/tuple.cc (96%) create mode 100644 third_party/xla/xla/hlo/builder/lib/tuple.h rename third_party/xla/xla/{client => hlo/builder}/lib/tuple_test.cc (97%) rename third_party/xla/xla/{client => hlo/builder}/padding.cc (99%) create mode 100644 third_party/xla/xla/hlo/builder/padding.h rename third_party/xla/xla/{client => hlo/builder}/padding_test.cc (98%) rename third_party/xla/xla/{client => hlo/builder}/sharding_builder.cc (98%) create mode 100644 third_party/xla/xla/hlo/builder/sharding_builder.h rename third_party/xla/xla/{client => hlo/builder}/value_inference.cc (99%) create mode 100644 third_party/xla/xla/hlo/builder/value_inference.h rename third_party/xla/xla/{client => hlo/builder}/xla_builder.cc (99%) create mode 100644 third_party/xla/xla/hlo/builder/xla_builder.h rename third_party/xla/xla/{client => hlo/builder}/xla_builder_test.cc (99%) rename third_party/xla/xla/{client => hlo/builder}/xla_computation.cc (96%) create mode 100644 third_party/xla/xla/hlo/builder/xla_computation.h diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index a15cf79ca9a5b5..4dc51f4bb9fe97 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1522,7 +1522,7 @@ Status ApproxTopKShape(shape_inference::InferenceContext* c) { c->set_output(1, output_shape); return absl::OkStatus(); } -// LINT.ThenChange(//tensorflow/compiler/xla/client/lib/approx_topk_shape.cc) +// LINT.ThenChange(//tensorflow/compiler/xla/hlo/builder/lib/approx_topk_shape.cc) } // namespace diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 115e56f66fbcfb..4cdbfb3d4ef5a2 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -2,7 +2,6 @@ # XLA client libraries. load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//xla:xla.bzl", "xla_cc_test") load("//xla/tsl:tsl.default.bzl", "filegroup") package( @@ -41,29 +40,13 @@ cc_library( ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder:padding +# instead. cc_library( name = "padding", - srcs = ["padding.cc"], hdrs = ["padding.h"], deps = [ - "//xla:types", - "//xla:util", - "//xla/tsl/lib/math:math_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "padding_test", - srcs = ["padding_test.cc"], - deps = [ - ":padding", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/hlo/builder:padding", ], ) @@ -206,141 +189,45 @@ cc_library( ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder:sharding_builder +# instead. cc_library( name = "sharding_builder", - srcs = ["sharding_builder.cc"], hdrs = ["sharding_builder.h"], deps = [ - "//xla:array", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "@com_google_absl//absl/log:check", + "//xla/hlo/builder:sharding_builder", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder:xla_computation +# instead. cc_library( name = "xla_computation", - srcs = ["xla_computation.cc"], hdrs = ["xla_computation.h"], visibility = ["//visibility:public"], deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/status:statusor", + "//xla/hlo/builder:xla_computation", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder:value_inference +# instead. cc_library( name = "value_inference", - srcs = ["value_inference.cc"], hdrs = ["value_inference.h"], visibility = ["//visibility:public"], deps = [ - ":xla_builder", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder:value_inference", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder:xla_builder +# instead. cc_library( name = "xla_builder", - srcs = ["xla_builder.cc"], hdrs = ["xla_builder.h"], visibility = ["//visibility:public"], deps = [ - ":padding", - ":sharding_builder", - ":xla_computation", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:array4d", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:sharding_op_util", - "//xla:status_macros", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_proto_cc", - "//xla/service:shape_inference", - "//xla/tsl/lib/core:bitmap", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:stacktrace", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "xla_builder_test", - srcs = ["xla_builder_test.cc"], - deps = [ - ":padding", - ":sharding_builder", - ":value_inference", - ":xla_builder", - ":xla_computation", - "//xla:comparison_util", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/service:hlo_proto_cc", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder:xla_builder", ], ) diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index 44461b75262794..384c3564a95aec 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -1,7 +1,7 @@ # Common computation builders for XLA. load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") +load("//xla/tests:build_defs.bzl", "generate_backend_suites") load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") @@ -23,555 +23,185 @@ filegroup( # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:arithmetic +# instead. cc_library( name = "arithmetic", - srcs = ["arithmetic.cc"], hdrs = ["arithmetic.h"], deps = [ - ":constants", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "arithmetic_test", - srcs = ["arithmetic_test.cc"], - deps = [ - ":arithmetic", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/types:span", + "//xla/hlo/builder/lib:arithmetic", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:comparators +# instead. cc_library( name = "comparators", - srcs = ["comparators.cc"], hdrs = [ "comparators.h", ], deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -xla_test( - name = "comparators_test", - srcs = ["comparators_test.cc"], - deps = [ - ":comparators", - ":constants", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_proto_cc", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:protobuf", + "//xla/hlo/builder/lib:comparators", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:constants +# instead. cc_library( name = "constants", - srcs = ["constants.cc"], hdrs = ["constants.h"], deps = [ - "//xla:literal_util", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:constants", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:broadcast +# instead. cc_library( name = "broadcast", - srcs = ["broadcast.cc"], hdrs = ["broadcast.h"], deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "constants_test", - srcs = ["constants_test.cc"], - deps = [ - ":constants", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", + "//xla/hlo/builder/lib:broadcast", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:conv_grad_size_util +# instead. cc_library( name = "conv_grad_size_util", - srcs = ["conv_grad_size_util.cc"], hdrs = ["conv_grad_size_util.h"], deps = [ - "//xla/client:padding", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:conv_grad_size_util", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:dynamic_shaped_ops +# instead. cc_library( name = "dynamic_shaped_ops", - srcs = ["dynamic_shaped_ops.cc"], hdrs = ["dynamic_shaped_ops.h"], deps = [ - ":constants", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:value_inference", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:dynamic_shaped_ops", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:loops +# instead. cc_library( name = "loops", - srcs = ["loops.cc"], hdrs = ["loops.h"], deps = [ - ":constants", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:loops", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:math +# instead. cc_library( name = "math", - srcs = ["math.cc"], hdrs = [ "math.h", - "math_impl.h", - ], - deps = [ - ":arithmetic", - ":constants", - ":loops", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], -) - -xla_test( - name = "math_test", - timeout = "long", - srcs = ["math_test.cc"], - backend_tags = { - # Times out. - "ghostfish_iss": ["noasan"], - }, deps = [ - ":constants", - ":math", - "//xla:array3d", - "//xla:error_spec", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/service", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_googletest//:gtest_main", + "//xla/hlo/builder/lib:math", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:matrix +# instead. cc_library( name = "matrix", - srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ - ":arithmetic", - ":constants", - ":slicing", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "matrix_test", - srcs = ["matrix_test.cc"], - deps = [ - ":constants", - ":matrix", - ":slicing", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:array4d", - "//xla:test", - "//xla:types", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "//xla/hlo/builder/lib:matrix", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:pooling +# instead. cc_library( name = "pooling", - srcs = ["pooling.cc"], hdrs = ["pooling.h"], deps = [ - ":arithmetic", - ":constants", - ":conv_grad_size_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:padding", - "//xla/client:xla_builder", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "pooling_test", - srcs = ["pooling_test.cc"], - deps = [ - ":pooling", - "//xla:error_spec", - "//xla:shape_util", - "//xla/client:padding", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", + "//xla/hlo/builder/lib:pooling", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:prng +# instead. cc_library( name = "prng", - srcs = ["prng.cc"], hdrs = ["prng.h"], deps = [ - ":constants", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "prng_test", - srcs = ["prng_test.cc"], - deps = [ - ":constants", - ":prng", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + "//xla/hlo/builder/lib:prng", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:qr +# instead. cc_library( name = "qr", - srcs = ["qr.cc"], hdrs = ["qr.h"], deps = [ - ":constants", - ":matrix", - ":slicing", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "qr_test", - srcs = ["qr_test.cc"], - tags = ["optonly"], - deps = [ - ":matrix", - ":qr", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:qr", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:lu_decomposition +# instead. cc_library( name = "lu_decomposition", - srcs = ["lu_decomposition.cc"], hdrs = ["lu_decomposition.h"], deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:lu_decomposition", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:approx_topk +# instead. cc_library( name = "approx_topk", - srcs = ["approx_topk.cc"], hdrs = ["approx_topk.h"], deps = [ - ":approx_topk_shape", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", + "//xla/hlo/builder/lib:approx_topk", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:approx_topk_shape +# instead. cc_library( name = "approx_topk_shape", - srcs = ["approx_topk_shape.cc"], hdrs = ["approx_topk_shape.h"], - deps = [ - "//xla:util", - "@com_google_absl//absl/status:statusor", - ], + deps = ["//xla/hlo/builder/lib:approx_topk_shape"], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:slicing +# instead. cc_library( name = "slicing", - srcs = ["slicing.cc"], hdrs = ["slicing.h"], deps = [ - ":arithmetic", - ":constants", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "slicing_test", - srcs = ["slicing_test.cc"], - deps = [ - ":slicing", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:literal", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:slicing", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:sorting +# instead. cc_library( name = "sorting", - srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ - ":comparators", - ":constants", - ":loops", - ":slicing", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "sorting_test", - srcs = ["sorting_test.cc"], - deps = [ - ":sorting", - "//xla:array", - "//xla:array2d", - "//xla:error_spec", - "//xla:literal_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", + "//xla/hlo/builder/lib:sorting", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:quantize +# instead. cc_library( name = "quantize", hdrs = ["quantize.h"], deps = [ - ":constants", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@local_tsl//tsl/platform:bfloat16", - ], -) - -xla_test( - name = "quantize_test", - srcs = ["quantize_test.cc"], - # TODO(b/122119490): re-enable TAP after fixing. - tags = [ - "manual", - "notap", - ], - deps = [ - ":quantize", - "//xla:array2d", - "//xla:test", - "//xla:types", - "//xla:util", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:bfloat16", + "//xla/hlo/builder/lib:quantize", ], ) @@ -588,8 +218,8 @@ cc_library( "//xla:xla_proto_cc", "//xla/client", "//xla/client:global_data", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/service", "//xla/tests:test_utils", "@com_google_absl//absl/log:check", @@ -599,213 +229,51 @@ cc_library( ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:self_adjoint_eig +# instead. cc_library( name = "self_adjoint_eig", - srcs = ["self_adjoint_eig.cc"], hdrs = ["self_adjoint_eig.h"], deps = [ - ":slicing", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "self_adjoint_eig_test", - srcs = ["self_adjoint_eig_test.cc"], - real_hardware_only = True, - shard_count = 5, - tags = ["optonly"], - deps = [ - ":arithmetic", - ":constants", - ":math", - ":matrix", - ":self_adjoint_eig", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", + "//xla/hlo/builder/lib:self_adjoint_eig", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:svd +# instead. cc_library( name = "svd", - srcs = ["svd.cc"], hdrs = ["svd.h"], deps = [ - ":arithmetic", - ":comparators", - ":constants", - ":loops", - ":math", - ":matrix", - ":slicing", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "svd_test", - srcs = ["svd_test.cc"], - real_hardware_only = True, - shard_count = 10, - tags = ["optonly"], - deps = [ - ":arithmetic", - ":constants", - ":matrix", - ":slicing", - ":svd", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", + "//xla/hlo/builder/lib:svd", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:tridiagonal +# instead. cc_library( name = "tridiagonal", - srcs = ["tridiagonal.cc"], hdrs = ["tridiagonal.h"], deps = [ - ":constants", - ":loops", - ":slicing", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "tridiagonal_test", - srcs = ["tridiagonal_test.cc"], - real_hardware_only = True, - shard_count = 10, - tags = ["optonly"], - deps = [ - ":slicing", - ":tridiagonal", - "//xla:array", - "//xla:array3d", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:util", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:tridiagonal", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:logdet +# instead. cc_library( name = "logdet", - srcs = ["logdet.cc"], - hdrs = ["logdet.h"], - deps = [ - ":arithmetic", - ":constants", - ":matrix", - ":qr", - ":slicing", - "//xla:shape_util", - "//xla:util", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "logdet_test", - srcs = ["logdet_test.cc"], - tags = [ - "optonly", - ], deps = [ - ":logdet", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:literal", - "//xla:literal_util", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", + "//xla/hlo/builder/lib:logdet", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/builder/lib:tuple +# instead. cc_library( name = "tuple", - srcs = ["tuple.cc"], hdrs = ["tuple.h"], deps = [ - "//xla:shape_tree", - "//xla:shape_util", - "//xla/client:xla_builder", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "tuple_test", - srcs = ["tuple_test.cc"], - deps = [ - ":tuple", - "//xla:error_spec", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/service", - "//xla/tests:client_library_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:tuple", ], ) diff --git a/third_party/xla/xla/client/lib/approx_topk.h b/third_party/xla/xla/client/lib/approx_topk.h index ccad3dc79175fa..175a12cad0e94a 100644 --- a/third_party/xla/xla/client/lib/approx_topk.h +++ b/third_party/xla/xla/client/lib/approx_topk.h @@ -16,57 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_APPROX_TOPK_H_ #define XLA_CLIENT_LIB_APPROX_TOPK_H_ -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Computes approximate top-ks by aggregating top-1s in equal-sized windows. -// The number and the size of the windows are determined by the `recall_target`. -// -// operand: A sequence of multi-dimensional arrays of type T_0, ..., T_{N-1} -// init_values: N starting values for top-1 reductions -// top_k: Determines the k in top-k operation. -// reduction_dim: Determines the dimension to compute top-k. -// comparator: The comparator computation to use, which should have function -// signatore of (T_0, T_0, T_1, T_1, ..., T_{N-1}, T_{N-1}) -> bool. -// recall_target: Valid range (0, 1]. User can trade-off quality and performance -// with this knob. -// aggregate_to_topk: When true, sorts the set of approximate top-k elements and -// only keep the final k elements on TPU. This option is useful when user -// wanted to forward the approximate results to host and aggregate the results -// on CPU for better throughput. -// reduction_input_size_override: When set to a positive value, it overrides the -// size determined by operands[reduction_dim] for evaluating the recall. This -// option is useful when the given operand is only a subset of the overall -// computation in SPMD or distributed pipelines, where the true input size -// cannot be deferred by the operand shape. -// -// Returns a sequence of multidimensional arrays of type T_0, ..., T_{N-1}, -// which contains the approximate top-ks from the input operands. When -// `aggregate_to_topk` is set to true, the output size is just top_k. When -// `aggregate_to_topk` is set to false, the output size varied by the target -// recall. For target recall = 0.9, the output size is roughly 10 * top_k. For -// target recall = 0.99, the output size is roughly 100 * top_k. -// -// TODO(fchern): Support other hardware platforms. -XlaOp ApproxTopK(XlaBuilder* builder, absl::Span operands, - absl::Span init_values, int64_t top_k, - int64_t reduction_dim, const XlaComputation& comparator, - float recall_target = 0.9, bool aggregate_to_topk = true, - int64_t reduction_input_size_override = -1); - -// Fallback for platforms that haven't been optimized. -XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span operands, - absl::Span init_values, int64_t top_k, - int64_t reduction_dim, - const XlaComputation& comparator, - float recall_target = 0.9, - bool aggregate_to_topk = true, - int64_t reduction_input_size_override = -1); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/approx_topk.h" #endif // XLA_CLIENT_LIB_APPROX_TOPK_H_ diff --git a/third_party/xla/xla/client/lib/approx_topk_shape.h b/third_party/xla/xla/client/lib/approx_topk_shape.h index ef59a604adb7f2..eef1e296f36fd3 100644 --- a/third_party/xla/xla/client/lib/approx_topk_shape.h +++ b/third_party/xla/xla/client/lib/approx_topk_shape.h @@ -16,35 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ #define XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ -#include - -#include "absl/status/statusor.h" - -namespace xla { - -// Determine the output size of the reduction dimension. This is useful for jax -// abstract eval to determine the output size. -// -// input_size: Input size of the reduction dimension. -// rank: Rank of the input operand. -// top_k: Determines the k in top-k operation. -// recall_target: Valid range (0, 1]. User can trade-off quality and performance -// with this knob. -// aggregate_to_topk: When true, sorts the set of approximate top-k elements and -// only keep the final k elements on TPU. This option is useful when user -// wanted to forward the approximate results to host and aggregate the results -// on CPU for better throughput. -// -// Returns a pair of -// 1. Reduction output size -// 2. Reduction amount in log2 form. -// -// 2. is invalid and set to -1 when the approximate output is disabled, i.e. -// top_k = 1 or aggregate_to_topk = true. -absl::StatusOr> ApproxTopKReductionOutputSize( - int64_t input_size, int64_t rank, int64_t top_k, float recall_target, - bool aggregate_to_topk, int64_t input_size_override = -1); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/approx_topk_shape.h" #endif // XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ diff --git a/third_party/xla/xla/client/lib/arithmetic.h b/third_party/xla/xla/client/lib/arithmetic.h index c434ca7ecc430a..0b8e000a2f276b 100644 --- a/third_party/xla/xla/client/lib/arithmetic.h +++ b/third_party/xla/xla/client/lib/arithmetic.h @@ -16,75 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_ARITHMETIC_H_ #define XLA_CLIENT_LIB_ARITHMETIC_H_ -#include -#include -#include - -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -using XlaOpGenerator = std::function; - -// Creates a scalar computation based on a lambda and returns it. -XlaComputation CreateScalarComputation(const std::string& name, - PrimitiveType type, XlaBuilder* builder, - XlaOpGenerator generator); - -// Creates a scalar add computation and returns it. -XlaComputation CreateScalarAddComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar multiply computation and returns it. -XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar ge computation and returns it. -XlaComputation CreateScalarGeComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar max computation and returns it. -XlaComputation CreateScalarMaxComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar min computation and returns it. -XlaComputation CreateScalarMinComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar logical AND computation and returns it. -XlaComputation CreateScalarAndComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar logical OR computation and returns it. -XlaComputation CreateScalarOrComputation(PrimitiveType type, - XlaBuilder* builder); - -// This is to be used for general purpose "identity" like reductions with zero -// for any type (ie. boolean operations for PRED and Add for real numbers). -// As an example, this operation can be used for a situation of: -// x_type = type(x) -// op = CreateScalarIdentityWithZeroComputation(x_type) -// ASSERT_TRUE(op(x, 0) == x) -// -// This functionality is used for operations that are similar to a slice, -// gather, or broadcast, but are created through a reduction. -XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, - XlaBuilder* builder); - -// Returns whether any predicate in "predicates" is set. -// -// Note: if predicates is zero-sized, Any() vacuously returns false. -XlaOp Any(XlaOp predicates); - -// Returns the argmax of `input` along `axis`. `output_type` is the type to -// use for the output. In case of ties always prefers smaller index. -XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); - -// Dispatch to ArgMin or ArgMax above, depending on bool. -XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/arithmetic.h" #endif // XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/third_party/xla/xla/client/lib/broadcast.h b/third_party/xla/xla/client/lib/broadcast.h index d28b28133a7b15..deb85ae9ab8585 100644 --- a/third_party/xla/xla/client/lib/broadcast.h +++ b/third_party/xla/xla/client/lib/broadcast.h @@ -16,20 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_BROADCAST_H_ #define XLA_CLIENT_LIB_BROADCAST_H_ -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/primitive_util.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting -// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. -absl::StatusOr BroadcastTo(XlaOp input, - absl::Span output_dims); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/broadcast.h" #endif // XLA_CLIENT_LIB_BROADCAST_H_ diff --git a/third_party/xla/xla/client/lib/comparators.h b/third_party/xla/xla/client/lib/comparators.h index e5d3de12ca2df1..ad9b37d716d717 100644 --- a/third_party/xla/xla/client/lib/comparators.h +++ b/third_party/xla/xla/client/lib/comparators.h @@ -16,45 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_COMPARATORS_H_ #define XLA_CLIENT_LIB_COMPARATORS_H_ -#include -#include -#include - -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Creates a scalar less-than computation and returns it. The created -// computation has 2 * 'operand_types.size()' many parameters, where parameters -// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The -// computation compares the first two parameters. For floating point types, a -// total order is created where -// -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN -XlaComputation CreateScalarLtComputation( - const std::vector& operand_types, XlaBuilder* builder); - -// Creates a scalar greater-than computation and returns it. The created -// computation has 2 * 'operand_types.size()' many parameters, where parameters -// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The -// computation compares the first two parameters. For floating point types, a -// total order is created where -// NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN -XlaComputation CreateScalarGtComputation( - const std::vector& operand_types, XlaBuilder* builder); - -// Creates a scalar comparison computation and returns it. This function takes -// a vector of comparator functions to compare the operands where the function -// isn't nullopt with the specified comparator at that location. -XlaComputation CreateScalarComparisonComputation( - const std::string& name, const std::vector& operand_types, - const std::vector< - std::optional)>>& - generators, - XlaBuilder* builder); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/comparators.h" #endif // XLA_CLIENT_LIB_COMPARATORS_H_ diff --git a/third_party/xla/xla/client/lib/constants.h b/third_party/xla/xla/client/lib/constants.h index 6f25b82d077cb9..2135f481977396 100644 --- a/third_party/xla/xla/client/lib/constants.h +++ b/third_party/xla/xla/client/lib/constants.h @@ -16,125 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_CONSTANTS_H_ #define XLA_CLIENT_LIB_CONSTANTS_H_ -#include - -#include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/primitive_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/types.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/statusor.h" - -namespace xla { - -// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is -// determined at C++ run-time, rather than C++ compile-time. -// If 'value' is floating point but 'type' is not, or if 'value' is complex but -// 'type' is not, an error will be returned. This is to catch accidental -// truncation; in such cases, use an explicit cast. -template -XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { - if (std::is_floating_point::value && - !(primitive_util::IsFloatingPointType(type) || - primitive_util::IsComplexType(type))) { - return builder->ReportError(InvalidArgument( - "Invalid cast from floating point type to %s in ConstantR0WithType.", - PrimitiveType_Name(type))); - } - if (std::is_same::value && - !primitive_util::IsComplexType(type)) { - return builder->ReportError(InvalidArgument( - "Invalid cast from complex type to %s in ConstantR0WithType.", - PrimitiveType_Name(type))); - } - return primitive_util::PrimitiveTypeSwitch( - [&](auto primitive_type_constant) -> XlaOp { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = primitive_util::NativeTypeOf; - return ConstantR0(builder, static_cast(value)); - } - return builder->ReportError( - InvalidArgument("Invalid type for ConstantR0WithType (%s).", - PrimitiveType_Name(type))); - }, - type); -} - -// Returns a scalar containing 'value' cast to the same run-time type as -// 'prototype'. -// If 'value' is floating point but 'prototype' is not, or if 'value' is complex -// 'prototype' is not, an error will be returned. -template -XlaOp ScalarLike(XlaOp prototype, T value) { - XlaBuilder* builder = prototype.builder(); - return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); - return ConstantR0WithType(builder, shape.element_type(), value); - }); -} - -// Returns an array or scalar containing copies of `value` cast to the same -// run-type type as `prototype` and broadcast to the same dimensions as -// `prototype`. -// -// If `prototype` is not a scalar or array, returns an error. -template -XlaOp FullLike(XlaOp prototype, T value) { - XlaBuilder* builder = prototype.builder(); - return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); - if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { - return Broadcast(ScalarLike(prototype, value), shape.dimensions()); - } else { - return InvalidArgument( - "Prototype shape for BroadcastConstantLike must be a scalar or " - "array, but was %s", - shape.ToString()); - } - }); -} - -// Returns a scalar with value '0' of 'type'. -XlaOp Zero(XlaBuilder* builder, PrimitiveType type); - -// Returns a zero-filled tensor with shape `shape`. -XlaOp Zeros(XlaBuilder* builder, const Shape& shape); - -// Returns a zero-filled tensor with the same shape as `prototype`. -XlaOp ZerosLike(XlaOp prototype); - -// Returns a scalar with value '1' of 'type'. -XlaOp One(XlaBuilder* builder, PrimitiveType type); - -// Returns the machine epsilon for floating-point type `type`, i.e., -// the difference between 1.0 and the next representable value. -XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type); - -// Returns the minimum representable finite or infinite value for 'type'. -// Returns '-inf' for floating-point types. -XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); - -// Returns the minimum representable finite value for 'type'. For a floating -// point type, this is equal to -MaxFiniteValue(). -XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); - -// Returns the minimum positive normal value for floating-point type `type`. -XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type); - -// Returns the maximum representable finite or infinite value for 'type'. -// Returns 'inf' for floating-point types. -XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); - -// Returns the maximum representable finite value for 'type'. -XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); - -// Returns a nan for the given type. Only valid for real-valued fp types. -XlaOp NanValue(XlaBuilder* builder, PrimitiveType type); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/constants.h" #endif // XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/third_party/xla/xla/client/lib/conv_grad_size_util.h b/third_party/xla/xla/client/lib/conv_grad_size_util.h index ca56ada8b55f25..e991982968da9e 100644 --- a/third_party/xla/xla/client/lib/conv_grad_size_util.h +++ b/third_party/xla/xla/client/lib/conv_grad_size_util.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ #define XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ -#include "absl/status/statusor.h" -#include "xla/client/padding.h" - -namespace xla { - -// Information about a single spatial dimension for a convolution gradients and -// windowed operations. -struct SpatialDimensionOutputSizeAndPadding { - // Effective size of the operation output (potentially expanded). - int64_t output_size; - // Number of padding elements to be added before/after this dimension of - // the input when computing the input gradient. - int64_t pad_before; - int64_t pad_after; -}; - -// Verifies that the dimensions all match, and computes the size and padding of -// a spatial dimension for convolution gradient operations. -absl::StatusOr -ConvGradExtractAndVerifyDimension(int64_t input_size, int64_t filter_size, - int64_t output_size, int64_t dilation, - int64_t stride, Padding padding); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/conv_grad_size_util.h" #endif // XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ diff --git a/third_party/xla/xla/client/lib/dynamic_shaped_ops.h b/third_party/xla/xla/client/lib/dynamic_shaped_ops.h index 31305bd90a7b58..cf62a37d6f920e 100644 --- a/third_party/xla/xla/client/lib/dynamic_shaped_ops.h +++ b/third_party/xla/xla/client/lib/dynamic_shaped_ops.h @@ -16,44 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ #define XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/primitive_util.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Similar to static shaped conditional, but allows true_computation and -// false_computation to have different dimension sizes (ranks still have to be -// the same). Fall back to static conditional if dynamism is not presented. -XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate, - XlaOp true_operand, - const XlaComputation& true_computation, - XlaOp false_operand, - const XlaComputation& false_computation); - -// Similar to DynamicConditional, but support multiple branches. -XlaOp DynamicConditional( - XlaBuilder* builder, XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - -// Similar to SetDimensionSize, but automatically adjust the bound of output if -// a tighter one can be inferred by `value_inference`. -absl::StatusOr SetDimensionSizeWithRebound( - ValueInference* value_inference, XlaOp operand, XlaOp dimension_size, - int64_t dimension); - -// Take a `operand` tensor and a R1 tensor `size_vector` representing the sizes -// of `operand`, Call SetDimensionSize if for each dimension whose size is -// dynamic. -absl::StatusOr SetAllDimensionSizes(ValueInference* value_inference, - XlaOp operand, XlaOp size_vector); -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" #endif // XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ diff --git a/third_party/xla/xla/client/lib/loops.h b/third_party/xla/xla/client/lib/loops.h index 3b9855e58cc3dd..d714efeaa415f1 100644 --- a/third_party/xla/xla/client/lib/loops.h +++ b/third_party/xla/xla/client/lib/loops.h @@ -16,59 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_LOOPS_H_ #define XLA_CLIENT_LIB_LOOPS_H_ -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Function that builds a loop condition. Takes as input a sequence of input -// values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, - XlaBuilder*)> - WhileLoopHelperConditionFunction; - -// Function that builds a loop body. Takes as input a sequence of input values -// and returns a sequence of output values. -typedef std::function>( - absl::Span, XlaBuilder*)> - WhileLoopHelperBodyFunction; - -// Helper function for building an XLA while loop, where the values carried by -// the loop are a tuple of values, e.g., (a, b, c): -// while( -// condition: (a, b, c) -> bool, -// body: (a, b, c) -> (a, b, c) -// init: (a, b, c) -// ) -// 'name' is a descriptive name for the loop. -absl::StatusOr> WhileLoopHelper( - const WhileLoopHelperConditionFunction& condition_function, - const WhileLoopHelperBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - XlaBuilder* builder); - -// Builds an XLA loop that repeats a computation `num_iterations` times. -// -// The body function (ForEachIndexBodyFunction) takes as input a pair of -// (current iteration number, loop-carried values), and returns an updated -// vector of the loop-carried values. -typedef std::function>( - XlaOp, absl::Span, XlaBuilder*)> - ForEachIndexBodyFunction; - -absl::StatusOr> ForEachIndex( - int64_t num_iterations, PrimitiveType num_iterations_type, - const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - XlaBuilder* builder); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/loops.h" #endif // XLA_CLIENT_LIB_LOOPS_H_ diff --git a/third_party/xla/xla/client/lib/lu_decomposition.h b/third_party/xla/xla/client/lib/lu_decomposition.h index a2d26e02f4e635..752e84c9d2b12f 100644 --- a/third_party/xla/xla/client/lib/lu_decomposition.h +++ b/third_party/xla/xla/client/lib/lu_decomposition.h @@ -16,46 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ #define XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Computes the LU decomposition with partial pivoting of a batch of matrices. -// -// Given a (batched) matrix a with shape [..., m, n], computes the matrix -// decomposition A = P @ L @ U where P is a permutation matrix, L is a -// lower-triangular matrix with unit diagonal entries, and U is an -// upper-triangular matrix. -// -// L and U are returned as a single matrix [..., m, n] containing both L and U -// packed in the same array. The unit diagonal of L is not represented -// explicitly. -// -// The permutation matrix P is returned in two forms, both as `pivots`, which is -// an s32[..., min(m, n)] array that describes a sequence of row-swaps in the -// style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array -// which gives the permutation to apply to the rows. We return both -// representations because they are each useful for different purposes; `pivots` -// is useful for computing the sign of a determinant, whereas `permutation` can -// be used via a Gather operation to permute the rows of a matrix. -// -// This method is only implemented on TPU at the moment. -// TODO(b/168208200): the implementation only supports F32 arrays. Handle the -// complex case. -struct LuDecompositionResult { - // The LU decomposition, with both L and U packed into an array with shape - // [..., m, n]. - XlaOp lu; - // An array of shape s32[..., min(m, n)] containing the pivot rows. - XlaOp pivots; - // An array of shape s32[..., m], containing an another representation of the - // pivots as a permutation. - XlaOp permutation; -}; - -LuDecompositionResult LuDecomposition(XlaOp a); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/lu_decomposition.h" #endif // XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ diff --git a/third_party/xla/xla/client/lib/math.h b/third_party/xla/xla/client/lib/math.h index 74b8a387a416de..9956776ee87d1a 100644 --- a/third_party/xla/xla/client/lib/math.h +++ b/third_party/xla/xla/client/lib/math.h @@ -16,112 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_MATH_H_ #define XLA_CLIENT_LIB_MATH_H_ -#include "xla/client/xla_builder.h" - -namespace xla { - -// Determines whether operand is +/-inf or nan. -// -// Raises an error if called on integral or complex values. -XlaOp IsPosInf(XlaOp operand); -XlaOp IsNegInf(XlaOp operand); -XlaOp IsInf(XlaOp operand); -XlaOp IsNan(XlaOp operand); - -// Determines whether operand is equal to -0. -// -// Raises an error for integral or complex values. -XlaOp IsNegZero(XlaOp operand); - -// Returns the next number after 'from' in the direction of 'to' the same way -// std::nextafter(from, to) would. -XlaOp NextAfter(XlaOp from, XlaOp to); - -// Computes the square of 'operand'. -XlaOp Square(XlaOp operand); - -// Computes the reciprocal of 'operand'. -XlaOp Reciprocal(XlaOp operand); - -// Computes an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x); - -// Computes an approximation of the inverse of the error function. -XlaOp ErfInv(XlaOp x); - -// Computes an approximation of the lgamma function. -XlaOp Lgamma(XlaOp input); - -// Computes an approximation of the digamma function. -XlaOp Digamma(XlaOp input); - -// Computes an approximation of the incomplete gamma function. -XlaOp Igamma(XlaOp a, XlaOp x); - -// Computes an approximation of the derivative of the incomplete gamma function -// with respect to a. -XlaOp IgammaGradA(XlaOp a, XlaOp x); - -// Computes an approximation of the derivative of a sample `x` from a `Gamma(a, -// 1)` distribution with respect to a. -XlaOp RandomGammaGrad(XlaOp a, XlaOp x); - -// Computes an approximation of the complementary incomplete gamma function. -XlaOp Igammac(XlaOp a, XlaOp x); - -// Computes the Polygamma of two arguments. -XlaOp Polygamma(XlaOp n, XlaOp x); - -// Computes the Riemann zeta function of two arguments. -XlaOp Zeta(XlaOp x, XlaOp q); - -// Rounds the given number to even when the number is equidistant between two -// integers. -XlaOp RoundToEven(XlaOp x); - -// Trigonometric functions - -// Computes the arc cosine of 'x'. -XlaOp Acos(XlaOp x); - -// Computes the arc sine of 'x'. -XlaOp Asin(XlaOp x); - -// Computes the arc tangent of 'x'. -XlaOp Atan(XlaOp x); - -// Hyperbolic trigonometric functions - -// Computes the inverse hyperbolic cosine of 'x'. -XlaOp Acosh(XlaOp x); - -// Computes the inverse hyperbolic sine of 'x'. -XlaOp Asinh(XlaOp x); - -// Computes the inverse hyperbolic tangent of 'x'. -XlaOp Atanh(XlaOp x); - -// Computes the hyperbolic cosine of 'x'. -XlaOp Cosh(XlaOp x); - -// Computes the hyperbolic sine of 'x'. -XlaOp Sinh(XlaOp x); - -// Applies a complex conjugation operation if 'a' is complex and 'conjugate' -// is true, otherwise returns its argument. -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); - -// Computes the Modified Bessel function of the first kind of the zeroth order -// at x. -XlaOp BesselI0e(XlaOp x); - -// Computes the Modified Bessel function of the first kind of the first order -// at x. -XlaOp BesselI1e(XlaOp x); - -// Computes the Regularized Incomplete Beta function. -XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/math.h" #endif // XLA_CLIENT_LIB_MATH_H_ diff --git a/third_party/xla/xla/client/lib/matrix.h b/third_party/xla/xla/client/lib/matrix.h index df3a2e878d88a7..aaf938786fc020 100644 --- a/third_party/xla/xla/client/lib/matrix.h +++ b/third_party/xla/xla/client/lib/matrix.h @@ -16,144 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_MATRIX_H_ #define XLA_CLIENT_LIB_MATRIX_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere -// else. -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64_t m, - int64_t n); - -// Returns a mask where the 'diagonal'-th diagonal is true and everything else -// is false. -XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0); - -// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the -// main diagonal, and k<0 for diagonals below the main diagonal. -// -// If 'x' has shape [..., M, N] -// If k >= 0: then the output has shape [..., min(M, N - k)], containing the -// diagonal elements (i.e., with indices [..., i, i + k]). -// If k < 0: then the output has shape [..., min(M + k, N)], containing the -// diagonal elements (i.e., with indices [..., i - k, i]). -XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); -XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0); - -// Places diag along the kth diagonal of target. -XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0); - -// Returns a lower-triangular mask, i.e., true below and including the -// `diagonal`-th diagonal and false above that diagonal. -XlaOp TriangleMask(XlaOp x, int diagonal); - -// Get the upper or lower triangle part of the last two dimensions -XlaOp Triangle(XlaOp x, bool lower); - -// Get the upper triangle part of the last two dimensions -XlaOp UpperTriangle(XlaOp x); - -// Get the lower triangle part of the last two dimensions -XlaOp LowerTriangle(XlaOp x); - -// If x is an array of shape [..., n, n], symmetrizes the matrix by replacing -// the upper triangle with the transpose of the lower triangle (if lower is -// True, vice-versa otherwise). If the type of `x` is complex, makes the matrix -// Hermitian by taking the conjugate of the complex part and setting the -// complex diagonal to zero. -XlaOp Symmetrize(XlaOp x, bool lower); - -// Multiplies slices of two tensors in batches. - -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if transpose_x else r_x -// c_o = r_y if transpose_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot( - xla::XlaOp x, xla::XlaOp y, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt); -xla::XlaOp BatchDot( - xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt, - bool grad_x = false, bool grad_y = false); - -// Parse an einsum string into dimension numbers: -// "ab,cb->ac" -// becomes: -// {{0, 1},{2, 1},{0, 2}} -// -// Each occurrence of ellipsis ("...") occurring in the input is replaced with -// the same numeric dimensions. The number of such dimensions is inferred from -// x_rank and y_rank. For example: -// einsum_config: "...ab,...bcd->...acd" -// x_rank: 4 -// y_rank: 5 -// becomes: -// {{0, 1, 2, 3},{0, 1, 3, 4, 5},{0, 1, 2, 4, 5}} -// -// NOTE: This function is meant for testing, there is no need to call it -// directly. - -absl::StatusOr, 3>> ParseEinsumString( - absl::string_view einsum_config, int64_t x_rank, int64_t y_rank); - -// If an einsum config does not contain an -> one will be added and the output -// config will be the sorted characters with any ellipsis at the beginning. -// Returns an empty string if the einsum string already has an ->. -std::string NormalizeEinsumString(absl::string_view einsum_config); - -// Supports two operand einsum notation like "ab,cb->ac". -xla::XlaOp Einsum( - xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt, - bool grad_x = false, bool grad_y = false); -xla::XlaOp Einsum( - xla::XlaOp x, absl::string_view einsum_config, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); - -// Same as above but supporting numeric labels on dimensions. So "ab,cb->ac" -// becomes: -// x_config = {0, 1} -// y_config = {2, 1} -// output_config = {0, 2} -xla::XlaOp Einsum( - xla::XlaOp x, absl::Span x_config, xla::XlaOp y, - absl::Span y_config, absl::Span output_config, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt, - bool grad_x = false, bool grad_y = false); - -// Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::XlaOp TransposeInMinorDims(xla::XlaOp x); - -// Transposes `x` in its minor dimensions if `transpose` is true, otherwise -// returns `x` unchanged. -xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/matrix.h" #endif // XLA_CLIENT_LIB_MATRIX_H_ diff --git a/third_party/xla/xla/client/lib/pooling.h b/third_party/xla/xla/client/lib/pooling.h index eb0a43029b359d..22f3d2f0b07b9c 100644 --- a/third_party/xla/xla/client/lib/pooling.h +++ b/third_party/xla/xla/client/lib/pooling.h @@ -16,68 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_POOLING_H_ #define XLA_CLIENT_LIB_POOLING_H_ -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" - -namespace xla { - -// Tensor format for reduce window operations. -class TensorFormat { - public: - TensorFormat(int batch_dimension, int feature_dimension, - absl::Span spatial_dimensions) - : batch_dimension_(batch_dimension), - feature_dimension_(feature_dimension), - spatial_dimensions_(spatial_dimensions.begin(), - spatial_dimensions.end()) {} - - int batch_dimension() const { return batch_dimension_; } - - int feature_dimension() const { return feature_dimension_; } - - int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } - - int num_spatial_dims() const { return spatial_dimensions_.size(); } - - private: - // The number of the dimension that represents the batch. - int batch_dimension_; - // The number of the dimension that represents the features. - int feature_dimension_; - // The dimension numbers for the spatial dimensions. - absl::InlinedVector spatial_dimensions_; -}; - -// Computes the max pool of 'operand'. -XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, - absl::Span stride, Padding padding, - const TensorFormat& data_format); - -// Computes the average pool of 'operand'. -XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, - absl::Span stride, - absl::Span> padding, - const TensorFormat& data_format, bool counts_include_padding); - -// Returns the list of low and high padding elements in each spatial dimension -// for the given 'padding' specification. -std::vector> MakeSpatialPadding( - absl::Span input_size, absl::Span kernel_size, - absl::Span stride, Padding padding, - const TensorFormat& data_format); - -// Computes the average pool gradient. -XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, - absl::Span kernel_size, - absl::Span stride, - absl::Span> spatial_padding, - const TensorFormat& data_format, bool counts_include_padding); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/pooling.h" #endif // XLA_CLIENT_LIB_POOLING_H_ diff --git a/third_party/xla/xla/client/lib/prng.h b/third_party/xla/xla/client/lib/prng.h index ef78a881c19460..0c9e460ba10cbb 100644 --- a/third_party/xla/xla/client/lib/prng.h +++ b/third_party/xla/xla/client/lib/prng.h @@ -16,86 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_PRNG_H_ #define XLA_CLIENT_LIB_PRNG_H_ -#include -#include -#include - -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/shape.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Records the bits and state generated by a random number generator. -struct RngOutput { - XlaOp value; - XlaOp state; -}; - -// A BitGenerator returns random bits and updated random bit generator state. -// -// key: is a value input to a random number generator that can affect the -// sequence of number it will generate. A random number generator constructs -// its seed using the key and the initial state. The tf2xla bridge passes the -// seed operand of a tensorflow random operation as a key to the random bit -// generator, for example. -// initial_state: initial_state is the initial state of the current random -// number generation. It could be 0 for a stateless random operation, and -// the returned state from a previous execution for a stateful random -// operation. -// shape: the shape of the random bits. -using BitGeneratorTy = std::function; - -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, - const xla::Shape& shape); - -// Implements the Philox algorithm to generate random numbers in parallel. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -// -// The paper presents a few variants of the Philox algorithm, we picked the -// 4x32_10 version of the algorithm for the following reasons: -// . 4x32 uses 32-bit multiplication which is fast on GPUs. -// . The authors recommend the 10-round variant, and TensorFlow also uses it. -RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, - const Shape& shape); -// Returns a scrambled pair of (state, key) from a single key. -std::pair ScramblePhiloxKey(XlaOp key); - -// Uses the given bit generator to generate random bits and then converts the -// random bits to random numbers of uniform distribution in the given range. -// Returns the random numbers and the state of the random number generator. -// This function is for shape with floating point element types. -RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state, - BitGeneratorTy bit_generator, - XlaOp minval, XlaOp maxval, - const xla::Shape& shape); - -// Similar to UniformFloatingPointDistribution but for shape with integer -// element types. -RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, - BitGeneratorTy bit_generator, XlaOp minval, - XlaOp maxval, const xla::Shape& shape); - -// Uses the given bit generator to generate random bits and then converts the -// random bits to random numbers of normal distribution. -// Returns the random numbers and the state of the random number generator. -RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, - BitGeneratorTy bit_generator, - const xla::Shape& shape); - -// Concatenates scalars into a vector. -xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, - absl::Span scalars); - -// Increases Philox counter (an uint128_t) by a delta (an uint64_t). -xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/prng.h" #endif // XLA_CLIENT_LIB_PRNG_H_ diff --git a/third_party/xla/xla/client/lib/qr.h b/third_party/xla/xla/client/lib/qr.h index ce51ab342bb39b..743b36503b6175 100644 --- a/third_party/xla/xla/client/lib/qr.h +++ b/third_party/xla/xla/client/lib/qr.h @@ -16,37 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_QR_H_ #define XLA_CLIENT_LIB_QR_H_ -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Computes the QR decompositions of a batch of matrices. That is, -// given a (batched) matrix a, computes an orthonormal matrix Q and an -// upper-triangular matrix R such that a = QR. -// `a` must be a (batched) matrix of size [..., m, n]. -struct QrDecomposition { - // A matrix with the same shape as the input matrix `a`, whose upper triangle - // (inclusive of the diagonal) is the matrix R, and whose lower triangle - // (exclusive of the diagonal) contains the elementary Householder reflectors. - // This is the same output format as used by LAPACK's xGEQRF routine. - XlaOp q_and_r; - // A vector of shape [..., min(m, n)] containing the scalar factors of the - // elementary Householder reflectors. - XlaOp taus; -}; - -QrDecomposition Qr(XlaOp a); - -// Given `a` and `taus` as returned by `QRDecomposition`, compute the product of -// the elementary Householder reflectors (i.e., the matrix Q of the QR -// decomposition). The equivalent LAPACK routine is xORGQR/xUNGQR. -XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus); - -// Helper that combines `Qr` and `ProductOfElementaryHouseholderReflectors` to -// compute explicit matrices `q` and `r`. -void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/qr.h" #endif // XLA_CLIENT_LIB_QR_H_ diff --git a/third_party/xla/xla/client/lib/quantize.h b/third_party/xla/xla/client/lib/quantize.h index f9835c42642d32..459716b36b54db 100644 --- a/third_party/xla/xla/client/lib/quantize.h +++ b/third_party/xla/xla/client/lib/quantize.h @@ -16,169 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_QUANTIZE_H_ #define XLA_CLIENT_LIB_QUANTIZE_H_ -#include -#include -#include -#include - -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/types.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/bfloat16.h" - -namespace xla { - -// Represents the range used for quantization -struct QuantizedRange { - QuantizedRange() = default; - QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} - - bool operator==(const QuantizedRange& rhs) const { - return this->min == rhs.min && this->max == rhs.max; - } - - bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } - - tsl::bfloat16 min = tsl::bfloat16(0.0f); - tsl::bfloat16 max = tsl::bfloat16(0.0f); -}; - -template -inline std::vector PackToUint32(absl::Span input) { - const int64_t kElementsPerPack = sizeof(uint32_t) / sizeof(T); - const int64_t input_size = input.size(); - const int64_t output_size = CeilOfRatio(input_size, kElementsPerPack); - - std::vector output_vec; - constexpr int64_t kShiftBits = sizeof(T) / sizeof(uint8_t) * CHAR_BIT; - - for (int64_t i = 0; i < output_size; i++) { - uint32_t result = 0; - for (int64_t p = 0; p < kElementsPerPack; p++) { - int64_t index = i * kElementsPerPack + p; - if (index < input_size) { - int64_t total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); - result |= (input[index] << total_shift_bits); - } - } - output_vec.push_back(result); - } - - return output_vec; -} - -// Dequantize the quantized input of packed uint32_t to bfloat16. -// Only uint8_t or uint16_t is supported for the original unpacked input. -// Returns a tensor of shape [d0,..., dn * unpack_size] if -// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). -// If transpose_output is true, will return a tensor of shape -// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when -// input's rank higher than 1. The input needs to be transposed to use -// transpose_output feature. -template -inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, - absl::string_view mode_string = "MIN_COMBINED", - bool transpose_output = false) { - XlaBuilder* const builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - float half_range = - !std::is_signed::value - ? 0.0f - : (static_cast(std::numeric_limits::max()) - - std::numeric_limits::min() + 1) / - 2.0f; - const int64_t unpack_size = sizeof(uint32_t) / sizeof(T); - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); - - auto element_type = shape.element_type(); - if (element_type != U32) { - return InvalidArgument( - "Only U32 is supported for input type of xla::Dequantize Op."); - } - - // Broadcast the input to [unpack_size, d0, ..., dn] if input size is - // [d0, ..., dn]. - auto broadcast_input = Broadcast(input, {unpack_size}); - - XlaOp iota_r1 = Iota(builder, U32, unpack_size); - // Highest significant bytes needs to shift more bytes than lower - // significant bytes. - XlaOp shift_bytes = - xla::ConstantR0(builder, unpack_size - 1) - iota_r1; - - const int bytes_of_type = sizeof(T) / sizeof(uint8_t); - std::vector shift_vec(unpack_size, CHAR_BIT * bytes_of_type); - XlaOp shift_bits = - shift_bytes * xla::ConstantR1(builder, shift_vec); - - // Make bit_mask for different data type T. - uint32_t bit_mask = 0x00000000; - for (int i = 0; i < bytes_of_type; i++) { - bit_mask <<= CHAR_BIT; - bit_mask |= 0x000000ff; - } - - std::vector shift_transpose_dimensions(shape.dimensions_size()); - std::iota(shift_transpose_dimensions.begin(), - shift_transpose_dimensions.end(), 0); - shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, - shape.dimensions_size()); - - // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. - XlaOp shifted_input = ShiftRightLogical( - broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), - shift_transpose_dimensions)); - XlaOp unpack_input = - And(shifted_input, xla::ConstantR0(builder, bit_mask)); - - XlaOp result; - - if (mode_string == "MIN_COMBINED") { - const tsl::bfloat16 scale_factor = - (range.max - range.min) / - (static_cast(std::numeric_limits::max() - - std::numeric_limits::min())); - // result = bfloat16(input + half_range) * scale_factor + range.min - XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); - XlaOp half_range_bf16 = xla::ConstantR0( - builder, static_cast(half_range)); - XlaOp sum = unpack_input_bf16 + half_range_bf16; - - result = sum * xla::ConstantR0(builder, scale_factor) + - xla::ConstantR0(builder, range.min); - } else { - // TODO(wangtao): support other modes. - return InvalidArgument( - "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); - } - - std::vector transpose_dimensions(shape.dimensions_size()); - std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); - std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); - transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); - - // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. - XlaOp transposed_result = Transpose(result, transpose_dimensions); - - // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. - XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); - - // Return the transpose result if transpose_output is true. - if (transpose_output) { - return reshaped_result; - } - - // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. - std::vector result_dimensions(shape.dimensions_size()); - std::iota(result_dimensions.begin(), result_dimensions.end(), 0); - std::reverse(result_dimensions.begin(), result_dimensions.end()); - - return Transpose(reshaped_result, result_dimensions); - }); -} - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/quantize.h" #endif // XLA_CLIENT_LIB_QUANTIZE_H_ diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig.h b/third_party/xla/xla/client/lib/self_adjoint_eig.h index f375f192e71f0e..ae81dbc0baf5a0 100644 --- a/third_party/xla/xla/client/lib/self_adjoint_eig.h +++ b/third_party/xla/xla/client/lib/self_adjoint_eig.h @@ -16,26 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ #define XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// The eigenvalue decomposition of a symmetric matrix, the original matrix is -// recovered by v * w * v_t. -struct SelfAdjointEigResult { - // The i-th column is the normalized eigenvector corresponding to the - // eigenvalue w[i]. Will return a matrix object if a is a matrix object. - XlaOp v; - // The eigenvalues in ascending order, each repeated according to its - // multiplicity. - XlaOp w; -}; - -SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, - int64_t max_iter = 15, float tol = 1e-5, - bool sort_eigenvalues = true); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/self_adjoint_eig.h" #endif // XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ diff --git a/third_party/xla/xla/client/lib/slicing.h b/third_party/xla/xla/client/lib/slicing.h index 329f299e40a896..c2ea243ae2c937 100644 --- a/third_party/xla/xla/client/lib/slicing.h +++ b/third_party/xla/xla/client/lib/slicing.h @@ -13,71 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/types.h" - #ifndef XLA_CLIENT_LIB_SLICING_H_ #define XLA_CLIENT_LIB_SLICING_H_ -namespace xla { - -// Updates a slice of 'x', i.e., -// x[start[0], ..., start[n]] = update -XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); - -// Performs a slice in the minor dimensions of a tensor. -// x[..., start[0]:end[0], ..., start[n]:end[n]] -XlaOp SliceInMinorDims(XlaOp x, absl::Span start, - absl::Span end); - -// Updates a slice of 'x', where 'start' contains a list of minor dimensions: -// x[..., start[0]:..., ..., start[n]:...] = update -XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, - absl::Span start); - -// Performs a dynamic slice in the minor dimensions of a tensor. -XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, - absl::Span sizes); - -XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, - absl::Span starts); - -// Gathers values along an axis specified by dim. -// -// For a 3-D tensor the output is specified by: -// -// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 -// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 -// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 -// -// If `input` is an n-dimensional tensor with size -// [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size -// [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as -// `index`. -XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse = true); - -// idx = index[i][j][k] -// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0 -// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1 -// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2 -XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, - const std::function& combiner); - -// Returns a new tensor which indexes the input tensor along dimension dim using -// the entries in index. -// -// The returned tensor has the same number of dimensions as the original tensor -// (input). The dimth dimension has the same size as the length of index; other -// dimensions have the same size as in the original tensor. -// -// This operation supports 0 or more major batch dimensions that act like a -// multidimensional loop over both the input and the index. -XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, - int64_t batch_dims = 0); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/slicing.h" #endif // XLA_CLIENT_LIB_SLICING_H_ diff --git a/third_party/xla/xla/client/lib/sorting.h b/third_party/xla/xla/client/lib/sorting.h index 4af4f8caaf977e..5cb81a43c11f36 100644 --- a/third_party/xla/xla/client/lib/sorting.h +++ b/third_party/xla/xla/client/lib/sorting.h @@ -16,23 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_SORTING_H_ #define XLA_CLIENT_LIB_SORTING_H_ -#include "xla/client/xla_builder.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Returns a tuple composed of the top `k` values and corresponding indices in -// `input`. Output values are in descending order, from largest to smallest. -XlaOp TopK(XlaOp input, int64_t k, - PrimitiveType index_type = PrimitiveType::S32); - -// Split sort in TopK into smaller sorts. -// Returns a tuple composed of the top `k` values and corresponding indices in -// `input`. Output values are in descending order, from largest to smallest. -XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions = 1, - PrimitiveType index_type = PrimitiveType::S32); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/sorting.h" #endif // XLA_CLIENT_LIB_SORTING_H_ diff --git a/third_party/xla/xla/client/lib/svd.h b/third_party/xla/xla/client/lib/svd.h index 07f361f73b3a3f..54893697c5fced 100644 --- a/third_party/xla/xla/client/lib/svd.h +++ b/third_party/xla/xla/client/lib/svd.h @@ -16,34 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_SVD_H_ #define XLA_CLIENT_LIB_SVD_H_ -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// The singular value decomposition of a given matrix A[..., M, N], the original -// matrix is recovered by u * diag(d) * v_t, where the first dims(A) - 2 -// dimensions are batch dimensions. -struct SVDResult { - // The columns of U are the left-singular vectors, e.g., - // U[..., :, :]_T * U[..., :, :] = I. - XlaOp u; - // Vector(s) with the singular values, within each vector sorted in descending - // order. The first dims(D) - 1 dimensions have the same size as the batch - // dimensions of A. And U[..., :, i] * D[..., i] = A[..., :, :] * V[..., :, - // i]. - XlaOp d; - // The columns of V are the right-singular vectors. e.g., - // V[..., :, :]_T * V[..., :, :] = I. - XlaOp v; -}; - -// TODO(kuny): Add a bool flag that supports SVD with economy (reduced) -// representation, which is more memory efficient, especially in the case of -// tall-skinny matrices. -SVDResult SVD(XlaOp a, int64_t max_iter = 100, float epsilon = 1e-6, - PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/svd.h" #endif // XLA_CLIENT_LIB_SVD_H_ diff --git a/third_party/xla/xla/client/lib/testing.cc b/third_party/xla/xla/client/lib/testing.cc index dfda52163ebf1f..61f19b8305a348 100644 --- a/third_party/xla/xla/client/lib/testing.cc +++ b/third_party/xla/xla/client/lib/testing.cc @@ -22,9 +22,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/client/client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/execution_options_util.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/service.h" diff --git a/third_party/xla/xla/client/lib/testing.h b/third_party/xla/xla/client/lib/testing.h index 76e268b36ebb97..a9b566c6635b3f 100644 --- a/third_party/xla/xla/client/lib/testing.h +++ b/third_party/xla/xla/client/lib/testing.h @@ -21,7 +21,7 @@ limitations under the License. #include "xla/client/client.h" #include "xla/client/global_data.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/service.h" #include "xla/shape.h" #include "xla/xla.pb.h" diff --git a/third_party/xla/xla/client/lib/tridiagonal.h b/third_party/xla/xla/client/lib/tridiagonal.h index b24ef6a3d4b71b..5cc51c5e98262e 100644 --- a/third_party/xla/xla/client/lib/tridiagonal.h +++ b/third_party/xla/xla/client/lib/tridiagonal.h @@ -16,28 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_TRIDIAGONAL_H_ #define XLA_CLIENT_LIB_TRIDIAGONAL_H_ -#include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace tridiagonal { - -enum SolverAlgorithm { kThomas }; - -absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, - XlaOp lower_diagonal, - XlaOp main_diagonal, - XlaOp upper_diagonal, XlaOp rhs); - -absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, - XlaOp rhs); - -absl::StatusOr TridiagonalMatMul(XlaOp upper_diagonal, - XlaOp main_diagonal, - XlaOp lower_diagonal, XlaOp rhs); - -} // namespace tridiagonal -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/tridiagonal.h" #endif // XLA_CLIENT_LIB_TRIDIAGONAL_H_ diff --git a/third_party/xla/xla/client/lib/tuple.h b/third_party/xla/xla/client/lib/tuple.h index dd8fb3c6ec82bf..c1dc9de027a50f 100644 --- a/third_party/xla/xla/client/lib/tuple.h +++ b/third_party/xla/xla/client/lib/tuple.h @@ -16,21 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_TUPLE_H_ #define XLA_CLIENT_LIB_TUPLE_H_ -#include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/shape_tree.h" - -namespace xla { - -// Returns a ShapeTree where each index is a GetTupleElement instruction for -// that subshape of the tuple. The root index is the original argument. -absl::StatusOr> DisassembleTuple(XlaOp tuple); - -// Assembles a tuple from a ShapeTree that contains the leaves of the tuple. -// Non-leaf elements of the ShapeTree are ignored. DisassembleTuple and -// AssembleTuple are essentially inverse operations. -XlaOp AssembleTuple(XlaBuilder* builder, ShapeTree elements); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/tuple.h" #endif // XLA_CLIENT_LIB_TUPLE_H_ diff --git a/third_party/xla/xla/client/padding.h b/third_party/xla/xla/client/padding.h index e717183ce2d6c8..a9e928d865da0e 100644 --- a/third_party/xla/xla/client/padding.h +++ b/third_party/xla/xla/client/padding.h @@ -16,52 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_PADDING_H_ #define XLA_CLIENT_PADDING_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/types.h" - -namespace xla { - -// Describes the padding applied for a windowed operation like -// convolution, where a window is placed inside a base area. -enum class Padding { - // Make the output have the same dimensions as the base area. For - // example, for a 3x3 base area and a 2x2 window, the output will be - // 3x3, so that requires padding the 3x3 base area to 4x4. - kSame, - - // Use no padding. For example, for a 4x4 base area and a 2x2 - // window, the output will be 3x3. - kValid, -}; - -// Validates that the slices are acceptable for determining padding -- this can -// be used to check the preconditions of MakePadding below to produce an error -// message that can be returned to the user. -absl::Status ValidatePaddingValues(absl::Span input_dimensions, - absl::Span window_dimensions, - absl::Span window_strides); - -// Returns the padding needed for the base area, given the base area dimensions, -// window dimensions, strides, and the type of padding. -// -// If v is the returned vector, then for each dimension number i, -// v[i].first is the padding to the left (i.e. in the direction of -// lower indices) and v[i].second is the padding to the right (i.e. in -// the direction of higher indices). -// -// Precondition: The number of dimensions (i.e., rank) in input_dimensions, -// window_dimensions, and strides must match, which is equal to the number -// of elements in the result vector. -std::vector> MakePadding( - absl::Span input_dimensions, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/padding.h" #endif // XLA_CLIENT_PADDING_H_ diff --git a/third_party/xla/xla/client/sharding_builder.h b/third_party/xla/xla/client/sharding_builder.h index eef395e0b46368..995978b165f885 100644 --- a/third_party/xla/xla/client/sharding_builder.h +++ b/third_party/xla/xla/client/sharding_builder.h @@ -16,48 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_SHARDING_BUILDER_H_ #define XLA_CLIENT_SHARDING_BUILDER_H_ -#include - -#include "xla/array.h" -#include "xla/shape.h" -#include "xla/shape_tree.h" -#include "xla/shape_util.h" -#include "xla/types.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace sharding_builder { -// A shaped array used to describe the assignment of tiles to devices. -using TileAssignment = Array; - -// Creates a replicated sharding - replicate a tensor on every device. -OpSharding Replicate(); - -// Creates a manual sharding - the partitioner will not change the shape. -OpSharding Manual(); - -// Creates a sharding that assigns a tensor to just one device. -OpSharding AssignDevice(int device); - -// Creates a tiled sharding with the given tile shape and assignment of tiles -// to devices. -// -// If tile_shape is not evenly divisible by the number of devices in -// tile_assignment, operations behave as if implicit padding had been inserted. -// The value of this padding is undefined. -OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment); - -// Creates a sharding in one dimension, with the given tile shape which must -// be rank 1 and using devices [0..num_tiles). -// -// This is simply a convenience wrapper for Tile(). -OpSharding Tile1D(const Shape& tile_shape, int64_t num_tiles); - -// Creates a tuple sharding from the given ShapeTree of element shardings. -OpSharding Tuple(const ShapeTree& shardings); - -} // namespace sharding_builder -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/sharding_builder.h" #endif // XLA_CLIENT_SHARDING_BUILDER_H_ diff --git a/third_party/xla/xla/client/value_inference.h b/third_party/xla/xla/client/value_inference.h index 84c1c99f53fd4d..f717cc703b2502 100644 --- a/third_party/xla/xla/client/value_inference.h +++ b/third_party/xla/xla/client/value_inference.h @@ -15,103 +15,7 @@ limitations under the License. #ifndef XLA_CLIENT_VALUE_INFERENCE_H_ #define XLA_CLIENT_VALUE_INFERENCE_H_ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/hlo/evaluator/hlo_evaluator.h" -#include "xla/hlo/ir/dfs_hlo_visitor.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" - -namespace xla { -// OptionalLiteral is an augmented literal class which returns optional -// values for each index (the value can be either valid or invalid). The -// implementation keeps two literals, a value literal, holding both the valid -// and garabage value, and a masking literal representing if a value is valid or -// garbage. -class OptionalLiteral { - public: - explicit OptionalLiteral(Literal value, Literal mask) - : value_(std::move(value)), mask_(std::move(mask)) {} - - template - std::optional Get(absl::Span element_index, - ShapeIndex shape_index = {}) const { - if (mask_.Get(element_index, shape_index)) { - return std::nullopt; - } else { - return value_.Get(element_index, shape_index); - } - } - - // Returns true if all values in this literal slice are value. - bool AllValid() { return mask_.IsAll(0); } - - // Get value out of this slice if all values are valid. Otherwise returns - // nullopt. - std::optional GetValue() { - if (!AllValid()) { - return std::nullopt; - } - return LiteralSlice(value_); - } - - private: - Literal value_; - Literal mask_; -}; - -enum ValueInferenceMode { - // Inference the constant value itself. - kValue = 0, - // Inference upper-bound and lower-bound of the value. Bounds are inclusive. - kUpperBound, - kLowerBound, -}; - -class ValueInference { - public: - // ValueInference analyzes values in XlaOp answers following questions: - // - What's the upper-bound of each value in a tensor. - // - What's the lower-bound of each value in a tensor. - // - What's the constant value of each tensor. - // - Whether or not each value in a tensor is dynamic. - explicit ValueInference(XlaBuilder* builder) : builder_(builder) { - CHECK(builder_); - } - absl::StatusOr AnalyzeIsDynamic(XlaOp op); - // Returns an OptionalLiteral. Each individual value of the literal is - // the concrete constant value if it can be inferred, otherwise a nullopt. - absl::StatusOr AnalyzeConstant(XlaOp op, - ValueInferenceMode mode); - - // Returns underlying xla builder. - XlaBuilder* builder() { return builder_; } - - private: - // Given an op handle, returns a simplified version of the handle inside a - // int64_t Literal. If the a -1 value for the handle means invalid - // simplification and the result shouldn't be used. - absl::StatusOr SimplifyOp(int64_t handle); - - // Perform CSE on a given handle, and return an equivalent handle if seen - // before. Otherwise, returns nullopt. - absl::StatusOr> CseOpHandle(int64_t handle); - XlaBuilder* builder_; - HloEvaluator evaluator_; - // A map from instruction_hash to handle that helps perform CSE. - absl::flat_hash_map cse_map_; -}; -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/value_inference.h" #endif // XLA_CLIENT_VALUE_INFERENCE_H_ diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h index dd222f1d82095b..1599160a713014 100644 --- a/third_party/xla/xla/client/xla_builder.h +++ b/third_party/xla/xla/client/xla_builder.h @@ -16,3071 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_XLA_BUILDER_H_ #define XLA_CLIENT_XLA_BUILDER_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/functional/function_ref.h" -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/array.h" -#include "xla/array2d.h" -#include "xla/array3d.h" -#include "xla/array4d.h" -#include "xla/client/padding.h" -#include "xla/client/xla_computation.h" -#include "xla/comparison_util.h" -#include "xla/hlo/ir/dynamic_parameter_binding.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout.h" -#include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/tsl/lib/core/bitmap.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/stacktrace.h" - -namespace xla { - -class XlaBuilder; -class XlaOp; -class HloInstruction; - -namespace internal { - -struct XlaBuilderFriend { - static XlaOp BuildAddDependency(XlaBuilder* builder, XlaOp operand, - XlaOp token, const Shape& shape); - - static std::pair BuildAsyncStart( - XlaBuilder* builder, absl::Span operands, - std::string execution_thread, const XlaComputation& called_computation, - const Shape& shape); - static XlaOp BuildAsyncUpdate(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - static XlaOp BuildAsyncDone(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - - static XlaOp BuildAllGatherStart( - XlaBuilder* builder, XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - static XlaOp BuildAllGatherDone(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - - static XlaOp BuildAllReduceStart( - XlaBuilder* builder, XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - static XlaOp BuildAllReduceDone(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - - static XlaOp BuildCollectivePermuteStart( - XlaBuilder* builder, XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id = std::nullopt); - static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - - static XlaOp BuildCopyStart( - XlaBuilder* builder, XlaOp operand, - std::optional cross_program_prefetch_index = std::nullopt); - static XlaOp BuildCopyDone(XlaBuilder* builder, XlaOp operand, - const Shape& shape); - - static XlaOp BuildFusion( - XlaBuilder* builder, absl::Span operands, - absl::string_view fusion_kind, const XlaComputation& fused_computation, - absl::Span>> - output_operand_aliasing = {}); - - static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand, - const Shape& shape); - - static XlaOp BuildPartitionId(XlaBuilder* builder, const Shape& shape); - - static XlaOp BuildSend(XlaBuilder* builder, XlaOp operand, XlaOp token, - const ChannelHandle& handle, bool is_host_transfer); - static XlaOp BuildSendDone(XlaBuilder* builder, XlaOp operand, - const ChannelHandle& handle, - bool is_host_transfer); - - static XlaOp BuildRecv(XlaBuilder* builder, XlaOp token, const Shape& shape, - const ChannelHandle& handle, bool is_host_transfer); - static XlaOp BuildRecvDone(XlaBuilder* builder, XlaOp token, - const Shape& shape, const ChannelHandle& handle, - bool is_host_transfer); - - static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand, OpSharding entry, - OpSharding exit, const Shape& shape); - - static XlaOp BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta, - const Shape& shape); - - static HloInstructionProto* GetInstruction(XlaOp op); - static HloInstructionProto* GetInstructionByHandle(XlaBuilder* builder, - int64_t handle); -}; - -} // namespace internal - -// This represents an instruction that has been enqueued using the XlaBuilder. -// This is used to pass to subsequent computations that depends upon the -// instruction as an operand. -class XlaOp { - public: - XlaOp() : handle_(-1), builder_(nullptr) { - static_assert(std::is_trivially_destructible::value, - "XlaOp should be trivially destructible"); - } - ~XlaOp() = default; - - XlaOp(const XlaOp& other) = default; - XlaOp& operator=(const XlaOp& other) = default; - - // Precondition: !IsUninitialized(). - // - // It's very common to do foo.builder()->bar(). Without this precondition, if - // foo.builder() is null, the call to bar will segfault at some point possibly - // deep in the callstack when we finally dereference `this`. The precondition - // lets us avoid this tricky-to-debug problem. - XlaBuilder* builder() const { - CHECK(builder_ != nullptr); - return builder_; - } - - // Returns true if the XlaOp represents valid, non-erroneous value. - bool valid() const { return handle_ >= 0; } - - // Returns true if the XlaOp was created by the XlaOp() constructor and - // not returned by a builder. - bool IsUninitialized() const { return builder_ == nullptr; } - - bool IsIdenticalTo(XlaOp rhs) const { - return handle_ == rhs.handle_ && builder_ == rhs.builder_; - } - - friend std::ostream& operator<<(std::ostream& out, XlaOp op) { - out << op.handle(); - return out; - } - - private: - explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} - XlaOp(int64_t handle, XlaBuilder* builder) - : handle_(handle), builder_(builder) {} - - int64_t handle() const { return handle_; } - - friend class XlaBuilder; - friend class ValueInference; - friend struct internal::XlaBuilderFriend; - - // < 0 means "invalid handle". - int64_t handle_; - - // Not owned. Non-null for any handle returned by XlaBuilder, even if the - // handle is invalid. - XlaBuilder* builder_; -}; - -// Arithmetic operator overloads for the XlaOp type. -XlaOp operator-(XlaOp x); -XlaOp operator+(XlaOp x, XlaOp y); -XlaOp operator-(XlaOp x, XlaOp y); -XlaOp operator*(XlaOp x, XlaOp y); -XlaOp operator/(XlaOp x, XlaOp y); -XlaOp operator%(XlaOp x, XlaOp y); - -// Bitwise operator overloads for the XlaOp type. -XlaOp operator~(XlaOp x); -XlaOp operator&(XlaOp x, XlaOp y); -XlaOp operator|(XlaOp x, XlaOp y); -XlaOp operator^(XlaOp x, XlaOp y); -XlaOp operator<<(XlaOp x, XlaOp y); -// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs -// a right logical shift. -XlaOp operator>>(XlaOp x, XlaOp y); - -// We don't overload the relational operators (==, !=, <, <=, >, >=) because the -// semantics might be surprising since their result types are usually 'bool'. -// Further programmers may expect == to be a structural equality. -// We also choose not to overload any of the mutating operators (e.g., +=, -=) -// because the semantics might be misleading — XLA computations are immutable. - -// A convenient interface for building up computations. -// -// Thread-compatible. -class XlaBuilder { - public: - // computation_name: name to use for the built computation. - explicit XlaBuilder(const std::string& computation_name); - - XlaBuilder(const XlaBuilder&) = delete; - XlaBuilder& operator=(const XlaBuilder&) = delete; - - virtual ~XlaBuilder(); - - // Returns the computation name. - const std::string& name() const { return name_; } - - // Sets OpMetadata that will be added to all instructions until cleared. - // - // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the computation builder. All subsequent - // instructions generated via this computation builder will have the same - // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } - - // Swaps the passed op metadata with the ones currently set. - // - // Returns the old op metadata. - OpMetadata SwapOpMetadata(OpMetadata metadata) { - OpMetadata old_metadata = std::move(metadata_); - metadata_ = std::move(metadata); - return old_metadata; - } - - // Similar to SetOpMetadata, but only set the metadata for the next op. - void SetOneShotOpMetadata(OpMetadata metadata) { - one_shot_metadata_ = std::move(metadata); - } - - // Clears the HloMetadata state. - void ClearOpMetadata() { metadata_.Clear(); } - - // Sets an OpSharding that will be attached to all instructions until cleared. - void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - - // Sets the FrontendAttributes that will be added to all instructions until - // cleared. - // - // FrontendAttributes are often applied to a series of XLA HLO instructions. - // As a result they are set on the computation builder and all the - // instructions generated via the computation builder will have the same - // frontend attributes attached to them. - virtual void SetFrontendAttributes( - const FrontendAttributes& frontend_attributes) { - frontend_attributes_ = frontend_attributes; - } - - // Swap the passed FrontendAttributes with the ones currently set. - // - // Return the old attributes. - FrontendAttributes SwapFrontendAttributes( - const FrontendAttributes& frontend_attributes) { - FrontendAttributes old_attributes = std::move(frontend_attributes_); - frontend_attributes_ = frontend_attributes; - return old_attributes; - } - - // Returns the FrontendAttributes that will be attached to all instructions. - const FrontendAttributes& frontend_attributes() const { - return frontend_attributes_; - } - - // Clears all the frontend attributes. - void ClearFrontendAttributes() { frontend_attributes_.Clear(); } - - // Clears the sharding. Ops will be sharded according to the default placement - // policy. - void ClearSharding() { sharding_ = std::nullopt; } - - // Returns the OpSharding that will be attached to all instructions. - const std::optional& sharding() const { return sharding_; } - - // Sets the builder to a mode where it will die immediately when an error is - // encountered, rather than producing it in a deferred fashion when Build() is - // called (which is the default). - void set_die_immediately_on_error(bool enabled) { - die_immediately_on_error_ = enabled; - } - - // Default dimension numbers used for a 2D convolution. - static constexpr int64_t kConvBatchDimension = 0; - static constexpr int64_t kConvFeatureDimension = 1; - static constexpr int64_t kConvFirstSpatialDimension = 2; - static constexpr int64_t kConvSecondSpatialDimension = 3; - static constexpr int64_t kConvKernelOutputDimension = 0; - static constexpr int64_t kConvKernelInputDimension = 1; - static constexpr int64_t kConvKernelFirstSpatialDimension = 2; - static constexpr int64_t kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Returns an error if the convolution dimension numbers have conflicts. - static absl::Status Validate(const ConvolutionDimensionNumbers& dnum); - - // Returns a new XlaBuilder whose resultant Computation is used only by this - // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error - // behavior as the parent. - std::unique_ptr CreateSubBuilder( - const std::string& computation_name); - - // Builds the computation with the requested operations, or returns a non-ok - // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. The root of the computation will be the last - // added operation. - // - // `remove_dynamic_dimensions` tells the builder whether to remove the - // dynamic dimensions information in all ops. - // - // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the - // dynamic dimensions information when XLA backend can handle dynamic - // dimensions. - absl::StatusOr Build(bool remove_dynamic_dimensions = false); - - // Overload of Build which specifies a particular root instruction for the - // computation. - absl::StatusOr Build(XlaOp root, - bool remove_dynamic_dimensions = false); - - // Builds the computation with the requested operations, or notes an error in - // the parent XlaBuilder and returns an empty computation if building failed. - // This function is intended to be used where the returned XlaComputation is - // only used by the parent XlaBuilder and hence further operation on the - // returned XlaComputation will simply be error'ed out if an error occurred - // while building this computation. If the built computation is to be used by - // a XlaBuilder other than the parent XlaBuilder then Build() should be used - // instead. - XlaComputation BuildAndNoteError(); - - // Returns a subgraph that roots on the given root. If the root is not a - // compile-time constant (see `IsConstant`), returns an error. - // - // This will copy the needed ops/computations to the subgraph. - absl::StatusOr BuildConstantSubGraph( - XlaOp root_op, bool dynamic_dimension_is_minus_one = false); - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // XlaOp and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - absl::Status first_error() const { return first_error_; } - - // Returns the current status of the builder, complete with the stack trace - // information. - absl::Status GetCurrentStatus() const; - - // Returns the shape of the given op. - absl::StatusOr GetShape(XlaOp op) const; - - // Returns the shape of the given op. - virtual absl::StatusOr GetShapePtr(XlaOp op) const; - - // Returns the OpSharding of the given op. If "op" has no sharding, return - // std::nullopt. - absl::StatusOr> GetOpSharding(XlaOp op) const; - - // Returns the (inferred) result for the current computation's shape. This - // assumes the root instruction is the last added instruction. - absl::StatusOr GetProgramShape() const; - - // Returns the (inferred) result for the current computation's shape using the - // given operation as the root. - absl::StatusOr GetProgramShape(XlaOp root) const; - - // Reports an error to the builder, by - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to - // Build()/GetShape()/...) - // * dying if die_immediately_on_error_ is true. - // Returns an XlaOp with an invalid handle but a valid builder. This value can - // be returned in place of a value in APIs that return an XlaOp. - XlaOp ReportError(const absl::Status& error); - - // A helper function that converts a absl::StatusOr into an XlaOp. - // If the absl::Status was an error, reports the error to builder and returns - // an invalid XlaOp handle. - XlaOp ReportErrorOrReturn(const absl::StatusOr& op); - - // A helper function that runs a function that returns a absl::StatusOr - // and returns an XlaOp. - XlaOp ReportErrorOrReturn( - absl::FunctionRef()> op_creator); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on any parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. - // - // This tests whether a computation is a compile-time constant without - // evaluating the computation. - absl::StatusOr IsConstant(XlaOp operand) const; - - // Adds a new input/output alias. Since the input/output shape information are - // not available until the computation is built, any eventual error in the - // arguments of this API will be detected only at computation Build() time. - // - // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias' - // and only donated buffer at runtime will be aliased with output. If a buffer - // is not donated at runtime, a copy will be inserted by XLA to prevent buffer - // clobbering. - void SetUpAlias(const ShapeIndex& output_index, int64_t param_number, - const ShapeIndex& param_index, - HloInputOutputAliasConfig::AliasKind kind = - HloInputOutputAliasConfig::AliasKind::kMayAlias) { - input_output_aliases_.push_back( - {output_index, param_number, param_index, kind}); - } - - // Describes an input/output alias as inserted by the SetUpAlias() API. - struct InputOutputAlias { - // Specifies the index of the aliased buffer in the result tuple. - ShapeIndex output_index; - // Specifies the parameter containing the buffer to be aliased. - int64_t param_number; - // Specifies the index of the aliased buffer in the parameter. - ShapeIndex param_index; - // Specifies if the alias is a must alias or may alias. - HloInputOutputAliasConfig::AliasKind kind; - }; - - // Adds a new buffer donor. The donated buffer may be paired with any valid - // output. On the contrary, the buffer aliasing bonds the input output pair. - // The input can only donate the buffer to the paired output. - void AddBufferDonor(int64_t param_number, const ShapeIndex& param_index) { - buffer_donors_.insert({param_number, param_index}); - } - - // Looks up the HloInstruction and sets the frontend attribute "attribute" to - // "value". If the attribute already existed, then its value is updated. - // - // The attribute is only added to the HloInstruction, not to the builder. - absl::Status SetInstructionFrontendAttribute(XlaOp op, std::string attribute, - std::string value); - - // Looks up the HloInstruction and sets the sharding. If the sharding already - // existed, then its value is updated. - // - // The sharding is only added to the HloInstruction, not to the builder. - absl::Status SetInstructionSharding( - XlaOp op, const std::optional& sharding); - - // Returns shapes for the operands. - absl::StatusOr> GetOperandShapes( - absl::Span operands) const; - - // Converts the op to string for the ease of debugging. - std::string OpToString(XlaOp op) const; - - private: - void ToStringHelper(std::string* out, int ident, int64_t op_handle) const; - - // Build helper which takes the id of the root operation.. - absl::StatusOr Build(int64_t root_id, - bool remove_dynamic_dimensions); - - // Description for the methods below can be found in the corresponding public - // functions section in this file. - - XlaOp Parameter(int64_t parameter_number, const Shape& shape, - const std::string& name, - const std::vector& replicated_at_leaf_buffers); - XlaOp Parameter(int64_t parameter_number, const Shape& shape, - const std::string& name) { - std::vector empty_bools; - return Parameter(parameter_number, shape, name, empty_bools); - } - - virtual XlaOp ConstantLiteral(const LiteralSlice& literal); - - XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); - - XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); - - // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim - // op from the XlaBuilder. This is only intended for export to MHLO or - // StableHLO, and cannot be compiled. Only static output_dimensions are - // allowed, and broadcast_dimensions is verified. - XlaOp MhloDynamicBroadcastInDim( - XlaOp operand, XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape); - - XlaOp Pad(XlaOp operand, XlaOp padding_value, - const PaddingConfig& padding_config); - XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, - int64_t pad_lo, int64_t pad_hi); - - virtual absl::StatusOr PadInternal( - const Shape& shape, XlaOp operand, XlaOp padding_value, - const PaddingConfig& padding_config); - - XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span new_sizes, - int64_t inferred_dimension = -1); - - XlaOp Reshape(XlaOp operand, absl::Span new_sizes, - int64_t inferred_dimension = -1); - - XlaOp Reshape(const Shape& shape, XlaOp operand, - int64_t inferred_dimension = -1); - - XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, - absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); - - XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, - const Shape& shape); - - XlaOp Collapse(XlaOp operand, absl::Span dimensions); - - XlaOp Slice(XlaOp operand, absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - virtual absl::StatusOr SliceInternal( - const Shape& shape, XlaOp operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, - int64_t limit_index, int64_t stride, int64_t dimno); - - XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); - virtual absl::StatusOr DynamicSliceInternal( - const Shape& shape, XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); - - XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - absl::Span start_indices); - virtual absl::StatusOr DynamicUpdateSliceInternal( - const Shape& shape, XlaOp operand, XlaOp update, - absl::Span start_indices); - - XlaOp ConcatInDim(absl::Span operands, int64_t dimension); - virtual absl::StatusOr ConcatInDimInternal( - const Shape& shape, absl::Span operands, int64_t dimension); - - XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); - - XlaOp Tuple(absl::Span elements); - virtual absl::StatusOr TupleInternal(const Shape& shape, - absl::Span elements); - - XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); - virtual absl::StatusOr GetTupleElementInternal(const Shape& shape, - XlaOp tuple_data, - int64_t index); - - XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp DotGeneral( - XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp SparseDot( - XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, - absl::Span sparsity, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp Conv( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, int64_t feature_group_count = 1, - int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp ConvWithGeneralPadding( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp ConvWithGeneralDimensions( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp ConvGeneral( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp ConvGeneralDilated( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt, - std::optional> window_reversal = std::nullopt); - - XlaOp DynamicConvForward( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - - XlaOp DynamicConvInputGrad( - XlaOp input_sizes, XlaOp lhs, XlaOp rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - - XlaOp DynamicConvKernelGrad( - XlaOp activations, XlaOp gradients, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - - absl::StatusOr DynamicConvInstruction( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - - virtual absl::StatusOr ConvGeneralDilatedInternal( - const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config); - - XlaOp Fft(XlaOp operand, FftType fft_type, - absl::Span fft_length); - virtual absl::StatusOr FftInternal( - const Shape& shape, XlaOp operand, FftType fft_type, - absl::Span fft_length); - - virtual absl::StatusOr TriangularSolveInternal( - const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options); - - virtual absl::StatusOr CholeskyInternal(const Shape& shape, XlaOp a, - bool lower); - - XlaOp Infeed(const Shape& shape, const std::string& config = ""); - XlaOp InfeedWithToken(XlaOp token, const Shape& shape, - const std::string& config); - virtual absl::StatusOr InfeedWithTokenInternal( - const Shape& infeed_instruction_shape, XlaOp token, - const std::string& config); - - void Outfeed(XlaOp operand, const Shape& shape_with_layout, - const std::string& outfeed_config); - XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, - const Shape& shape_with_layout, - const std::string& outfeed_config); - virtual absl::StatusOr OutfeedWithTokenInternal( - XlaOp operand, XlaOp token, const Shape& shape_with_layout, - const std::string& outfeed_config); - XlaOp Call(const XlaComputation& computation, - absl::Span operands); - - XlaOp CompositeCall( - const XlaComputation& computation, absl::Span operands, - const std::string& name, - std::optional attributes = std::nullopt, - std::optional version = std::nullopt); - - XlaOp CustomCall( - const std::string& call_target_name, absl::Span operands, - const Shape& shape_with_layout, const std::string& opaque, - std::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, std::optional window, - std::optional dnums, - CustomCallSchedule schedule, CustomCallApiVersion api_version); - - // Internal version of CustomCall without computation that doesn't do op - // specific error handling and expects arguments to be legal. CustomCall - // method above calls this method after error handling. - virtual absl::StatusOr CustomCallInternal( - const std::string& call_target_name, absl::Span operands, - const XlaComputation* computation, const Shape& shape_with_layout, - const std::string& opaque, - std::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, std::optional window, - std::optional dnums, - CustomCallSchedule schedule, CustomCallApiVersion api_version); - - // TODO(b/239474321) Remove this overload as it has simply led to code - // duplication. - XlaOp CustomCall( - const std::string& call_target_name, absl::Span operands, - const XlaComputation& computation, const Shape& shape_with_layout, - const std::string& opaque, - std::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, CustomCallSchedule schedule, - CustomCallApiVersion api_version); - - XlaOp OptimizationBarrier(XlaOp operand); - - XlaOp Reduce(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - - XlaOp Reduce(absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - - virtual absl::StatusOr ReduceInternal( - const Shape& shape, absl::Span all_operands, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - - XlaOp ReduceAll(XlaOp operand, XlaOp init_value, - const XlaComputation& computation); - - XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - - XlaOp ReduceWindow(absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - - XlaOp ReduceWindowWithGeneralPadding( - absl::Span operands, absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - absl::StatusOr ReduceWindowInternal( - absl::Span operands, absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - virtual absl::StatusOr ReduceWindowInternal( - const Shape& shape, XlaOp operand, XlaOp init_value, - const XlaComputation& computation, Window window); - XlaOp CrossReplicaSum(XlaOp operand, - absl::Span replica_groups = {}); - - XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - - XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - - XlaOp ReduceScatter( - XlaOp operand, const XlaComputation& computation, - int64_t scatter_dimension, int64_t shard_count, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - - XlaOp AllToAll(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups, - const std::optional& layout = std::nullopt, - const std::optional& channel_id = std::nullopt); - - XlaOp AllToAllTuple( - absl::Span operands, - absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id = std::nullopt); - - XlaOp AllToAllTuple( - XlaOp operand, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id = std::nullopt); - - XlaOp CollectiveBroadcast( - XlaOp operand, absl::Span replica_groups, - const std::optional& channel_id = std::nullopt); - - XlaOp CollectivePermute( - XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id = std::nullopt); - - XlaOp ReplicaId(); - - XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding, XlaOp source, XlaOp init_value, - const XlaComputation& scatter); - - XlaOp SelectAndScatterWithGeneralPadding( - XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, XlaOp source, - XlaOp init_value, const XlaComputation& scatter); - - absl::StatusOr SelectAndScatterInternal( - XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, XlaOp source, - XlaOp init_value, const XlaComputation& scatter); - - virtual XlaOp Iota(const Shape& shape, int64_t iota_dimension); - - XlaOp Iota(PrimitiveType type, int64_t size); - - XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); - - XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); - virtual absl::StatusOr BitcastConvertTypeInternal(const Shape& shape, - XlaOp operand); - - XlaOp StochasticConvertType(XlaOp operand, XlaOp random, - PrimitiveType new_element_type); - - XlaOp Transpose(XlaOp operand, absl::Span permutation); - virtual absl::StatusOr TransposeInternal( - const Shape& shape, XlaOp operand, absl::Span permutation); - - XlaOp Rev(XlaOp operand, absl::Span dimensions); - virtual absl::StatusOr RevInternal( - const Shape& shape, XlaOp operand, absl::Span dimensions); - - XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64_t dimension = -1, bool is_stable = false); - virtual absl::StatusOr SortInternal(const Shape& shape, - absl::Span operands, - const XlaComputation& comparator, - int64_t dimension, bool is_stable); - - XlaOp TopK(XlaOp operand, int64_t k, bool largest); - virtual absl::StatusOr TopKInternal(const Shape& shape, XlaOp operand, - int64_t k, bool largest); - - XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); - - XlaOp Map(absl::Span operands, const XlaComputation& computation, - absl::Span dimensions, - absl::Span static_operands = {}); - - XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); - - XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); - - XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, - const Shape& shape); - // Internal variant for the op with the full result shape containing both data - // and state shape as a tuple. - virtual absl::StatusOr RngBitGeneratorInternal( - const Shape& full_result_shape, RandomAlgorithm algorithm, - XlaOp initial_state); - - XlaOp While(const XlaComputation& condition, const XlaComputation& body, - XlaOp init); - virtual absl::StatusOr WhileInternal(const Shape& shape, - const XlaComputation& condition, - const XlaComputation& body, - XlaOp init); - - XlaOp Conditional(XlaOp predicate, XlaOp true_operand, - const XlaComputation& true_computation, XlaOp false_operand, - const XlaComputation& false_computation); - - XlaOp Conditional(XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - - XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); - virtual absl::StatusOr ReducePrecisionInternal(const Shape& shape, - XlaOp operand, - int exponent_bits, - int mantissa_bits); - - XlaOp Gather(XlaOp input, XlaOp start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes, - bool indices_are_sorted = false); - - virtual absl::StatusOr GatherInternal( - const Shape& shape, XlaOp input, XlaOp start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes, bool indices_are_sorted); - - XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false, bool unique_indices = false); - XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, - absl::Span updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false, bool unique_indices = false); - - virtual absl::StatusOr ScatterInternal( - const Shape& shape, absl::Span inputs, XlaOp scatter_indices, - absl::Span updates, const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, - bool unique_indices); - - void Send(XlaOp operand, const ChannelHandle& handle); - XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); - - XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, - const ChannelHandle& handle); - - XlaOp RecvFromHost(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - - virtual XlaOp CreateToken(); - - XlaOp AfterAll(absl::Span tokens); - - XlaOp Recv(const Shape& shape, const ChannelHandle& handle); - XlaOp RecvWithToken(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - - XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, - float epsilon, int64_t feature_index); - - XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, - XlaOp variance, float epsilon, - int64_t feature_index); - - XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, - XlaOp batch_var, XlaOp grad_output, float epsilon, - int64_t feature_index); - - XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); - - XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); - - virtual absl::StatusOr SetDimensionSizeInternal(const Shape& shape, - XlaOp operand, - XlaOp val, - int64_t dimension); - - XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); - - virtual absl::StatusOr AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - absl::Span operands); - absl::StatusOr AddInstruction(HloInstructionProto&& instr, - HloOpcode opcode) { - return AddInstruction(std::move(instr), opcode, /*operands=*/{}); - } - - void AddCalledComputation(const XlaComputation& computation, - HloInstructionProto* instr); - - absl::StatusOr LookUpInstruction(XlaOp op) const; - absl::StatusOr LookUpInstructionByHandle( - int64_t handle) const; - absl::StatusOr LookUpMutableInstruction(XlaOp op); - absl::StatusOr LookUpMutableInstructionByHandle( - int64_t handle); - - // Internal helper method that does the building for an arbitrary unary op. - virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand); - - // Internal helper method that does the building for an arbitrary binary op. - // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. The direction is - // only used if opcode is kCompare. - XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - std::optional direction = std::nullopt, - std::optional type = std::nullopt); - - absl::StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction); - - // Internal helper method for binary op compare without broadcast dimensions. - virtual absl::StatusOr Compare(const Shape& shape, XlaOp lhs, - XlaOp rhs, - ComparisonDirection direction, - Comparison::Type type); - - // Internal helper method that does the building for an arbitrary binary op - // with same ranked operands that doesn't broadcast. - virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, - XlaOp lhs, XlaOp rhs); - - // Internal helper method that does the building for an arbitrary ternary op. - XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs); - - XlaOp RngOp(RandomDistribution distribution, - absl::Span parameters, const Shape& shape); - - virtual absl::StatusOr RngOpInternal( - RandomDistribution distribution, absl::Span parameters, - const Shape& shape); - - virtual absl::StatusOr InDimBroadcast( - const Shape& shape, XlaOp operand, - absl::Span broadcast_dimensions); - - // Internal helper method that creates a sequence of instructions that - // performs an explicit broadcast of the operand to the target shape. - // All dimensions of the operand must either be equal to the corresponding - // output shape dimension, or be exactly 1. (Such dimensions are the - // degenerate dimensions.) - absl::StatusOr AddBroadcastSequence(const Shape& output_shape, - XlaOp operand); - - // Internal helper method that broadcasts a scalar to the shape of the output. - absl::StatusOr BroadcastScalarToOutputShape(XlaOp scalar, - XlaOp output); - - // Internal helper method for creating a Reshape op with the already inferred - // shape. - virtual absl::StatusOr ReshapeInternal(const Shape& shape, - XlaOp operand, - int64_t inferred_dimension); - - // Returns the (inferred) result for the program shape using the given root. - absl::StatusOr GetProgramShape(int64_t root_id) const; - - // A visitor which checks whether an operation is a compile-time constant, - // meaning that it doesn't depend on any parameters, or on any stateful - // operation such as `RngNormal` or `Infeed`. The visitor walks the - // computation starting at a given operation and sets is_constant to false iff - // a parameter or stateful operation is encountered. - void IsConstantVisitor(int64_t op_handle, int depth, - absl::flat_hash_set* visited, - bool* is_constant) const; - - // Checks bounds for convolution parameters. - absl::Status VerifyConvolution( - const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers) const; - - int64_t GetNextId() { return ++next_id_; } - - // Populates the module with the input/output alias information stored within - // the input_output_aliases vector. - static absl::Status PopulateInputOutputAliasAndBufferDonor( - HloModuleProto* module, const ProgramShape& program_shape, - const std::vector& input_output_aliases, - const absl::flat_hash_set& - buffer_donors); - - std::string name_; // Name to use for the built computation. - - // The next sequential ID for every instruction/computation contained within - // this computation. - int64_t next_id_ = 0; - - // The first error encountered while building the computation. - // This is OK until the first error is encountered. - absl::Status first_error_; - - // The saved stack trace from the point at which the first error occurred. - tsl::SavedStackTrace first_error_backtrace_; - - // The instructions of this computation. - // Use a deque so pointers into this are stable, for example the return - // value of LookUpInstructionByHandle(). - std::deque instructions_; - // A cache for the HloInstructionProto shapes, to avoid recreating Shape - // objects from protos and to support the GetShapePtr() API. - std::vector> instruction_shapes_; - - // Dynamic parameter configuration of this computation. - DynamicParameterBinding dynamic_parameter_binding_; - - // Holds the input/output alias information populated by the SetUpAlias() API. - std::vector input_output_aliases_; - - // Holds the buffer donor information populated by the AddBufferDonor() API. - absl::flat_hash_set buffer_donors_; - - // A map from XlaOp::Handle to the index in the instructions_ vector where the - // instruction is held. - absl::flat_hash_map handle_to_index_; - - // Track imported instructions by their computation id and the position in - // their computation's instruction list. - struct ImportedInstruction { - int64_t computation_id; - int64_t instruction_index; - }; - - absl::flat_hash_map handle_to_imported_index_; - - // The embedded computations used by this computation. Each computation was - // the entry computation of some XlaComputation, the key is the unique id of - // that XlaComputation. - std::map embedded_; - - // The unique parameter numbers. - absl::flat_hash_set parameter_numbers_; - - // The metadata to attach to each op. This is structured as a "modal"-like - // operation, in order to simplify client code (and not sprinkle this metadata - // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_; - - // A temporary metadata that will only be applied to the next op created. - std::optional one_shot_metadata_; - - // Sharding for this operator. This is structured as a "model"-like operation, - // in order to simplify client code, similar to metadata_. - std::optional sharding_; - - // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_ = false; - - XlaBuilder* parent_builder_{nullptr}; - - FrontendAttributes frontend_attributes_; - - friend XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, - const Shape& shape, const std::string& name, - const std::vector& replicated_at_leaf_buffers); - friend XlaOp ConstantLiteral(XlaBuilder* builder, - const LiteralSlice& literal); - - friend XlaOp Broadcast(XlaOp operand, - absl::Span broadcast_sizes); - - friend XlaOp BroadcastInDim(XlaOp operand, - absl::Span out_dim_size, - absl::Span broadcast_dimensions); - - friend XlaOp MhloDynamicBroadcastInDim( - XlaOp operand, XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape); - - friend XlaOp Copy(XlaOp operand); - - friend XlaOp Pad(XlaOp operand, XlaOp padding_value, - const PaddingConfig& padding_config); - - friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, - int64_t pad_lo, int64_t pad_hi); - - friend XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span new_sizes); - - friend XlaOp Reshape(XlaOp operand, absl::Span new_sizes); - - friend XlaOp Reshape(const Shape& shape, XlaOp operand); - - friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, - absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); - - friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, - const Shape& shape); - - friend XlaOp ReshapeWithInferredDimension(XlaOp operand, - absl::Span new_sizes, - int64_t inferred_dimension); - - friend XlaOp Collapse(XlaOp operand, absl::Span dimensions); - - friend XlaOp Slice(XlaOp operand, absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - - friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, - int64_t limit_index, int64_t stride, int64_t dimno); - - friend XlaOp DynamicSlice(XlaOp operand, - absl::Span start_indices, - absl::Span slice_sizes); - - friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - absl::Span start_indices); - - friend XlaOp ConcatInDim(XlaBuilder* builder, - absl::Span operands, int64_t dimension); - - friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); - friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); - friend XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); - friend XlaOp Compare(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - ComparisonDirection direction); - friend XlaOp Compare(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - ComparisonDirection direction, - Comparison::Type compare_type); - friend XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, - const DotDimensionNumbers& dimension_number, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - virtual absl::StatusOr DotGeneralInternal( - const Shape& shape, XlaOp lhs, XlaOp rhs, - const DotDimensionNumbers& dimension_number, - const PrecisionConfig* precision_config); - friend XlaOp SparseDot(XlaOp lhs, XlaOp rhs, - absl::Span sparse_meta, - absl::Span sparsity, - const DotDimensionNumbers& dimension_number, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp Conv(XlaOp lhs, XlaOp rhs, - absl::Span window_strides, Padding padding, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp ConvWithGeneralPadding( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp ConvWithGeneralDimensions( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp ConvGeneral( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp DynamicConvForward( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type); - friend XlaOp DynamicConvKernelGrad( - XlaOp activations, XlaOp gradients, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type); - friend XlaOp DynamicConvInputGrad( - XlaOp input_sizes, XlaOp lhs, XlaOp rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type); - - friend XlaOp ConvKernelGrad( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - - friend XlaOp ConvGeneralDilated( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type, - std::optional> window_reversal); - - friend XlaOp Fft(XlaOp operand, FftType fft_type, - absl::Span fft_length); - friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool unit_diagonal, - TriangularSolveOptions::Transpose transpose_a); - friend XlaOp Cholesky(XlaOp a, bool lower); - friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, - const std::string& config); - friend void Outfeed(XlaOp operand, const Shape& shape_with_layout, - const std::string& outfeed_config); - friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - absl::Span operands); - - friend XlaOp CompositeCall(XlaBuilder* builder, - const XlaComputation& computation, - absl::Span operands, - const std::string& name, - std::optional attributes, - std::optional version); - - friend XlaOp CustomCall( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape, - const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, CustomCallSchedule schedule, - CustomCallApiVersion api_version); - friend XlaOp CustomCallWithComputation( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const XlaComputation& computation, - const Shape& shape, const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, CustomCallSchedule schedule, - CustomCallApiVersion api_version); - friend XlaOp CustomCallWithLayout( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, - const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, CustomCallSchedule schedule, - CustomCallApiVersion api_version); - friend XlaOp CustomCallWithConvDnums( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape, - absl::Span operand_shapes_with_layout, - const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, Window window, ConvolutionDimensionNumbers dnums, - CustomCallSchedule schedule, CustomCallApiVersion api_version); - friend XlaOp OptimizationBarrier(XlaOp operand); - friend XlaOp Complex(XlaOp real, XlaOp imag, - absl::Span broadcast_dimensions); - friend XlaOp Conj(XlaOp operand); - friend XlaOp Add(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Sub(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Mul(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Div(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Rem(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Max(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Min(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp And(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Or(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Xor(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Not(XlaOp operand); - friend XlaOp PopulationCount(XlaOp operand); - friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp ShiftRightArithmetic( - XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp ShiftRightLogical( - XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Reduce(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value, - const XlaComputation& computation); - friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding); - friend XlaOp ReduceWindow(absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding); - friend XlaOp ReduceWindowWithGeneralPadding( - XlaOp operand, XlaOp init_value, const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - friend XlaOp ReduceWindowWithGeneralPadding( - absl::Span operands, absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - - friend XlaOp CrossReplicaSum(XlaOp operand, - absl::Span replica_groups); - friend XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids); - friend XlaOp AllGatherTuple(absl::Span operands, - int64_t all_gather_dimension, int64_t shard_count, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids); - friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& shape_with_layout, - std::optional use_global_device_ids); - friend XlaOp AllReduceTuple(absl::Span operand, - const XlaComputation& computation, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& shape_with_layout, - std::optional use_global_device_ids); - friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation, - int64_t scatter_dimension, int64_t shard_count, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids); - - friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id); - friend XlaOp AllToAllTuple(absl::Span operands, - absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id); - friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id); - friend XlaOp CollectiveBroadcast( - XlaOp operand, absl::Span replica_groups, - const std::optional& channel_id); - friend XlaOp CollectivePermute( - XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id); - friend XlaOp ReplicaId(XlaBuilder* builder); - friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding, XlaOp source, XlaOp init_value, - const XlaComputation& scatter); - friend XlaOp SelectAndScatterWithGeneralPadding( - XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, XlaOp source, - XlaOp init_value, const XlaComputation& scatter); - friend XlaOp Abs(XlaOp operand); - friend XlaOp Atan2(XlaOp y, XlaOp x, - absl::Span broadcast_dimensions); - friend XlaOp Erf(XlaOp operand); - friend XlaOp Exp(XlaOp operand); - friend XlaOp Expm1(XlaOp operand); - friend XlaOp Floor(XlaOp operand); - friend XlaOp Ceil(XlaOp operand); - friend XlaOp Round(XlaOp operand); - friend XlaOp RoundNearestEven(XlaOp operand); - friend XlaOp Log(XlaOp operand); - friend XlaOp Log1p(XlaOp operand); - friend XlaOp Logistic(XlaOp operand); - friend XlaOp Sign(XlaOp operand); - friend XlaOp Clz(XlaOp operand); - friend XlaOp Cos(XlaOp operand); - friend XlaOp Sin(XlaOp operand); - friend XlaOp Tan(XlaOp operand); - friend XlaOp Tanh(XlaOp operand); - friend XlaOp Real(XlaOp operand); - friend XlaOp Imag(XlaOp operand); - friend XlaOp Sqrt(XlaOp operand); - friend XlaOp Rsqrt(XlaOp operand); - friend XlaOp Cbrt(XlaOp operand); - friend XlaOp Pow(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp IsFinite(XlaOp operand); - friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, - int64_t iota_dimension); - friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size); - friend XlaOp ConvertElementType(XlaOp operand, - PrimitiveType new_element_type); - friend XlaOp BitcastConvertType(XlaOp operand, - PrimitiveType new_element_type); - friend XlaOp StochasticConvertType(XlaOp operand, XlaOp random, - PrimitiveType new_element_type); - friend XlaOp Neg(XlaOp operand); - friend XlaOp Transpose(XlaOp operand, absl::Span permutation); - friend XlaOp Rev(XlaOp operand, absl::Span dimensions); - friend XlaOp Sort(absl::Span operands, - const XlaComputation& comparator, int64_t dimension, - bool is_stable); - friend XlaOp TopK(XlaOp operand, int64_t k, bool largest); - friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); - friend XlaOp Map(XlaBuilder* builder, absl::Span operands, - const XlaComputation& computation, - absl::Span dimensions, - absl::Span static_operands); - friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); - friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); - friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, - const Shape& shape); - friend XlaOp While(const XlaComputation& condition, - const XlaComputation& body, XlaOp init); - friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand, - const XlaComputation& true_computation, - XlaOp false_operand, - const XlaComputation& false_computation); - friend XlaOp Conditional( - XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - friend XlaOp ConditionalImpl( - XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - friend XlaOp ReducePrecision(XlaOp operand, int exponent_bits, - int mantissa_bits); - friend XlaOp Gather(XlaOp input, XlaOp start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes, - bool indices_are_sorted); - friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted, bool unique_indices); - friend XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, - absl::Span updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted, bool unique_indices); - friend void Send(XlaOp operand, const ChannelHandle& handle); - friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, - const ChannelHandle& handle); - friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, - float epsilon, int64_t feature_index); - friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, - XlaOp mean, XlaOp variance, float epsilon, - int64_t feature_index); - friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, - XlaOp batch_var, XlaOp grad_output, float epsilon, - int64_t feature_index); - friend XlaOp SendWithToken(XlaOp operand, XlaOp token, - const ChannelHandle& handle); - friend XlaOp RecvWithToken(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - friend XlaOp SendToHost(XlaOp operand, XlaOp token, - const Shape& shape_with_layout, - const ChannelHandle& handle); - friend XlaOp RecvFromHost(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape, - const std::string& config); - friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, - const Shape& shape_with_layout, - const std::string& outfeed_config); - friend XlaOp CreateToken(XlaBuilder* builder); - friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); - - friend XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); - friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); - friend XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); - - protected: - // Returns OK status if the given op was built using this builder. Otherwise, - // returns an error. - absl::Status CheckOpBuilder(XlaOp op) const; - - private: - XlaOp AllGatherImpl(XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids, bool async); - - XlaOp AllReduceImpl(XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids, bool async); - - XlaOp CollectiveBroadcastImpl(XlaOp operand, - absl::Span replica_groups, - const std::optional& channel_id); - - XlaOp CollectivePermuteImpl( - XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id, bool async); - - XlaOp ConditionalImpl( - XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - - XlaOp AllToAllArray( - XlaOp operand, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, absl::Span replica_groups, - const std::optional& channel_id = std::nullopt); - - // Creates an op with the given opcode and the output shape. - virtual absl::StatusOr AddOpWithShape( - HloOpcode opcode, const Shape& shape, absl::Span operands); - - // Here, InstructionType is either const HloInstructionProto* or non-const - // HloInstructionProto*. - template - absl::StatusOr LookUpInstructionByHandleInternal( - int64_t handle) const { - auto it = handle_to_index_.find(handle); - if (it == handle_to_index_.end()) { - // Try look for the instruction in the imported instructions. - auto imported_it = handle_to_imported_index_.find(handle); - if (imported_it != handle_to_imported_index_.end()) { - ImportedInstruction imported = imported_it->second; - return const_cast( - &embedded_.at(imported.computation_id) - .instructions(imported.instruction_index)); - } - return InvalidArgument("No XlaOp with handle %d", handle); - } - return const_cast(&instructions_.at(it->second)); - } - - // Here, InstructionType is either const HloInstructionProto* or non-const - // HloInstructionProto*. - // - // TODO(hinsu): Return const pointer within absl::StatusOr and use - // absl::implicit_cast at callsites. This requires implicit_cast support in - // absl::StatusOr similar to absl::StatusOr. - template - absl::StatusOr LookUpInstructionInternal(XlaOp op) const { - TF_RETURN_IF_ERROR(CheckOpBuilder(op)); - return LookUpInstructionByHandleInternal(op.handle()); - } - - friend struct internal::XlaBuilderFriend; - - friend class ValueInference; -}; - -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class XlaScopedShardingAssignment { - public: - XlaScopedShardingAssignment(xla::XlaBuilder* builder, - std::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } - - XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; - XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = - delete; - - ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } - - private: - void SetSharding(const std::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } - - xla::XlaBuilder* const builder_; - std::optional prev_sharding_; -}; - -// RAII-style object: save the current builder's frontend attributes, and merge -// them with the new ones on construction. -// Restore the original attributes on destruction. -class XlaScopedFrontendAttributesAssignment { - public: - XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder, - FrontendAttributes attributes) - : builder_(builder) { - saved_ = builder_->SwapFrontendAttributes(attributes); - } - - ~XlaScopedFrontendAttributesAssignment() { - builder_->SetFrontendAttributes(saved_); - } - - private: - xla::XlaBuilder* const builder_; - FrontendAttributes saved_; - - XlaScopedFrontendAttributesAssignment( - const XlaScopedFrontendAttributesAssignment&) = delete; - XlaScopedFrontendAttributesAssignment& operator=( - const XlaScopedFrontendAttributesAssignment&) = delete; -}; - -// RAII-style object: sets the current op metadata in builder on construction, -// and sets back to the previous assignment on destruction. -class XlaScopedOpMetadataAssignment { - public: - XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata) - : builder_(builder) { - saved_ = builder_->SwapOpMetadata(metadata); - } - - ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); } - - private: - xla::XlaBuilder* const builder_; - OpMetadata saved_; - - XlaScopedOpMetadataAssignment(const XlaScopedOpMetadataAssignment&) = delete; - XlaScopedOpMetadataAssignment& operator=( - const XlaScopedOpMetadataAssignment&) = delete; -}; - -// Free functions for building XlaOps. The intention is that these will -// become the public API for building XlaOps rather than calling methods on -// XlaBuilder directly. -// - -// Enqueues a "retrieve parameter value" instruction for a parameter that was -// passed to the computation. -XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, - const Shape& shape, const std::string& name); - -// Same as above, but with leaf buffer replication annotation. -XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, - const Shape& shape, const std::string& name, - const std::vector& replicated_at_leaf_buffers); - -// Enqueues a constant with the value of the given literal onto the -// computation. -XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); - -// Enqueues a constant onto the computation. Methods are templated on the -// native host type (NativeT) which corresponds to a specific XLA -// PrimitiveType as given in the following table: -// -// Native Type PrimitiveType -// ----------------------------- -// bool PRED -// int32_t S32 -// int64_t S64 -// uint32_t U32 -// uint64_t U64 -// float F32 -// double F64 -// -// Note: not all primitive types defined in xla_data.proto have a -// corresponding native type yet. -template -XlaOp ConstantR0(XlaBuilder* builder, NativeT value); -template -XlaOp ConstantR1(XlaBuilder* builder, absl::Span values); -XlaOp ConstantR1(XlaBuilder* builder, const tsl::core::Bitmap& values); -template -XlaOp ConstantR2(XlaBuilder* builder, - std::initializer_list> values); -template -XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, - const Array& values, - const Layout& layout); -template -XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values); -template -XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, - const Array2D& values, - const Layout& layout); -template -XlaOp ConstantR2FromArray2D(XlaBuilder* builder, - const Array2D& values); -template -XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, - const Array3D& values, - const Layout& layout); -template -XlaOp ConstantR3FromArray3D(XlaBuilder* builder, - const Array3D& values); -template -XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, - const Array4D& values, - const Layout& layout); -template -XlaOp ConstantR4FromArray4D(XlaBuilder* builder, - const Array4D& values); - -// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the -// computation. The vector has size 'length' and every element has the value -// 'value'. -template -XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value); - -// Adds dimensions to an array by duplicating the data in the array. -// -// The new dimensions are inserted on the left, i.e. if -// broadcast_sizes has values {a0, ..., aN} and the operand shape -// has dimensions {b0, ..., bM} then the shape of the output has -// dimensions {a0, ..., aN, b0, ..., bM}. -// -// The new dimensions index into copies of the operand, i.e. -// -// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); - -// This op broadcasts the `operand` to an output with the given `shape`. -// `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the -// i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th -// dimension of the output. This also requires that the i'th input dimension is -// either 1 or is the same as the output dimension it's broadcasting into. -// -// For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the -// output shape is s32[2,2]: -// - Specifying {1} as broadcast_dimension will generate output -// {{1, 2}, -// {1, 2}} -// - On the other hand, specifying {0} as broadcast_dimension -// will generate output -// {{1 , 1}, -// {2 , 2}} -XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); - -// This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim -// op from the XlaBuilder. This is only intended for export to MHLO or -// StableHLO, and cannot be compiled. See -// https://www.tensorflow.org/mlir/hlo_ops#mhlodynamic_broadcast_in_dim_mhlodynamicbroadcastindimop. -// for the op semantics. -XlaOp MhloDynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape); - -// Copies the input operand to the output. This operation is for internal -// purpose and is only used by the compiler for optimization purposes or to -// ensure correctness. The XLA client should never have to generate this -// instruction. -// -// Copy has two potential use cases: -// -// * Create a copy of the operand with a new layout. -// -// * Create a copy of the operand in a separately allocated buffer. This is -// necessary for some backends if the operand is a parameter or constant and -// the operand is returned within a tuple. In this case, the lifetime of the -// operand buffer must be the same as the lifetime of the output result. -// However, the lifetimes of parameters and constants are managed separately -// from the lifetime of the output result. Creating a separate copy of the -// parameter or constant buffer resolves this issue. -XlaOp Copy(XlaOp operand); - -// Enqueues a pad operation onto the computation that pads the given value on -// the edges as well as between the elements of the input. padding_config -// specifies the padding amount for each dimension. -XlaOp Pad(XlaOp operand, XlaOp padding_value, - const PaddingConfig& padding_config); - -// Enqueues a pad operation in a given dimension, taking all other -// dimensions as they are. -XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, - int64_t pad_lo, int64_t pad_hi); - -// Enqueues an operation onto the computation that flattens the operand based -// on the dimension order (major/slowest-varying to minor/fastest-varying) -// given, followed by reshaping it into the shape with the given dimension -// sizes (also major to minor). Conceptually, this is a limited form of -// "shape casting". -XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span new_sizes); - -// Enqueues a dynamic reshape operation. The dynamic reshape takes additional -// XlaOps as sizes for the result dimension. The result dim i is a dynamic -// dimension dimension if dims_are_dynamic[i] is true. -XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, - absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); - -// This is an experimental API for creating the mhlo.dynamic_reshape op from the -// XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot -// be compiled. -XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); - -// Enqueues an operation onto the computation that collapses the operand, -// from first to last dimension (C order), then reshapes it to the given -// dimension sizes. Conceptually, this is a limited form of "shape casting". -XlaOp Reshape(XlaOp operand, absl::Span new_sizes); - -// Enqueues a Reshape op that uses an explicit target shape. -XlaOp Reshape(const Shape& shape, XlaOp operand); - -// `inferred_dimension` represents the output dimension that's inferred by -// upper-level framework by dividing the input element count by the known -// output element count. While an inferred_dimension can be static, if there -// is a dynamic dimension in the output, it must be the inferred dimension. -XlaOp ReshapeWithInferredDimension(XlaOp operand, - absl::Span new_sizes, - int64_t inferred_dimension); - -// Wrapper for Reshape. -// Enqueues an operation to collapse the provided dimensions; e.g. an -// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to -// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must -// be a consecutive, in-order subsequence of the operand dimensions. -// -// Note that collapsing a single dimension does nothing: -// -// {256} collapsing {0} => {256} -// {1} collapsing {0} => {1} -// -// Collapsing multiple dimensions produces a single result dimension: -// -// {256, 2} collapsing {0,1} => {512} -// {256, 2, 3} collapsing {0,1} => {512, 3} -// -// This could potentially cause data to be moved -- it provides a more -// structured form of reshaping than an arbitrary Reshape operation. -XlaOp Collapse(XlaOp operand, absl::Span dimensions); - -// Enqueues a slice operation onto the computation that slices the operand -// from the start indices to the limit indices; e.g. -// -// x -// [ 0 1 2 3 ] -// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] -// [ 8 9 a b ] -// -// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D -// range notation. -// The strides parameter determines the stride over the slice -XlaOp Slice(XlaOp operand, absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - -// Enqueues a slice operation in a given dimension, taking all other -// dimensions as they are; e.g. if dimno is 1 from start_index 2 to -// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand -// for: -// -// array[:, 2:4:1, :] -XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, - int64_t stride, int64_t dimno); - -// Enqueues a slice operation onto the computation that slices the 'operand' -// from dynamic start indices which are passed in 'start_indices'. -// The size of the slice in each dimension is passed in 'slice_sizes', -// which specify the end point of exclusive slice intervals in each -// dimension [start, start + size). -// The shape of each element of 'start_indices' must be scalar, with the span -// size equal to the rank of the 'operand'. All elements of 'start_indices' must -// have the same shape. -// Slice index calculations are computed modulo input dimension sizes to -// prevent dynamic start indices from generating out-of-bound array accesses. -XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); - -// Enqueues a dynamic update slice operation onto the computation, which -// updates a slice of 'operand' with 'update' at dynamic 'start_indices'. -// The shape of 'update' determines the shape of the slice of 'operand' -// which is updated. -// The indices specified in 'start_indices' specify the offset of the slice -// of 'operand' which is updated. -// -// update = {10, 11} // calculated at runtime. -// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] -// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] -// [7 8 9] [7 8 9 ] -// -// The shape of each element of 'start_indices' must be scalar, with the span -// size equal to the rank of the 'operand'. All elements of 'start_indices' must -// have the same shape. -// Slice index calculations are computed modulo update dimension sizes to -// prevent dynamic start indices from generating out-of-bound array accesses. -XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - absl::Span start_indices); - -// Enqueues a concatenate instruction onto the computation. 'operands' must -// have >= 1 entry. -XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, - int64_t dimension); - -// Enqueues a conditional-move-like select operation onto the computation; -// predicated on pred, selects between on_true and on_false. -XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); - -// Enqueues a tuple-creation instruction onto the computation. -XlaOp Tuple(XlaBuilder* builder, absl::Span elements); - -// Enqueues a tuple-element-get instruction onto the computation. -XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); - -// Enqueues an equal-to comparison instruction onto the computation. -XlaOp Eq(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a not-equal comparison instruction onto the computation. -XlaOp Ne(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a greater-or-equal comparison instruction onto the computation. -XlaOp Ge(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a greater-than comparison instruction onto the computation. -XlaOp Gt(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a less-than comparison instruction onto the computation. -XlaOp Lt(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a less-or-equal comparison instruction onto the computation. -XlaOp Le(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a comparison instruction onto the computation (optionally without -// broadcast_dimensions for consistency with others). -XlaOp Compare(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - ComparisonDirection direction, Comparison::Type compare_type); -XlaOp Compare(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - ComparisonDirection direction); -XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); - -// Enqueues a dot instruction onto the computation. -XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a general dot instruction onto the computation. -XlaOp DotGeneral( - XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a sparse dot instruction onto the computation. -XlaOp SparseDot( - XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, - absl::Span sparsity, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, which uses the -// default convolution dimension numbers. -XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, int64_t feature_group_count = 1, - int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, with the caller -// provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, with the caller -// provided dimension numbers configuration. -XlaOp ConvWithGeneralDimensions( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, with the caller -// provided padding configuration as well as the dimension numbers. -XlaOp ConvGeneral( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, with the caller -// provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt, - std::optional> window_reversal = std::nullopt); - -XlaOp DynamicConvForward( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - -XlaOp DynamicConvInputGrad( - XlaOp input_sizes, XlaOp lhs, XlaOp rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - -XlaOp DynamicConvKernelGrad( - XlaOp activations, XlaOp gradients, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - -// Enqueues an FFT instruction onto the computation, of the given type and -// with the given FFT length. -XlaOp Fft(XlaOp operand, FftType fft_type, - absl::Span fft_length); - -// Solves systems of linear equations with lower or upper triangular coefficient -// matrices by forward- or back-substitution. Broadcasting along leading -// dimensions, this routine solves for x in one of the matrix systems -// `op(a) * x = b`, or `x * op(a) = b`, -// for the variable `x` given `a` and `b`, where `op(a)` is either -// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. -// -// * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form -// square matrices. If `lower` is true (false), then the strictly upper -// (lower) triangular part of each innermost matrix in `a` is assumed to be -// zero and is not accessed. -// * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a -// tensor of shape `[..., K, M]`. -// * `left_side` is a boolean, indicating whether to solve a system of the form -// op(a) * x = b (true) or x * op(a) = b (false). -// * `lower` is a boolean, indicating whether the argument `a` is -// lower-triangular (true) or upper-triangular (false). -// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be -// 1 and not accessed. -// * `transpose_a` indicates which function `op` we use to transform the tensor -// `a`: the identity function, transpose(a), or conjugate(transpose(a)) -XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool unit_diagonal, - TriangularSolveOptions::Transpose transpose_a); - -// Computes the Cholesky decompositions of a batch of symmetric (Hermitian) -// positive definite matrices. -// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the -// two minor dimensions equal. -// If `lower` is true, the data from the lower triangle is used; if false, the -// upper triangle is used. The input data in the other triangle of the input -// does not affect the output. Returns the output in the same lower/upper -// triangle. The data returned in the other output triangle is arbitrary and -// implementation-defined. -// -// If `a` is not Hermitian positive definite, returns an array full of NaNs. -XlaOp Cholesky(XlaOp a, bool lower); - -// Enqueues an infeed instruction onto the computation, which writes data of -// the given shape to the infeed buffer of the device. -XlaOp Infeed(XlaBuilder* builder, const Shape& shape, - const std::string& config = ""); - -// Variant of Infeed which takes a token-shaped operand and produces a -// two-element tuple containing the data value and a token-shaped value. -// Tokens are used for ordering side-effecting operations. -// TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp InfeedWithToken(XlaOp token, const Shape& shape, - const std::string& config = ""); - -// Enqueues an outfeed instruction onto the computation. This instruction -// generates outgoing data transfers for the given data. -// -// shape_with_layout communicates the laid out shape that we want to outfeed -// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error -// will occur. -void Outfeed(XlaOp operand, const Shape& shape_with_layout, - const std::string& outfeed_config); - -// Variant of Outfeed which takes a token-shaped operand and produces a -// token-shaped value. Tokens are used for ordering side-effecting operations. -// TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, - const Shape& shape_with_layout, - const std::string& outfeed_config); - -// Enqueues a call instruction onto the computation. -XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - absl::Span operands); - -// Enqueues a composite call instruction onto the computation. -XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation, - absl::Span operands, const std::string& name, - std::optional attributes = std::nullopt, - std::optional version = std::nullopt); - -// Enqueues a custom call instruction onto the computation. A custom call -// invokes code external to XLA. The |operands| are passed to the external code, -// and the external code is expected to produce a result of the given -// |shape|. The exact mechanism is backend-specific. For example, in the CPU -// backend, a call instruction is emitted which targets a symbol with the name -// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, -// but |call_target_name| should be short as it may be used in labels. |opaque| -// can encode arbitrarily large amounts of information. |has_side_effect| -// specifies whether the instruction can have side effects. -// |output_operand_aliasing| specifies a list of output/operand buffer pairs -// that alias each other, where the output buffer is represented as a -// ShapeIndex, and the operand buffer is represented as the operand index and -// the ShapeIndex. -XlaOp CustomCall( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape, - const std::string& opaque = "", bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}, - const Literal* literal = nullptr, - CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, - CustomCallApiVersion api_version = API_VERSION_ORIGINAL); - -// Overload which constructs a custom call that applies an Xla computation. -XlaOp CustomCallWithComputation( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const XlaComputation& computation, - const Shape& shape, const std::string& opaque = "", - bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}, - const Literal* literal = nullptr, - CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, - CustomCallApiVersion api_version = API_VERSION_ORIGINAL); - -// Overload which constructs a custom call with fixed layouts. The operands will -// have the layouts specified by |operand_shapes_with_layout| when provided to -// external code, and the external code is expected to produce a result with the -// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| -// and |operand_shapes_with_layout| must have layouts. -XlaOp CustomCallWithLayout( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, - const std::string& opaque = "", bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}, - const Literal* literal = nullptr, - CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, - CustomCallApiVersion api_version = API_VERSION_ORIGINAL); - -// Overload which annotates a custom call with the given Window and -// ConvolutionDimensionNumbers. Useful for custom-calls which represent -// convolutions. -// -// This sets the layout of its operands if operand_shapes_with_layout is -// nonempty, and it sets the layout of its result if `shape` has a layout. -XlaOp CustomCallWithConvDnums( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape, - absl::Span operand_shapes_with_layout, - const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, Window window, ConvolutionDimensionNumbers dnums, - CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, - CustomCallApiVersion api_version = API_VERSION_ORIGINAL); - -// Enqueues an optimization barrier onto the computation. -XlaOp OptimizationBarrier(XlaOp operand); - -// The following methods enqueue element-wise binary arithmetic operations -// onto the computation. The shapes of the operands have to match unless one -// of the operands is a scalar, or an explicit broadcast dimension is given -// (see g3doc for more details). - -// Enqueues a complex compose instruction onto the computation. -XlaOp Complex(XlaOp real, XlaOp imag, - absl::Span broadcast_dimensions = {}); - -// Enqueues a complex conjugate instruction onto the computation. -XlaOp Conj(XlaOp operand); - -// Enqueues an add instruction onto the computation. -XlaOp Add(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a subtract instruction onto the computation. -XlaOp Sub(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a multiply instruction onto the computation. -XlaOp Mul(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a divide instruction onto the computation. -XlaOp Div(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a remainder instruction onto the computation. -XlaOp Rem(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a max instruction onto the computation. -XlaOp Max(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a min instruction onto the computation. -XlaOp Min(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Element-wise logical operators -XlaOp And(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Overload to call And with 3 or more operands. We need the following somewhat -// convoluted overload set to disambiguate with the overload that takes the -// `broadcast_dimensions` optional param. -inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) { - return And(op1, And(op2, op3)); -} -template -XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { - return And(op1, And(op2, And(op3, operands...))); -} - -XlaOp Or(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Overload to call Or with 3 or more operands. As with `And`, we need the -// following complicated overload set to handle the default arg in the `Or` -// overload above. -inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) { - return Or(op1, Or(op2, op3)); -} -template -XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { - return Or(op1, Or(op2, Or(op3, operands...))); -} - -XlaOp Xor(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -XlaOp Not(XlaOp operand); - -XlaOp PopulationCount(XlaOp operand); - -XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -// Reduces an array among the provided dimensions, given "computation" as a -// reduction operator. -XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, - absl::Span dimensions_to_reduce); - -// Reduces several arrays simultaneously among the provided dimensions, given -// "computation" as a reduction operator. -XlaOp Reduce(XlaBuilder* builder, absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - -// Convenience wrapper around the above that reduces all the dimensions in the -// operand shape. -XlaOp ReduceAll(XlaOp operand, XlaOp init_value, - const XlaComputation& computation); - -// Enqueues a windowed reduce instruction onto the computation. -XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - -XlaOp ReduceWindow(absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - -// As ReduceWindow(), but the padding is given in the format -// returned by MakePadding(). -XlaOp ReduceWindowWithGeneralPadding( - XlaOp operand, XlaOp init_value, const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); -XlaOp ReduceWindowWithGeneralPadding( - absl::Span operands, absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - -// Returns the sum of the operand value within each subgroup of replicas. All -// replicas supply one input to the sum and all replicas receive the resulting -// sum for each subgroup. -XlaOp CrossReplicaSum(XlaOp operand, - absl::Span replica_groups = {}); - -XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -XlaOp AllGatherTuple( - absl::Span operands, int64_t all_gather_dimension, - int64_t shard_count, absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -// Enqueues an operation that do an AllReduce of the operand cross cores. Here -// AllReduce means doing a reduction on the input operand cross cores and then -// broadcasting the reduction result to those cores. The reduction function is -// defined by `computation`, which should be a commutative computation on -// scalars, e.g., add, min, or max. The way that AllReduce is applied is -// configured by: -// -// - `replica_groups`: each ReplicaGroup contains a list of replica id. If -// empty, all replicas belong to one group. Allreduce will be applied within -// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} -// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. -// -// - `channel_id`: for Allreduce nodes from different modules, if they have the -// same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be -// applied cross modules. -// -// - `shape_with_layout`: forces the layout of the AllReduce to the given -// layout. This is used to guarantee the same layout for a group of AllReduce -// ops compiled separately. -XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -XlaOp AllReduceTuple( - absl::Span operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -XlaOp ReduceScatter( - XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, - int64_t shard_count, absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -// Enqueues an operation that do an AllToAll of the operand cross cores. -// This involves AllToAll, followed by Reshape, Transpose, and another Reshape -// to get proper codegen. See implementation for additional details. -// -// An optional `layout` can be specified to force the layout of the instruction. -// This is used to guarantee the same layout for a group of AllToAll ops -// compiled separately. -XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, - absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt, - const std::optional& channel_id = std::nullopt); - -XlaOp AllToAllTuple( - absl::Span operand, - absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt, - const std::optional& channel_id = std::nullopt); - -XlaOp AllToAllTuple( - XlaOp operand, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt, - const std::optional& channel_id = std::nullopt); - -XlaOp CollectiveBroadcast( - XlaOp operand, absl::Span replica_groups, - const std::optional& channel_id = std::nullopt); - -// Enqueues an collective operation that sends and receives data cross replicas. -// -// - `source_target_pair`: a list of (source_replica_id, target_replica_id) -// pairs. For each pair, the operand is sent from source replica to target -// replica. Note that, 1) any two pairs should not have the same target replica -// id, and they should not have the same source replica id; 2) if a replica id -// is not a target in any pair, then the output on that replica is a tensor -// consists of 0(s) with the same shape as the input. -XlaOp CollectivePermute( - XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id = std::nullopt); - -// Enqueues an operation that returns the replica ID. -XlaOp ReplicaId(XlaBuilder* builder); - -// Enqueues an operation that scatters the `source` array to the selected -// indices of each window. -XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding, XlaOp source, XlaOp init_value, - const XlaComputation& scatter); - -// As SelectAndScatter(), but the padding is given in the format -// returned by MakePadding(). -XlaOp SelectAndScatterWithGeneralPadding( - XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, XlaOp source, - XlaOp init_value, const XlaComputation& scatter); - -// Enqueues an abs instruction onto the computation. -XlaOp Abs(XlaOp operand); - -// Enqueues a atan2 instruction onto the computation. -XlaOp Atan2(XlaOp y, XlaOp x, - absl::Span broadcast_dimensions = {}); - -// Enqueues an erf instruction onto the computation. -XlaOp Erf(XlaOp operand); - -// Enqueues an exp instruction onto the computation. -XlaOp Exp(XlaOp operand); - -// Enqueues an expm1 instruction onto the computation. -XlaOp Expm1(XlaOp operand); - -// Enqueues a floor instruction onto the computation. -XlaOp Floor(XlaOp operand); - -// Enqueues a ceil instruction onto the computation. -XlaOp Ceil(XlaOp operand); - -// Enqueues a round instruction onto the computation, -// with half-way cases rounding away from zero. -XlaOp Round(XlaOp operand); - -// Enqueues a round instruction onto the computation, rounding to nearest even -XlaOp RoundNearestEven(XlaOp operand); - -// Enqueues an log instruction (natural logarithm) onto the computation. -XlaOp Log(XlaOp operand); - -// Enqueues an log1p instruction (log(x+1)) onto the computation. -XlaOp Log1p(XlaOp operand); - -// Enqueues a logistic instruction onto the computation. -XlaOp Logistic(XlaOp operand); - -// Enqueues a sign instruction onto the computation. -XlaOp Sign(XlaOp operand); - -// Enqueues a count leading zeros instruction onto the computation. -XlaOp Clz(XlaOp operand); - -// Enqueues a cosine instruction onto the computation. -XlaOp Cos(XlaOp operand); - -// Enqueues a sine instruction onto the computation. -XlaOp Sin(XlaOp operand); - -// Enqueues a tan instruction onto the computation. -XlaOp Tan(XlaOp operand); - -// Enqueues a tanh instruction onto the computation. -XlaOp Tanh(XlaOp operand); - -// Enqueues a real-part instruction onto the computation. -XlaOp Real(XlaOp operand); - -// Enqueues an imaginary-part instruction onto the computation. -XlaOp Imag(XlaOp operand); - -// Enqueues a sqrt computation onto the computation. -XlaOp Sqrt(XlaOp operand); - -// Enqueues a cbrt computation onto the computation. -XlaOp Cbrt(XlaOp operand); - -// Enqueues a rsqrt computation onto the computation. -XlaOp Rsqrt(XlaOp operand); - -// Enqueues a lhs^rhs computation onto the computation. -XlaOp Pow(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues an operator that tests if the operand's values are finite, i.e., not -// +/-Inf or NaN. Returns an array of booleans with the same shape where -// entries are true iff the corresponding entry was not infinite or NaN. -// -// Defined only for real-valued (i.e. not complex) floating-point types; raises -// an error for other types. -// -// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. -XlaOp IsFinite(XlaOp operand); - -// Enqueues an iota operation onto the computation. -XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension); - -// Enqueues a rank-1 iota operation onto the computation. -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size); - -// Enqueues a convert instruction onto the computation that changes the -// element type of the operand array to primitive_type. -XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); - -// Enqueues a no-op instruction onto the computation that changes -// the element type of the operand array to primitive_type. The -// bit-widths of the source and destination element types must be -// identical. -XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); - -// Enqueues a stochastic convert instruction onto the computation that changes -// the element type of the operand array with stochastic rounding to -// primitive_type. -XlaOp StochasticConvertType(XlaOp operand, XlaOp random, - PrimitiveType new_element_type); - -// Enqueues a negate instruction onto the computation. -XlaOp Neg(XlaOp operand); - -// Enqueues a transpose instruction onto the computation. -XlaOp Transpose(XlaOp operand, absl::Span permutation); - -// Enqueues a reverse instruction onto the computation. The order of the -// elements in the given dimensions is reversed (i.e., the element at index i -// is moved to index dimension_size - 1 - i). -XlaOp Rev(XlaOp operand, absl::Span dimensions); - -// Enqueues a sort instruction onto the computation, using 'comparator' for -// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' -// determines whether the stable sorting should be used. -// If only one operand is provided: -// * If the operand is a rank-1 tensor (an array), the result is a sorted array. -// The resulting sorting order has the property that for all index positions -// i, j with i < j, either -// comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or -// comparator(value[i], value[j]) = true. -// * If the operand has higher rank, the operand is sorted along the provided -// dimension. For example, for a rank-2 tensor (a matrix), a dimension value -// of 0 will independently sort every column, and a dimension value of 1 will -// independently sort each row. If no dimension number is provided, then the -// last dimension is chosen by default. For the dimension which is sorted, the -// same sorting order applies as in the rank-1 case. -// -// If more than one operand is provided: -// * All operands must be tensors with the same dimensions. The element types of -// the tensors may be different. -// * The result is a tuple that consists of the operands in sorted order (along -// the provided dimension, as above). The same permutation as implied by the -// comparison computation is applied to all operand tensors. When comparing -// two index positions, 'comparator' is called with 2 * n scalar parameters, -// where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at -// two index positions. -// Default comparator computations can be found in lib/comparators.h -XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64_t dimension = -1, bool is_stable = false); - -// Enqueues a topk instruction onto the computation. TopK returns the largest -// 'k' values and their indices along the last dimension of the 'operand' if -// `lagest=true` or the smallest `k` values if `largest=false`. -// -// * If the operand is a rank-1 tensor (an array), the result is a tuple that -// consists of: -// * a sorted array with the top 'k' elements. -// * an array containing the indices of the k elements. -// For example, if the input is [0.1, 0.3, 0.2] and k == 2, the output tuple -// is ([0.3, 0.2], [1, 2]). -// * If the operand has higher rank, the result is a tuple that consists of: -// * a tensor equivalent to one produced by sorting the operand along the last -// dimension and slicing that dimension to only the top 'k' values. The last -// dimension is sorted as in the rank-1 case. -// * a tensor containing the indices of the top 'k' values along the last -// dimension. -// For example, if the input is [0.1, 0.3, 0.2][0.5, 0.4, 0.6] and k == 1, the -// output tuple is ([0.3][0.6], [1][2]). -XlaOp TopK(XlaOp operand, int64_t k, bool largest); - -// Enqueues a clamp instruction onto the computation. -XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); - -// Enqueues a map instruction onto the computation. -XlaOp Map(XlaBuilder* builder, absl::Span operands, - const XlaComputation& computation, - absl::Span dimensions, - absl::Span static_operands = {}); - -// Enqueues a N(mu, sigma) random number generation instruction onto the -// computation. -XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); - -// Enqueues a U(a, b) random number generation instruction onto the -// computation. Returns values in the semi-open interval [a, b). -XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); - -// Enqueues a B(initial_state) random bit generation instruction onto the -// computation. Returns the new key and random bits with the specified shape. -XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, - const Shape& shape); - -// Enqueues a while node onto the computation. -XlaOp While(const XlaComputation& condition, const XlaComputation& body, - XlaOp init); - -// Enqueues a conditional node onto the computation. -XlaOp Conditional(XlaOp predicate, XlaOp true_operand, - const XlaComputation& true_computation, XlaOp false_operand, - const XlaComputation& false_computation); - -// Enqueues either a predicated (if/else) or indexed (switch/case/default) -// conditional node onto the computation. N >= 1 branch_computations and -// branch_operands are matched by index. branch_index selects the branch that -// will be executed. Out of range branch_index uses the N-1'th -// branch_computation as default. -XlaOp Conditional(XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - -// Enqueues a ReducePrecision node onto the computation. -XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); - -// Enqueues a Gather node onto the computation. -XlaOp Gather(XlaOp input, XlaOp start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes, - bool indices_are_sorted = false); - -// Enqueues a Scatter node onto the computation. -XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false, bool unique_indices = false); -XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, - absl::Span updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false, bool unique_indices = false); - -// Enqueues a Send node onto the computation for device-to-device -// communication. This operation sends the given operand to -// a Recv instruction in a different computation that shares the same channel -// handle. -void Send(XlaOp operand, const ChannelHandle& handle); - -// Variant of Send which takes a token-shaped operand and produces a -// token-shaped value. Tokens are used for ordering side-effecting operations. -// TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); - -// Enqueues a Recv node onto the computation for device-to-device -// communication. The data comes from a Send instruction in a different -// computation that shares the same channel handle and its shape must be the -// same as the given shape. -XlaOp Recv(XlaBuilder* builder, const Shape& shape, - const ChannelHandle& handle); - -// Variant of Recv which takes a token-shaped operand and produces a two-element -// tuple containing the data value and a token-shaped value. Tokens are used -// for ordering side-effecting operations. -// TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp RecvWithToken(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - -// Enqueues a Send node which transfers data from the device to the host. The -// 'shape_with_layout' argument defines the layout of the data transferred; its -// shape must be compatible with the shape of the operand. The operand must be -// array-shaped. -// TODO(b/111544877): Support tuple shapes. -XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, - const ChannelHandle& handle); - -// Enqueues a Recv node which transfers data from the host to the device. The -// given shape must contain a layout and must be an array. -// TODO(b/111544877): Support tuple shapes. -XlaOp RecvFromHost(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - -// Enqueues an operation (AfterAll) with no operands that produces a -// token-shaped value. Tokens are used for ordering side-effecting operations. -// This is a separate method from AfterAll to facility the removal of -// operand-less AfterAll instructions. -// TODO(b/110532604): Remove this function when all tokens are derived from a -// single token generated or passed into the entry computation. -XlaOp CreateToken(XlaBuilder* builder); - -// Enqueues an AfterAll instruction which produces a token-shaped value and -// takes a variadic number of token-shaped operands. The number of operands must -// be greater than zero. Used for joining tokens. -XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); - -// Normalizes operand across spatial and batch dimensions for each feature. -// -// Returns a tuple (normalized, batch_mean, batch_var) where `normalized` -// is the normalized result and batch_mean and batch_var are the mean and -// variance, respectively, across batch for the operand. -XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon, - int64_t feature_index); - -// Normalizes operand across spatial and batch dimensions for each feature. -// -// `BatchNormInference` is equivalent to calling `BatchNormTraining` without -// computing `mean` and `variance` for each batch inside the operation. It -// uses the input `mean` and `variance` instead as estimated values. The -// purpose of this op is to reduce latency in inference, hence the name -// `BatchNormInference`. -// -// The output has the same shape as `operand`, and contains the normalized -// values for each batch. -XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, - XlaOp variance, float epsilon, int64_t feature_index); - -// Calculates the gradients of a batch norm op. -// -// The inputs `batch_mean` and `batch_var` represent the mean and variance -// across the batch. -// -// Returns a tuple of three elements: -// - grad_operand: Gradient with respect to input `operand` -// - grad_offset: Gradient with respect to input `offset` -// - grad_scale: Gradient with respect to input `scale` -XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, - XlaOp batch_var, XlaOp grad_output, float epsilon, - int64_t feature_index); - -// Returns the size of the given dimension of the operand. The operand must be -// array shaped. -XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); - -// Sets the size of the given dimension of the operand. The operand must be -// array shaped. The result will have the same shape as the operand, but the -// given dimension will be dynamic (if not already). -XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); - -// Returns the same op but with dynamic dimension removed. -XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); - -// Implementation details below this point. -// - -// Free function template implementations. - -template -XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); -} - -template -XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { - BorrowingLiteral literal( - reinterpret_cast(values.begin()), - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), - {static_cast(values.size())})); - return ConstantLiteral(builder, literal); -} - -template -XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(builder, literal); -} - -inline XlaOp ConstantR1(XlaBuilder* builder, const tsl::core::Bitmap& values) { - return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); -} - -template -XlaOp ConstantR2(XlaBuilder* builder, - std::initializer_list> values) { - return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); -} - -template -XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, - const Array& values, - const Layout& layout) { - return ConstantLiteral( - builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { - return ConstantLiteral(builder, - LiteralUtil::CreateFromArray(values)); -} - -template -XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, - const Array2D& values, - const Layout& layout) { - return ConstantLiteral( - builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp ConstantR2FromArray2D(XlaBuilder* builder, - const Array2D& values) { - return ConstantLiteral(builder, - LiteralUtil::CreateR2FromArray2D(values)); -} - -template -XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, - const Array3D& values, - const Layout& layout) { - return ConstantLiteral( - builder, - LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -XlaOp ConstantR3FromArray3D(XlaBuilder* builder, - const Array3D& values) { - return ConstantFromArray(builder, values); -} - -template -XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, - const Array4D& values, - const Layout& layout) { - return ConstantFromArrayWithLayout(builder, values, layout); -} - -template -XlaOp ConstantR4FromArray4D(XlaBuilder* builder, - const Array4D& values) { - return ConstantFromArray(builder, values); -} - -// Switches from automatic SPMD partitioning to manual partitioning. Converts a -// full-shaped tensor (to be automatically partitioned by SPMD partitioner) to a -// shard-shaped tensor to be consumed by manually partitioned ops. -absl::StatusOr ConvertSpmdFullToShardShape( - xla::XlaBuilder* builder, xla::XlaOp input, int single_dim, - const xla::OpSharding& manual_sharding, - absl::Span unspecified_dims); - -// Switches from manual partitioning to automatic SPMD partitioning. Converts a -// shard-shaped tensor (manually partitioned in SPMD-style) to a full-shaped -// tensor to be partitioned automatically by the SPMD partitioner. -absl::StatusOr ConvertSpmdShardToFullShape( - xla::XlaBuilder* builder, xla::XlaOp input, const xla::Shape& output_shape, - int single_dim, const xla::OpSharding& manual_sharding, - absl::Span unspecified_dims); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/xla_builder.h" #endif // XLA_CLIENT_XLA_BUILDER_H_ diff --git a/third_party/xla/xla/client/xla_computation.h b/third_party/xla/xla/client/xla_computation.h index 52a54aa113b178..685fcfecb0b093 100644 --- a/third_party/xla/xla/client/xla_computation.h +++ b/third_party/xla/xla/client/xla_computation.h @@ -16,58 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_XLA_COMPUTATION_H_ #define XLA_CLIENT_XLA_COMPUTATION_H_ -#include -#include -#include - -#include "absl/status/statusor.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape.h" -#include "xla/status_macros.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// The computation graph that the user builds up with the XlaBuilder. -class XlaComputation { - public: - XlaComputation() : unique_id_(-1) {} - XlaComputation(HloModuleProto proto) - : unique_id_(proto.id()), proto_(std::move(proto)) {} - - ~XlaComputation() = default; - - XlaComputation(const XlaComputation&) = delete; - XlaComputation& operator=(const XlaComputation&) = delete; - - XlaComputation(XlaComputation&& from) = default; - - XlaComputation& operator=(XlaComputation&& from) = default; - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - absl::StatusOr GetProgramShape() const; - - const std::string& name() const { return proto().name(); } - - const HloModuleProto& proto() const { return proto_; } - HloModuleProto* mutable_proto() { return &proto_; } - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - absl::StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return unique_id_ == -1; } - - private: - XlaComputation(const int64_t unique_id) : unique_id_(unique_id) {} - friend class XlaBuilder; - - int64_t unique_id_; - HloModuleProto proto_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/xla_computation.h" #endif // XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/third_party/xla/xla/hlo/builder/BUILD b/third_party/xla/xla/hlo/builder/BUILD new file mode 100644 index 00000000000000..783e358f4d42d7 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/BUILD @@ -0,0 +1,190 @@ +# Description: +# XLA builder libraries. + +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.default.bzl", "filegroup") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "padding", + srcs = ["padding.cc"], + hdrs = ["padding.h"], + deps = [ + "//xla:util", + "//xla/tsl/lib/math:math_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "padding_test", + srcs = ["padding_test.cc"], + deps = [ + ":padding", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "sharding_builder", + srcs = ["sharding_builder.cc"], + hdrs = ["sharding_builder.h"], + deps = [ + "//xla:array", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/log:check", + ], +) + +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "value_inference", + srcs = ["value_inference.cc"], + hdrs = ["value_inference.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_builder", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], + deps = [ + ":padding", + ":sharding_builder", + ":xla_computation", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:array4d", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:sharding_op_util", + "//xla:status_macros", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "//xla/service:shape_inference", + "//xla/tsl/lib/core:bitmap", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:stacktrace", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":padding", + ":sharding_builder", + ":value_inference", + ":xla_builder", + ":xla_computation", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:hlo_proto_cc", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/hlo/builder/lib/BUILD b/third_party/xla/xla/hlo/builder/lib/BUILD new file mode 100644 index 00000000000000..d523c4a8af2ec7 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/BUILD @@ -0,0 +1,787 @@ +# Common computation builders for XLA. + +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla/hlo/builder:friends"]), + licenses = ["notice"], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +# Generate test_suites for all backends, named "${backend}_tests". +generate_backend_suites() + +cc_library( + name = "arithmetic", + srcs = ["arithmetic.cc"], + hdrs = ["arithmetic.h"], + deps = [ + ":constants", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "arithmetic_test", + srcs = ["arithmetic_test.cc"], + deps = [ + ":arithmetic", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "comparators", + srcs = ["comparators.cc"], + hdrs = [ + "comparators.h", + ], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "comparators_test", + srcs = ["comparators_test.cc"], + deps = [ + ":comparators", + ":constants", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:protobuf", + ], +) + +cc_library( + name = "constants", + srcs = ["constants.cc"], + hdrs = ["constants.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "broadcast", + srcs = ["broadcast.cc"], + hdrs = ["broadcast.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "constants_test", + srcs = ["constants_test.cc"], + deps = [ + ":constants", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "conv_grad_size_util", + srcs = ["conv_grad_size_util.cc"], + hdrs = ["conv_grad_size_util.h"], + deps = [ + "//xla/hlo/builder:padding", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "dynamic_shaped_ops", + srcs = ["dynamic_shaped_ops.cc"], + hdrs = ["dynamic_shaped_ops.h"], + deps = [ + ":constants", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:value_inference", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "loops", + srcs = ["loops.cc"], + hdrs = ["loops.h"], + deps = [ + ":constants", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "math", + srcs = ["math.cc"], + hdrs = [ + "math.h", + "math_impl.h", + ], + deps = [ + ":arithmetic", + ":constants", + ":loops", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "math_test", + timeout = "long", + srcs = ["math_test.cc"], + backend_tags = { + # Times out. + "ghostfish_iss": ["noasan"], + }, + deps = [ + ":constants", + ":math", + "//xla:array3d", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/service", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "matrix", + srcs = ["matrix.cc"], + hdrs = ["matrix.h"], + deps = [ + ":arithmetic", + ":constants", + ":slicing", + "//xla:literal", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "matrix_test", + srcs = ["matrix_test.cc"], + deps = [ + ":constants", + ":matrix", + ":slicing", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:array4d", + "//xla:test", + "//xla:types", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + ":arithmetic", + ":constants", + ":conv_grad_size_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "pooling_test", + srcs = ["pooling_test.cc"], + deps = [ + ":pooling", + "//xla:error_spec", + "//xla:shape_util", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "prng", + srcs = ["prng.cc"], + hdrs = ["prng.h"], + deps = [ + ":constants", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "prng_test", + srcs = ["prng_test.cc"], + deps = [ + ":constants", + ":prng", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], + deps = [ + ":constants", + ":matrix", + ":slicing", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "qr_test", + srcs = ["qr_test.cc"], + tags = ["optonly"], + deps = [ + ":matrix", + ":qr", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "lu_decomposition", + srcs = ["lu_decomposition.cc"], + hdrs = ["lu_decomposition.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "approx_topk", + srcs = ["approx_topk.cc"], + hdrs = ["approx_topk.h"], + deps = [ + ":approx_topk_shape", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "approx_topk_shape", + srcs = ["approx_topk_shape.cc"], + hdrs = ["approx_topk_shape.h"], + deps = [ + "//xla:util", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + ":arithmetic", + ":constants", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "slicing_test", + srcs = ["slicing_test.cc"], + deps = [ + ":slicing", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:literal", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "sorting", + srcs = ["sorting.cc"], + hdrs = ["sorting.h"], + deps = [ + ":comparators", + ":constants", + ":loops", + ":slicing", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "sorting_test", + srcs = ["sorting_test.cc"], + deps = [ + ":sorting", + "//xla:array", + "//xla:array2d", + "//xla:error_spec", + "//xla:literal_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + ], +) + +cc_library( + name = "quantize", + hdrs = ["quantize.h"], + deps = [ + ":constants", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@local_tsl//tsl/platform:bfloat16", + ], +) + +xla_test( + name = "quantize_test", + srcs = ["quantize_test.cc"], + # TODO(b/122119490): re-enable TAP after fixing. + tags = [ + "manual", + "notap", + ], + deps = [ + ":quantize", + "//xla:array2d", + "//xla:test", + "//xla:types", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:bfloat16", + ], +) + +cc_library( + name = "self_adjoint_eig", + srcs = ["self_adjoint_eig.cc"], + hdrs = ["self_adjoint_eig.h"], + deps = [ + ":slicing", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "self_adjoint_eig_test", + srcs = ["self_adjoint_eig_test.cc"], + real_hardware_only = True, + shard_count = 5, + tags = ["optonly"], + deps = [ + ":arithmetic", + ":constants", + ":math", + ":matrix", + ":self_adjoint_eig", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "svd", + srcs = ["svd.cc"], + hdrs = ["svd.h"], + deps = [ + ":arithmetic", + ":comparators", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "svd_test", + srcs = ["svd_test.cc"], + real_hardware_only = True, + shard_count = 10, + tags = ["optonly"], + deps = [ + ":arithmetic", + ":constants", + ":matrix", + ":slicing", + ":svd", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "tridiagonal", + srcs = ["tridiagonal.cc"], + hdrs = ["tridiagonal.h"], + deps = [ + ":constants", + ":loops", + ":slicing", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "tridiagonal_test", + srcs = ["tridiagonal_test.cc"], + real_hardware_only = True, + shard_count = 10, + tags = ["optonly"], + deps = [ + ":slicing", + ":tridiagonal", + "//xla:array", + "//xla:array3d", + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "logdet", + srcs = ["logdet.cc"], + hdrs = ["logdet.h"], + deps = [ + ":arithmetic", + ":constants", + ":matrix", + ":qr", + ":slicing", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "logdet_test", + srcs = ["logdet_test.cc"], + tags = [ + "optonly", + ], + deps = [ + ":logdet", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "tuple", + srcs = ["tuple.cc"], + hdrs = ["tuple.h"], + deps = [ + "//xla:shape_tree", + "//xla:shape_util", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "tuple_test", + srcs = ["tuple_test.cc"], + deps = [ + ":tuple", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/service", + "//xla/tests:client_library_test_base", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/client/lib/approx_topk.cc b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc similarity index 98% rename from third_party/xla/xla/client/lib/approx_topk.cc rename to third_party/xla/xla/hlo/builder/lib/approx_topk.cc index 7a5c7bd379cb82..16e9c090e9dd3b 100644 --- a/third_party/xla/xla/client/lib/approx_topk.cc +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/approx_topk.h" +#include "xla/hlo/builder/lib/approx_topk.h" #include #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xla/client/lib/approx_topk_shape.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk.h b/third_party/xla/xla/hlo/builder/lib/approx_topk.h new file mode 100644 index 00000000000000..f940d26967cc76 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.h @@ -0,0 +1,72 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_APPROX_TOPK_H_ +#define XLA_HLO_BUILDER_LIB_APPROX_TOPK_H_ + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Computes approximate top-ks by aggregating top-1s in equal-sized windows. +// The number and the size of the windows are determined by the `recall_target`. +// +// operand: A sequence of multi-dimensional arrays of type T_0, ..., T_{N-1} +// init_values: N starting values for top-1 reductions +// top_k: Determines the k in top-k operation. +// reduction_dim: Determines the dimension to compute top-k. +// comparator: The comparator computation to use, which should have function +// signatore of (T_0, T_0, T_1, T_1, ..., T_{N-1}, T_{N-1}) -> bool. +// recall_target: Valid range (0, 1]. User can trade-off quality and performance +// with this knob. +// aggregate_to_topk: When true, sorts the set of approximate top-k elements and +// only keep the final k elements on TPU. This option is useful when user +// wanted to forward the approximate results to host and aggregate the results +// on CPU for better throughput. +// reduction_input_size_override: When set to a positive value, it overrides the +// size determined by operands[reduction_dim] for evaluating the recall. This +// option is useful when the given operand is only a subset of the overall +// computation in SPMD or distributed pipelines, where the true input size +// cannot be deferred by the operand shape. +// +// Returns a sequence of multidimensional arrays of type T_0, ..., T_{N-1}, +// which contains the approximate top-ks from the input operands. When +// `aggregate_to_topk` is set to true, the output size is just top_k. When +// `aggregate_to_topk` is set to false, the output size varied by the target +// recall. For target recall = 0.9, the output size is roughly 10 * top_k. For +// target recall = 0.99, the output size is roughly 100 * top_k. +// +// TODO(fchern): Support other hardware platforms. +XlaOp ApproxTopK(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, int64_t top_k, + int64_t reduction_dim, const XlaComputation& comparator, + float recall_target = 0.9, bool aggregate_to_topk = true, + int64_t reduction_input_size_override = -1); + +// Fallback for platforms that haven't been optimized. +XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, int64_t top_k, + int64_t reduction_dim, + const XlaComputation& comparator, + float recall_target = 0.9, + bool aggregate_to_topk = true, + int64_t reduction_input_size_override = -1); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_APPROX_TOPK_H_ diff --git a/third_party/xla/xla/client/lib/approx_topk_shape.cc b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.cc similarity index 98% rename from third_party/xla/xla/client/lib/approx_topk_shape.cc rename to third_party/xla/xla/hlo/builder/lib/approx_topk_shape.cc index 374aa01830fccf..f6925f330c1267 100644 --- a/third_party/xla/xla/client/lib/approx_topk_shape.cc +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/approx_topk_shape.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h new file mode 100644 index 00000000000000..83b2b71d1054e5 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h @@ -0,0 +1,50 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_APPROX_TOPK_SHAPE_H_ +#define XLA_HLO_BUILDER_LIB_APPROX_TOPK_SHAPE_H_ + +#include + +#include "absl/status/statusor.h" + +namespace xla { + +// Determine the output size of the reduction dimension. This is useful for jax +// abstract eval to determine the output size. +// +// input_size: Input size of the reduction dimension. +// rank: Rank of the input operand. +// top_k: Determines the k in top-k operation. +// recall_target: Valid range (0, 1]. User can trade-off quality and performance +// with this knob. +// aggregate_to_topk: When true, sorts the set of approximate top-k elements and +// only keep the final k elements on TPU. This option is useful when user +// wanted to forward the approximate results to host and aggregate the results +// on CPU for better throughput. +// +// Returns a pair of +// 1. Reduction output size +// 2. Reduction amount in log2 form. +// +// 2. is invalid and set to -1 when the approximate output is disabled, i.e. +// top_k = 1 or aggregate_to_topk = true. +absl::StatusOr> ApproxTopKReductionOutputSize( + int64_t input_size, int64_t rank, int64_t top_k, float recall_target, + bool aggregate_to_topk, int64_t input_size_override = -1); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_APPROX_TOPK_SHAPE_H_ diff --git a/third_party/xla/xla/client/lib/arithmetic.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc similarity index 97% rename from third_party/xla/xla/client/lib/arithmetic.cc rename to third_party/xla/xla/hlo/builder/lib/arithmetic.cc index e14bd9118def05..6ec14f7dd31d43 100644 --- a/third_party/xla/xla/client/lib/arithmetic.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/arithmetic.h" +#include "xla/hlo/builder/lib/arithmetic.h" #include #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/arithmetic.h b/third_party/xla/xla/hlo/builder/lib/arithmetic.h new file mode 100644 index 00000000000000..fda730573f37f8 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic.h @@ -0,0 +1,90 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_ARITHMETIC_H_ +#define XLA_HLO_BUILDER_LIB_ARITHMETIC_H_ + +#include +#include +#include + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +using XlaOpGenerator = std::function; + +// Creates a scalar computation based on a lambda and returns it. +XlaComputation CreateScalarComputation(const std::string& name, + PrimitiveType type, XlaBuilder* builder, + XlaOpGenerator generator); + +// Creates a scalar add computation and returns it. +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar multiply computation and returns it. +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar ge computation and returns it. +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar max computation and returns it. +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar min computation and returns it. +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar logical AND computation and returns it. +XlaComputation CreateScalarAndComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar logical OR computation and returns it. +XlaComputation CreateScalarOrComputation(PrimitiveType type, + XlaBuilder* builder); + +// This is to be used for general purpose "identity" like reductions with zero +// for any type (ie. boolean operations for PRED and Add for real numbers). +// As an example, this operation can be used for a situation of: +// x_type = type(x) +// op = CreateScalarIdentityWithZeroComputation(x_type) +// ASSERT_TRUE(op(x, 0) == x) +// +// This functionality is used for operations that are similar to a slice, +// gather, or broadcast, but are created through a reduction. +XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, + XlaBuilder* builder); + +// Returns whether any predicate in "predicates" is set. +// +// Note: if predicates is zero-sized, Any() vacuously returns false. +XlaOp Any(XlaOp predicates); + +// Returns the argmax of `input` along `axis`. `output_type` is the type to +// use for the output. In case of ties always prefers smaller index. +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); + +// Dispatch to ArgMin or ArgMax above, depending on bool. +XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_ARITHMETIC_H_ diff --git a/third_party/xla/xla/client/lib/arithmetic_test.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/arithmetic_test.cc rename to third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc index abbdf06fb8b731..3cde6bf0f4e5c3 100644 --- a/third_party/xla/xla/client/lib/arithmetic_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/arithmetic.h" +#include "xla/hlo/builder/lib/arithmetic.h" #include #include #include "absl/types/span.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/broadcast.cc b/third_party/xla/xla/hlo/builder/lib/broadcast.cc similarity index 97% rename from third_party/xla/xla/client/lib/broadcast.cc rename to third_party/xla/xla/hlo/builder/lib/broadcast.cc index 8c3336ef9e312e..aaabe046cebb02 100644 --- a/third_party/xla/xla/client/lib/broadcast.cc +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/broadcast.h" +#include "xla/hlo/builder/lib/broadcast.h" #include @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.h b/third_party/xla/xla/hlo/builder/lib/broadcast.h new file mode 100644 index 00000000000000..86cf39f64ddc82 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.h @@ -0,0 +1,35 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_BROADCAST_H_ +#define XLA_HLO_BUILDER_LIB_BROADCAST_H_ + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/primitive_util.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting +// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. +absl::StatusOr BroadcastTo(XlaOp input, + absl::Span output_dims); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_BROADCAST_H_ diff --git a/third_party/xla/xla/client/lib/comparators.cc b/third_party/xla/xla/hlo/builder/lib/comparators.cc similarity index 97% rename from third_party/xla/xla/client/lib/comparators.cc rename to third_party/xla/xla/hlo/builder/lib/comparators.cc index 771d5331803d49..fec1874a0373d4 100644 --- a/third_party/xla/xla/client/lib/comparators.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/comparators.h" +#include "xla/hlo/builder/lib/comparators.h" #include #include @@ -23,8 +23,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/comparators.h b/third_party/xla/xla/hlo/builder/lib/comparators.h new file mode 100644 index 00000000000000..8dd3e47e07eb48 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/comparators.h @@ -0,0 +1,60 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_COMPARATORS_H_ +#define XLA_HLO_BUILDER_LIB_COMPARATORS_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Creates a scalar less-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN +XlaComputation CreateScalarLtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +// Creates a scalar greater-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN +XlaComputation CreateScalarGtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +// Creates a scalar comparison computation and returns it. This function takes +// a vector of comparator functions to compare the operands where the function +// isn't nullopt with the specified comparator at that location. +XlaComputation CreateScalarComparisonComputation( + const std::string& name, const std::vector& operand_types, + const std::vector< + std::optional)>>& + generators, + XlaBuilder* builder); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_COMPARATORS_H_ diff --git a/third_party/xla/xla/client/lib/comparators_test.cc b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/comparators_test.cc rename to third_party/xla/xla/hlo/builder/lib/comparators_test.cc index acaf2f19985276..39bf073171a86b 100644 --- a/third_party/xla/xla/client/lib/comparators_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/comparators.h" +#include "xla/hlo/builder/lib/comparators.h" #include #include @@ -21,9 +21,9 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/client/lib/constants.cc b/third_party/xla/xla/hlo/builder/lib/constants.cc similarity index 98% rename from third_party/xla/xla/client/lib/constants.cc rename to third_party/xla/xla/hlo/builder/lib/constants.cc index 1e5a7fae4c9c10..acfa2fe0b66e2c 100644 --- a/third_party/xla/xla/client/lib/constants.cc +++ b/third_party/xla/xla/hlo/builder/lib/constants.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/constants.h" +#include "xla/hlo/builder/lib/constants.h" #include #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/hlo/builder/lib/constants.h b/third_party/xla/xla/hlo/builder/lib/constants.h new file mode 100644 index 00000000000000..ce695736d1e49c --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/constants.h @@ -0,0 +1,140 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_CONSTANTS_H_ +#define XLA_HLO_BUILDER_LIB_CONSTANTS_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/primitive_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is +// determined at C++ run-time, rather than C++ compile-time. +// If 'value' is floating point but 'type' is not, or if 'value' is complex but +// 'type' is not, an error will be returned. This is to catch accidental +// truncation; in such cases, use an explicit cast. +template +XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { + if (std::is_floating_point::value && + !(primitive_util::IsFloatingPointType(type) || + primitive_util::IsComplexType(type))) { + return builder->ReportError(InvalidArgument( + "Invalid cast from floating point type to %s in ConstantR0WithType.", + PrimitiveType_Name(type))); + } + if (std::is_same::value && + !primitive_util::IsComplexType(type)) { + return builder->ReportError(InvalidArgument( + "Invalid cast from complex type to %s in ConstantR0WithType.", + PrimitiveType_Name(type))); + } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> XlaOp { + if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { + using NativeT = primitive_util::NativeTypeOf; + return ConstantR0(builder, static_cast(value)); + } + return builder->ReportError( + InvalidArgument("Invalid type for ConstantR0WithType (%s).", + PrimitiveType_Name(type))); + }, + type); +} + +// Returns a scalar containing 'value' cast to the same run-time type as +// 'prototype'. +// If 'value' is floating point but 'prototype' is not, or if 'value' is complex +// 'prototype' is not, an error will be returned. +template +XlaOp ScalarLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + return ConstantR0WithType(builder, shape.element_type(), value); + }); +} + +// Returns an array or scalar containing copies of `value` cast to the same +// run-type type as `prototype` and broadcast to the same dimensions as +// `prototype`. +// +// If `prototype` is not a scalar or array, returns an error. +template +XlaOp FullLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { + return Broadcast(ScalarLike(prototype, value), shape.dimensions()); + } else { + return InvalidArgument( + "Prototype shape for BroadcastConstantLike must be a scalar or " + "array, but was %s", + shape.ToString()); + } + }); +} + +// Returns a scalar with value '0' of 'type'. +XlaOp Zero(XlaBuilder* builder, PrimitiveType type); + +// Returns a zero-filled tensor with shape `shape`. +XlaOp Zeros(XlaBuilder* builder, const Shape& shape); + +// Returns a zero-filled tensor with the same shape as `prototype`. +XlaOp ZerosLike(XlaOp prototype); + +// Returns a scalar with value '1' of 'type'. +XlaOp One(XlaBuilder* builder, PrimitiveType type); + +// Returns the machine epsilon for floating-point type `type`, i.e., +// the difference between 1.0 and the next representable value. +XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite or infinite value for 'type'. +// Returns '-inf' for floating-point types. +XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite value for 'type'. For a floating +// point type, this is equal to -MaxFiniteValue(). +XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum positive normal value for floating-point type `type`. +XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite or infinite value for 'type'. +// Returns 'inf' for floating-point types. +XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite value for 'type'. +XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); + +// Returns a nan for the given type. Only valid for real-valued fp types. +XlaOp NanValue(XlaBuilder* builder, PrimitiveType type); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_CONSTANTS_H_ diff --git a/third_party/xla/xla/client/lib/constants_test.cc b/third_party/xla/xla/hlo/builder/lib/constants_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/constants_test.cc rename to third_party/xla/xla/hlo/builder/lib/constants_test.cc index 2ae344f2e6cf9e..61aa0ae71dee5b 100644 --- a/third_party/xla/xla/client/lib/constants_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/constants_test.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/constants.h" +#include "xla/hlo/builder/lib/constants.h" #include -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/client/lib/conv_grad_size_util.cc b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc similarity index 97% rename from third_party/xla/xla/client/lib/conv_grad_size_util.cc rename to third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc index f08328c9086b4f..9bbe184a9d6140 100644 --- a/third_party/xla/xla/client/lib/conv_grad_size_util.cc +++ b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/conv_grad_size_util.h" +#include "xla/hlo/builder/lib/conv_grad_size_util.h" #include #include "absl/log/log.h" #include "absl/status/statusor.h" -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h new file mode 100644 index 00000000000000..91e43d226c180b --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_CONV_GRAD_SIZE_UTIL_H_ +#define XLA_HLO_BUILDER_LIB_CONV_GRAD_SIZE_UTIL_H_ + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/padding.h" + +namespace xla { + +// Information about a single spatial dimension for a convolution gradients and +// windowed operations. +struct SpatialDimensionOutputSizeAndPadding { + // Effective size of the operation output (potentially expanded). + int64_t output_size; + // Number of padding elements to be added before/after this dimension of + // the input when computing the input gradient. + int64_t pad_before; + int64_t pad_after; +}; + +// Verifies that the dimensions all match, and computes the size and padding of +// a spatial dimension for convolution gradient operations. +absl::StatusOr +ConvGradExtractAndVerifyDimension(int64_t input_size, int64_t filter_size, + int64_t output_size, int64_t dilation, + int64_t stride, Padding padding); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_CONV_GRAD_SIZE_UTIL_H_ diff --git a/third_party/xla/xla/client/lib/dynamic_shaped_ops.cc b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc similarity index 98% rename from third_party/xla/xla/client/lib/dynamic_shaped_ops.cc rename to third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc index c263d31badcdf5..ba82ec343ce55a 100644 --- a/third_party/xla/xla/client/lib/dynamic_shaped_ops.cc +++ b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h new file mode 100644 index 00000000000000..71188b8fb80a22 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h @@ -0,0 +1,59 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_DYNAMIC_SHAPED_OPS_H_ +#define XLA_HLO_BUILDER_LIB_DYNAMIC_SHAPED_OPS_H_ + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/primitive_util.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Similar to static shaped conditional, but allows true_computation and +// false_computation to have different dimension sizes (ranks still have to be +// the same). Fall back to static conditional if dynamism is not presented. +XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate, + XlaOp true_operand, + const XlaComputation& true_computation, + XlaOp false_operand, + const XlaComputation& false_computation); + +// Similar to DynamicConditional, but support multiple branches. +XlaOp DynamicConditional( + XlaBuilder* builder, XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + +// Similar to SetDimensionSize, but automatically adjust the bound of output if +// a tighter one can be inferred by `value_inference`. +absl::StatusOr SetDimensionSizeWithRebound( + ValueInference* value_inference, XlaOp operand, XlaOp dimension_size, + int64_t dimension); + +// Take a `operand` tensor and a R1 tensor `size_vector` representing the sizes +// of `operand`, Call SetDimensionSize if for each dimension whose size is +// dynamic. +absl::StatusOr SetAllDimensionSizes(ValueInference* value_inference, + XlaOp operand, XlaOp size_vector); +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_DYNAMIC_SHAPED_OPS_H_ diff --git a/third_party/xla/xla/client/lib/generate_math_impl.py b/third_party/xla/xla/hlo/builder/lib/generate_math_impl.py similarity index 100% rename from third_party/xla/xla/client/lib/generate_math_impl.py rename to third_party/xla/xla/hlo/builder/lib/generate_math_impl.py diff --git a/third_party/xla/xla/client/lib/logdet.cc b/third_party/xla/xla/hlo/builder/lib/logdet.cc similarity index 90% rename from third_party/xla/xla/client/lib/logdet.cc rename to third_party/xla/xla/hlo/builder/lib/logdet.cc index 96063a5c72431f..cc17d0ec26ffe6 100644 --- a/third_party/xla/xla/client/lib/logdet.cc +++ b/third_party/xla/xla/hlo/builder/lib/logdet.cc @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/logdet.h" +#include "xla/hlo/builder/lib/logdet.h" #include #include #include #include "absl/status/statusor.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/qr.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/qr.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/client/lib/logdet.h b/third_party/xla/xla/hlo/builder/lib/logdet.h similarity index 87% rename from third_party/xla/xla/client/lib/logdet.h rename to third_party/xla/xla/hlo/builder/lib/logdet.h index ee3d984fa69319..8c02d72de9940b 100644 --- a/third_party/xla/xla/client/lib/logdet.h +++ b/third_party/xla/xla/hlo/builder/lib/logdet.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_CLIENT_LIB_LOGDET_H_ -#define XLA_CLIENT_LIB_LOGDET_H_ +#ifndef XLA_HLO_BUILDER_LIB_LOGDET_H_ +#define XLA_HLO_BUILDER_LIB_LOGDET_H_ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" namespace xla { @@ -34,4 +34,4 @@ XlaOp LogDet(XlaOp a); } // namespace xla -#endif // XLA_CLIENT_LIB_LOGDET_H_ +#endif // XLA_HLO_BUILDER_LIB_LOGDET_H_ diff --git a/third_party/xla/xla/client/lib/logdet_test.cc b/third_party/xla/xla/hlo/builder/lib/logdet_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/logdet_test.cc rename to third_party/xla/xla/hlo/builder/lib/logdet_test.cc index b2600ed7f7ea23..8618aab2aa833d 100644 --- a/third_party/xla/xla/client/lib/logdet_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/logdet_test.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/logdet.h" +#include "xla/hlo/builder/lib/logdet.h" #include #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/client/lib/loops.cc b/third_party/xla/xla/hlo/builder/lib/loops.cc similarity index 97% rename from third_party/xla/xla/client/lib/loops.cc rename to third_party/xla/xla/hlo/builder/lib/loops.cc index 5785e9969dee8f..e7dbad01163d93 100644 --- a/third_party/xla/xla/client/lib/loops.cc +++ b/third_party/xla/xla/hlo/builder/lib/loops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/loops.h" +#include "xla/hlo/builder/lib/loops.h" #include #include @@ -23,8 +23,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/loops.h b/third_party/xla/xla/hlo/builder/lib/loops.h new file mode 100644 index 00000000000000..540ab784f34684 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/loops.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_LOOPS_H_ +#define XLA_HLO_BUILDER_LIB_LOOPS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Function that builds a loop condition. Takes as input a sequence of input +// values, and returns a boolean value representing if the condition succeeds. +typedef std::function(absl::Span, + XlaBuilder*)> + WhileLoopHelperConditionFunction; + +// Function that builds a loop body. Takes as input a sequence of input values +// and returns a sequence of output values. +typedef std::function>( + absl::Span, XlaBuilder*)> + WhileLoopHelperBodyFunction; + +// Helper function for building an XLA while loop, where the values carried by +// the loop are a tuple of values, e.g., (a, b, c): +// while( +// condition: (a, b, c) -> bool, +// body: (a, b, c) -> (a, b, c) +// init: (a, b, c) +// ) +// 'name' is a descriptive name for the loop. +absl::StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); + +// Builds an XLA loop that repeats a computation `num_iterations` times. +// +// The body function (ForEachIndexBodyFunction) takes as input a pair of +// (current iteration number, loop-carried values), and returns an updated +// vector of the loop-carried values. +typedef std::function>( + XlaOp, absl::Span, XlaBuilder*)> + ForEachIndexBodyFunction; + +absl::StatusOr> ForEachIndex( + int64_t num_iterations, PrimitiveType num_iterations_type, + const ForEachIndexBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_LOOPS_H_ diff --git a/third_party/xla/xla/client/lib/lu_decomposition.cc b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc similarity index 96% rename from third_party/xla/xla/client/lib/lu_decomposition.cc rename to third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc index b4f00876ce36a8..78e9c00e07ca1a 100644 --- a/third_party/xla/xla/client/lib/lu_decomposition.cc +++ b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/lu_decomposition.h" +#include "xla/hlo/builder/lib/lu_decomposition.h" #include #include #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/lu_decomposition.h b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.h new file mode 100644 index 00000000000000..d233dab04f50e2 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_LU_DECOMPOSITION_H_ +#define XLA_HLO_BUILDER_LIB_LU_DECOMPOSITION_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Computes the LU decomposition with partial pivoting of a batch of matrices. +// +// Given a (batched) matrix a with shape [..., m, n], computes the matrix +// decomposition A = P @ L @ U where P is a permutation matrix, L is a +// lower-triangular matrix with unit diagonal entries, and U is an +// upper-triangular matrix. +// +// L and U are returned as a single matrix [..., m, n] containing both L and U +// packed in the same array. The unit diagonal of L is not represented +// explicitly. +// +// The permutation matrix P is returned in two forms, both as `pivots`, which is +// an s32[..., min(m, n)] array that describes a sequence of row-swaps in the +// style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array +// which gives the permutation to apply to the rows. We return both +// representations because they are each useful for different purposes; `pivots` +// is useful for computing the sign of a determinant, whereas `permutation` can +// be used via a Gather operation to permute the rows of a matrix. +// +// This method is only implemented on TPU at the moment. +// TODO(b/168208200): the implementation only supports F32 arrays. Handle the +// complex case. +struct LuDecompositionResult { + // The LU decomposition, with both L and U packed into an array with shape + // [..., m, n]. + XlaOp lu; + // An array of shape s32[..., min(m, n)] containing the pivot rows. + XlaOp pivots; + // An array of shape s32[..., m], containing an another representation of the + // pivots as a permutation. + XlaOp permutation; +}; + +LuDecompositionResult LuDecomposition(XlaOp a); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_LU_DECOMPOSITION_H_ diff --git a/third_party/xla/xla/client/lib/math.cc b/third_party/xla/xla/hlo/builder/lib/math.cc similarity index 99% rename from third_party/xla/xla/client/lib/math.cc rename to third_party/xla/xla/hlo/builder/lib/math.cc index c3f27638bdd2e9..f7c00aece14d0e 100644 --- a/third_party/xla/xla/client/lib/math.cc +++ b/third_party/xla/xla/hlo/builder/lib/math.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/math.h" +#include "xla/hlo/builder/lib/math.h" #include #include @@ -28,11 +28,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/math_impl.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/math_impl.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/math.h b/third_party/xla/xla/hlo/builder/lib/math.h new file mode 100644 index 00000000000000..6c26ec20410c64 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/math.h @@ -0,0 +1,127 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_MATH_H_ +#define XLA_HLO_BUILDER_LIB_MATH_H_ + +#include "xla/hlo/builder/xla_builder.h" + +namespace xla { + +// Determines whether operand is +/-inf or nan. +// +// Raises an error if called on integral or complex values. +XlaOp IsPosInf(XlaOp operand); +XlaOp IsNegInf(XlaOp operand); +XlaOp IsInf(XlaOp operand); +XlaOp IsNan(XlaOp operand); + +// Determines whether operand is equal to -0. +// +// Raises an error for integral or complex values. +XlaOp IsNegZero(XlaOp operand); + +// Returns the next number after 'from' in the direction of 'to' the same way +// std::nextafter(from, to) would. +XlaOp NextAfter(XlaOp from, XlaOp to); + +// Computes the square of 'operand'. +XlaOp Square(XlaOp operand); + +// Computes the reciprocal of 'operand'. +XlaOp Reciprocal(XlaOp operand); + +// Computes an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x); + +// Computes an approximation of the inverse of the error function. +XlaOp ErfInv(XlaOp x); + +// Computes an approximation of the lgamma function. +XlaOp Lgamma(XlaOp input); + +// Computes an approximation of the digamma function. +XlaOp Digamma(XlaOp input); + +// Computes an approximation of the incomplete gamma function. +XlaOp Igamma(XlaOp a, XlaOp x); + +// Computes an approximation of the derivative of the incomplete gamma function +// with respect to a. +XlaOp IgammaGradA(XlaOp a, XlaOp x); + +// Computes an approximation of the derivative of a sample `x` from a `Gamma(a, +// 1)` distribution with respect to a. +XlaOp RandomGammaGrad(XlaOp a, XlaOp x); + +// Computes an approximation of the complementary incomplete gamma function. +XlaOp Igammac(XlaOp a, XlaOp x); + +// Computes the Polygamma of two arguments. +XlaOp Polygamma(XlaOp n, XlaOp x); + +// Computes the Riemann zeta function of two arguments. +XlaOp Zeta(XlaOp x, XlaOp q); + +// Rounds the given number to even when the number is equidistant between two +// integers. +XlaOp RoundToEven(XlaOp x); + +// Trigonometric functions + +// Computes the arc cosine of 'x'. +XlaOp Acos(XlaOp x); + +// Computes the arc sine of 'x'. +XlaOp Asin(XlaOp x); + +// Computes the arc tangent of 'x'. +XlaOp Atan(XlaOp x); + +// Hyperbolic trigonometric functions + +// Computes the inverse hyperbolic cosine of 'x'. +XlaOp Acosh(XlaOp x); + +// Computes the inverse hyperbolic sine of 'x'. +XlaOp Asinh(XlaOp x); + +// Computes the inverse hyperbolic tangent of 'x'. +XlaOp Atanh(XlaOp x); + +// Computes the hyperbolic cosine of 'x'. +XlaOp Cosh(XlaOp x); + +// Computes the hyperbolic sine of 'x'. +XlaOp Sinh(XlaOp x); + +// Applies a complex conjugation operation if 'a' is complex and 'conjugate' +// is true, otherwise returns its argument. +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); + +// Computes the Modified Bessel function of the first kind of the zeroth order +// at x. +XlaOp BesselI0e(XlaOp x); + +// Computes the Modified Bessel function of the first kind of the first order +// at x. +XlaOp BesselI1e(XlaOp x); + +// Computes the Regularized Incomplete Beta function. +XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_MATH_H_ diff --git a/third_party/xla/xla/client/lib/math_impl.h b/third_party/xla/xla/hlo/builder/lib/math_impl.h similarity index 97% rename from third_party/xla/xla/client/lib/math_impl.h rename to third_party/xla/xla/hlo/builder/lib/math_impl.h index f89851ad9366c2..262856d08c712f 100644 --- a/third_party/xla/xla/client/lib/math_impl.h +++ b/third_party/xla/xla/hlo/builder/lib/math_impl.h @@ -17,12 +17,12 @@ limitations under the License. // https://github.com/pearu/functional_algorithms // for more information. -#ifndef XLA_CLIENT_LIB_MATH_IMPL_H_ -#define XLA_CLIENT_LIB_MATH_IMPL_H_ +#ifndef XLA_HLO_BUILDER_LIB_MATH_IMPL_H_ +#define XLA_HLO_BUILDER_LIB_MATH_IMPL_H_ -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" namespace xla { namespace math_impl { @@ -256,4 +256,4 @@ XlaOp AsinReal(XlaOp x) { } // namespace math_impl } // namespace xla -#endif // XLA_CLIENT_LIB_MATH_IMPL_H_ +#endif // XLA_HLO_BUILDER_LIB_MATH_IMPL_H_ diff --git a/third_party/xla/xla/client/lib/math_test.cc b/third_party/xla/xla/hlo/builder/lib/math_test.cc similarity index 99% rename from third_party/xla/xla/client/lib/math_test.cc rename to third_party/xla/xla/hlo/builder/lib/math_test.cc index 0c5776f4bea333..ab6d5f1585f8cc 100644 --- a/third_party/xla/xla/client/lib/math_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/math_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/math.h" +#include "xla/hlo/builder/lib/math.h" #include #include @@ -26,9 +26,9 @@ limitations under the License. #include #include "xla/array3d.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" diff --git a/third_party/xla/xla/client/lib/matrix.cc b/third_party/xla/xla/hlo/builder/lib/matrix.cc similarity index 99% rename from third_party/xla/xla/client/lib/matrix.cc rename to third_party/xla/xla/hlo/builder/lib/matrix.cc index 38a1a67efde17f..7c189b762a49f3 100644 --- a/third_party/xla/xla/client/lib/matrix.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/matrix.h" +#include "xla/hlo/builder/lib/matrix.h" #include #include @@ -36,10 +36,10 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/hlo/builder/lib/matrix.h b/third_party/xla/xla/hlo/builder/lib/matrix.h new file mode 100644 index 00000000000000..8fdf01d438d7a1 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/matrix.h @@ -0,0 +1,159 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_MATRIX_H_ +#define XLA_HLO_BUILDER_LIB_MATRIX_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere +// else. +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64_t m, + int64_t n); + +// Returns a mask where the 'diagonal'-th diagonal is true and everything else +// is false. +XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0); + +// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the +// main diagonal, and k<0 for diagonals below the main diagonal. +// +// If 'x' has shape [..., M, N] +// If k >= 0: then the output has shape [..., min(M, N - k)], containing the +// diagonal elements (i.e., with indices [..., i, i + k]). +// If k < 0: then the output has shape [..., min(M + k, N)], containing the +// diagonal elements (i.e., with indices [..., i - k, i]). +XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); +XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0); + +// Places diag along the kth diagonal of target. +XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0); + +// Returns a lower-triangular mask, i.e., true below and including the +// `diagonal`-th diagonal and false above that diagonal. +XlaOp TriangleMask(XlaOp x, int diagonal); + +// Get the upper or lower triangle part of the last two dimensions +XlaOp Triangle(XlaOp x, bool lower); + +// Get the upper triangle part of the last two dimensions +XlaOp UpperTriangle(XlaOp x); + +// Get the lower triangle part of the last two dimensions +XlaOp LowerTriangle(XlaOp x); + +// If x is an array of shape [..., n, n], symmetrizes the matrix by replacing +// the upper triangle with the transpose of the lower triangle (if lower is +// True, vice-versa otherwise). If the type of `x` is complex, makes the matrix +// Hermitian by taking the conjugate of the complex part and setting the +// complex diagonal to zero. +XlaOp Symmetrize(XlaOp x, bool lower); + +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + std::optional preferred_element_type = std::nullopt); +xla::XlaOp BatchDot( + xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + std::optional preferred_element_type = std::nullopt, + bool grad_x = false, bool grad_y = false); + +// Parse an einsum string into dimension numbers: +// "ab,cb->ac" +// becomes: +// {{0, 1},{2, 1},{0, 2}} +// +// Each occurrence of ellipsis ("...") occurring in the input is replaced with +// the same numeric dimensions. The number of such dimensions is inferred from +// x_rank and y_rank. For example: +// einsum_config: "...ab,...bcd->...acd" +// x_rank: 4 +// y_rank: 5 +// becomes: +// {{0, 1, 2, 3},{0, 1, 3, 4, 5},{0, 1, 2, 4, 5}} +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. + +absl::StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config, int64_t x_rank, int64_t y_rank); + +// If an einsum config does not contain an -> one will be added and the output +// config will be the sorted characters with any ellipsis at the beginning. +// Returns an empty string if the einsum string already has an ->. +std::string NormalizeEinsumString(absl::string_view einsum_config); + +// Supports two operand einsum notation like "ab,cb->ac". +xla::XlaOp Einsum( + xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + std::optional preferred_element_type = std::nullopt, + bool grad_x = false, bool grad_y = false); +xla::XlaOp Einsum( + xla::XlaOp x, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Same as above but supporting numeric labels on dimensions. So "ab,cb->ac" +// becomes: +// x_config = {0, 1} +// y_config = {2, 1} +// output_config = {0, 2} +xla::XlaOp Einsum( + xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, absl::Span output_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + std::optional preferred_element_type = std::nullopt, + bool grad_x = false, bool grad_y = false); + +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); + +// Transposes `x` in its minor dimensions if `transpose` is true, otherwise +// returns `x` unchanged. +xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_MATRIX_H_ diff --git a/third_party/xla/xla/client/lib/matrix_test.cc b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/matrix_test.cc rename to third_party/xla/xla/hlo/builder/lib/matrix_test.cc index caa313b4ab8923..debb6e20ae0108 100644 --- a/third_party/xla/xla/client/lib/matrix_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/matrix.h" +#include "xla/hlo/builder/lib/matrix.h" #include #include @@ -28,9 +28,9 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/pooling.cc b/third_party/xla/xla/hlo/builder/lib/pooling.cc similarity index 98% rename from third_party/xla/xla/client/lib/pooling.cc rename to third_party/xla/xla/hlo/builder/lib/pooling.cc index 5f03ad45afb0fd..81dd1a7c4c0f95 100644 --- a/third_party/xla/xla/client/lib/pooling.cc +++ b/third_party/xla/xla/hlo/builder/lib/pooling.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/pooling.h" +#include "xla/hlo/builder/lib/pooling.h" #include #include @@ -22,11 +22,11 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/conv_grad_size_util.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/conv_grad_size_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/pooling.h b/third_party/xla/xla/hlo/builder/lib/pooling.h new file mode 100644 index 00000000000000..15176888939c04 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/pooling.h @@ -0,0 +1,83 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_POOLING_H_ +#define XLA_HLO_BUILDER_LIB_POOLING_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" + +namespace xla { + +// Tensor format for reduce window operations. +class TensorFormat { + public: + TensorFormat(int batch_dimension, int feature_dimension, + absl::Span spatial_dimensions) + : batch_dimension_(batch_dimension), + feature_dimension_(feature_dimension), + spatial_dimensions_(spatial_dimensions.begin(), + spatial_dimensions.end()) {} + + int batch_dimension() const { return batch_dimension_; } + + int feature_dimension() const { return feature_dimension_; } + + int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } + + int num_spatial_dims() const { return spatial_dimensions_.size(); } + + private: + // The number of the dimension that represents the batch. + int batch_dimension_; + // The number of the dimension that represents the features. + int feature_dimension_; + // The dimension numbers for the spatial dimensions. + absl::InlinedVector spatial_dimensions_; +}; + +// Computes the max pool of 'operand'. +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool of 'operand'. +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, + const TensorFormat& data_format, bool counts_include_padding); + +// Returns the list of low and high padding elements in each spatial dimension +// for the given 'padding' specification. +std::vector> MakeSpatialPadding( + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool gradient. +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, bool counts_include_padding); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_POOLING_H_ diff --git a/third_party/xla/xla/client/lib/pooling_test.cc b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc similarity index 99% rename from third_party/xla/xla/client/lib/pooling_test.cc rename to third_party/xla/xla/hlo/builder/lib/pooling_test.cc index 54ef5f43b49f94..97b874d81c04ce 100644 --- a/third_party/xla/xla/client/lib/pooling_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/pooling.h" +#include "xla/hlo/builder/lib/pooling.h" #include #include #include "absl/container/inlined_vector.h" #include "absl/types/span.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/prng.cc b/third_party/xla/xla/hlo/builder/lib/prng.cc similarity index 99% rename from third_party/xla/xla/client/lib/prng.cc rename to third_party/xla/xla/hlo/builder/lib/prng.cc index 370382238adf4a..7bafd7bf5b8e22 100644 --- a/third_party/xla/xla/client/lib/prng.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/prng.h" +#include "xla/hlo/builder/lib/prng.h" #include #include @@ -27,8 +27,8 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/prng.h b/third_party/xla/xla/hlo/builder/lib/prng.h new file mode 100644 index 00000000000000..89b4dd62bbcd14 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/prng.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_PRNG_H_ +#define XLA_HLO_BUILDER_LIB_PRNG_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Records the bits and state generated by a random number generator. +struct RngOutput { + XlaOp value; + XlaOp state; +}; + +// A BitGenerator returns random bits and updated random bit generator state. +// +// key: is a value input to a random number generator that can affect the +// sequence of number it will generate. A random number generator constructs +// its seed using the key and the initial state. The tf2xla bridge passes the +// seed operand of a tensorflow random operation as a key to the random bit +// generator, for example. +// initial_state: initial_state is the initial state of the current random +// number generation. It could be 0 for a stateless random operation, and +// the returned state from a previous execution for a stateful random +// operation. +// shape: the shape of the random bits. +using BitGeneratorTy = std::function; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, + const xla::Shape& shape); + +// Implements the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +// +// The paper presents a few variants of the Philox algorithm, we picked the +// 4x32_10 version of the algorithm for the following reasons: +// . 4x32 uses 32-bit multiplication which is fast on GPUs. +// . The authors recommend the 10-round variant, and TensorFlow also uses it. +RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, + const Shape& shape); +// Returns a scrambled pair of (state, key) from a single key. +std::pair ScramblePhiloxKey(XlaOp key); + +// Uses the given bit generator to generate random bits and then converts the +// random bits to random numbers of uniform distribution in the given range. +// Returns the random numbers and the state of the random number generator. +// This function is for shape with floating point element types. +RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, + XlaOp minval, XlaOp maxval, + const xla::Shape& shape); + +// Similar to UniformFloatingPointDistribution but for shape with integer +// element types. +RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const xla::Shape& shape); + +// Uses the given bit generator to generate random bits and then converts the +// random bits to random numbers of normal distribution. +// Returns the random numbers and the state of the random number generator. +RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, + const xla::Shape& shape); + +// Concatenates scalars into a vector. +xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, + absl::Span scalars); + +// Increases Philox counter (an uint128_t) by a delta (an uint64_t). +xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_PRNG_H_ diff --git a/third_party/xla/xla/client/lib/prng_test.cc b/third_party/xla/xla/hlo/builder/lib/prng_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/prng_test.cc rename to third_party/xla/xla/hlo/builder/lib/prng_test.cc index 22241e9fab1da9..0e5f9772c35d26 100644 --- a/third_party/xla/xla/client/lib/prng_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/prng.h" +#include "xla/hlo/builder/lib/prng.h" #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/test.h" diff --git a/third_party/xla/xla/client/lib/qr.cc b/third_party/xla/xla/hlo/builder/lib/qr.cc similarity index 96% rename from third_party/xla/xla/client/lib/qr.cc rename to third_party/xla/xla/hlo/builder/lib/qr.cc index 794b2e4887f8b2..699e13b4c2e181 100644 --- a/third_party/xla/xla/client/lib/qr.cc +++ b/third_party/xla/xla/hlo/builder/lib/qr.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/qr.h" +#include "xla/hlo/builder/lib/qr.h" #include #include @@ -21,10 +21,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/qr.h b/third_party/xla/xla/hlo/builder/lib/qr.h new file mode 100644 index 00000000000000..6e4f3cc15fa4ec --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/qr.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_QR_H_ +#define XLA_HLO_BUILDER_LIB_QR_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Computes the QR decompositions of a batch of matrices. That is, +// given a (batched) matrix a, computes an orthonormal matrix Q and an +// upper-triangular matrix R such that a = QR. +// `a` must be a (batched) matrix of size [..., m, n]. +struct QrDecomposition { + // A matrix with the same shape as the input matrix `a`, whose upper triangle + // (inclusive of the diagonal) is the matrix R, and whose lower triangle + // (exclusive of the diagonal) contains the elementary Householder reflectors. + // This is the same output format as used by LAPACK's xGEQRF routine. + XlaOp q_and_r; + // A vector of shape [..., min(m, n)] containing the scalar factors of the + // elementary Householder reflectors. + XlaOp taus; +}; + +QrDecomposition Qr(XlaOp a); + +// Given `a` and `taus` as returned by `QRDecomposition`, compute the product of +// the elementary Householder reflectors (i.e., the matrix Q of the QR +// decomposition). The equivalent LAPACK routine is xORGQR/xUNGQR. +XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus); + +// Helper that combines `Qr` and `ProductOfElementaryHouseholderReflectors` to +// compute explicit matrices `q` and `r`. +void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_QR_H_ diff --git a/third_party/xla/xla/client/lib/qr_test.cc b/third_party/xla/xla/hlo/builder/lib/qr_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/qr_test.cc rename to third_party/xla/xla/hlo/builder/lib/qr_test.cc index fc9e583ab9ad12..9f8e28e53cef66 100644 --- a/third_party/xla/xla/client/lib/qr_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/qr_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/qr.h" +#include "xla/hlo/builder/lib/qr.h" #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/hlo/builder/lib/quantize.h b/third_party/xla/xla/hlo/builder/lib/quantize.h new file mode 100644 index 00000000000000..d0126f0c021b2f --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/quantize.h @@ -0,0 +1,184 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_QUANTIZE_H_ +#define XLA_HLO_BUILDER_LIB_QUANTIZE_H_ + +#include +#include +#include +#include + +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/bfloat16.h" + +namespace xla { + +// Represents the range used for quantization +struct QuantizedRange { + QuantizedRange() = default; + QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} + + bool operator==(const QuantizedRange& rhs) const { + return this->min == rhs.min && this->max == rhs.max; + } + + bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } + + tsl::bfloat16 min = tsl::bfloat16(0.0f); + tsl::bfloat16 max = tsl::bfloat16(0.0f); +}; + +template +inline std::vector PackToUint32(absl::Span input) { + const int64_t kElementsPerPack = sizeof(uint32_t) / sizeof(T); + const int64_t input_size = input.size(); + const int64_t output_size = CeilOfRatio(input_size, kElementsPerPack); + + std::vector output_vec; + constexpr int64_t kShiftBits = sizeof(T) / sizeof(uint8_t) * CHAR_BIT; + + for (int64_t i = 0; i < output_size; i++) { + uint32_t result = 0; + for (int64_t p = 0; p < kElementsPerPack; p++) { + int64_t index = i * kElementsPerPack + p; + if (index < input_size) { + int64_t total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); + result |= (input[index] << total_shift_bits); + } + } + output_vec.push_back(result); + } + + return output_vec; +} + +// Dequantize the quantized input of packed uint32_t to bfloat16. +// Only uint8_t or uint16_t is supported for the original unpacked input. +// Returns a tensor of shape [d0,..., dn * unpack_size] if +// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). +// If transpose_output is true, will return a tensor of shape +// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when +// input's rank higher than 1. The input needs to be transposed to use +// transpose_output feature. +template +inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, + absl::string_view mode_string = "MIN_COMBINED", + bool transpose_output = false) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + std::numeric_limits::min() + 1) / + 2.0f; + const int64_t unpack_size = sizeof(uint32_t) / sizeof(T); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); + + auto element_type = shape.element_type(); + if (element_type != U32) { + return InvalidArgument( + "Only U32 is supported for input type of xla::Dequantize Op."); + } + + // Broadcast the input to [unpack_size, d0, ..., dn] if input size is + // [d0, ..., dn]. + auto broadcast_input = Broadcast(input, {unpack_size}); + + XlaOp iota_r1 = Iota(builder, U32, unpack_size); + // Highest significant bytes needs to shift more bytes than lower + // significant bytes. + XlaOp shift_bytes = + xla::ConstantR0(builder, unpack_size - 1) - iota_r1; + + const int bytes_of_type = sizeof(T) / sizeof(uint8_t); + std::vector shift_vec(unpack_size, CHAR_BIT * bytes_of_type); + XlaOp shift_bits = + shift_bytes * xla::ConstantR1(builder, shift_vec); + + // Make bit_mask for different data type T. + uint32_t bit_mask = 0x00000000; + for (int i = 0; i < bytes_of_type; i++) { + bit_mask <<= CHAR_BIT; + bit_mask |= 0x000000ff; + } + + std::vector shift_transpose_dimensions(shape.dimensions_size()); + std::iota(shift_transpose_dimensions.begin(), + shift_transpose_dimensions.end(), 0); + shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, + shape.dimensions_size()); + + // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. + XlaOp shifted_input = ShiftRightLogical( + broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), + shift_transpose_dimensions)); + XlaOp unpack_input = + And(shifted_input, xla::ConstantR0(builder, bit_mask)); + + XlaOp result; + + if (mode_string == "MIN_COMBINED") { + const tsl::bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + // result = bfloat16(input + half_range) * scale_factor + range.min + XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); + XlaOp half_range_bf16 = xla::ConstantR0( + builder, static_cast(half_range)); + XlaOp sum = unpack_input_bf16 + half_range_bf16; + + result = sum * xla::ConstantR0(builder, scale_factor) + + xla::ConstantR0(builder, range.min); + } else { + // TODO(wangtao): support other modes. + return InvalidArgument( + "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); + } + + std::vector transpose_dimensions(shape.dimensions_size()); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); + std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); + transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); + + // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. + XlaOp transposed_result = Transpose(result, transpose_dimensions); + + // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. + XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); + + // Return the transpose result if transpose_output is true. + if (transpose_output) { + return reshaped_result; + } + + // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. + std::vector result_dimensions(shape.dimensions_size()); + std::iota(result_dimensions.begin(), result_dimensions.end(), 0); + std::reverse(result_dimensions.begin(), result_dimensions.end()); + + return Transpose(reshaped_result, result_dimensions); + }); +} + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_QUANTIZE_H_ diff --git a/third_party/xla/xla/client/lib/quantize_test.cc b/third_party/xla/xla/hlo/builder/lib/quantize_test.cc similarity index 99% rename from third_party/xla/xla/client/lib/quantize_test.cc rename to third_party/xla/xla/hlo/builder/lib/quantize_test.cc index 6f371404f12869..6520bb4a07fef1 100644 --- a/third_party/xla/xla/client/lib/quantize_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/quantize_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/quantize.h" +#include "xla/hlo/builder/lib/quantize.h" #include #include #include "xla/array2d.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig.cc b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc similarity index 95% rename from third_party/xla/xla/client/lib/self_adjoint_eig.cc rename to third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc index 05ba43f9fabcef..a7f3a3c00b6933 100644 --- a/third_party/xla/xla/client/lib/self_adjoint_eig.cc +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/self_adjoint_eig.h" +#include "xla/hlo/builder/lib/self_adjoint_eig.h" #include #include @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h new file mode 100644 index 00000000000000..f0dffdc41218bf --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h @@ -0,0 +1,41 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_SELF_ADJOINT_EIG_H_ +#define XLA_HLO_BUILDER_LIB_SELF_ADJOINT_EIG_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// The eigenvalue decomposition of a symmetric matrix, the original matrix is +// recovered by v * w * v_t. +struct SelfAdjointEigResult { + // The i-th column is the normalized eigenvector corresponding to the + // eigenvalue w[i]. Will return a matrix object if a is a matrix object. + XlaOp v; + // The eigenvalues in ascending order, each repeated according to its + // multiplicity. + XlaOp w; +}; + +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, + int64_t max_iter = 15, float tol = 1e-5, + bool sort_eigenvalues = true); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_SELF_ADJOINT_EIG_H_ diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig_test.cc b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/self_adjoint_eig_test.cc rename to third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc index 4265c844d3848f..cb4645aea60bbf 100644 --- a/third_party/xla/xla/client/lib/self_adjoint_eig_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/self_adjoint_eig.h" +#include "xla/hlo/builder/lib/self_adjoint_eig.h" #include #include @@ -25,12 +25,12 @@ limitations under the License. #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/client/lib/slicing.cc b/third_party/xla/xla/hlo/builder/lib/slicing.cc similarity index 98% rename from third_party/xla/xla/client/lib/slicing.cc rename to third_party/xla/xla/hlo/builder/lib/slicing.cc index 26c5ea59ff1931..42dd4c8a82d188 100644 --- a/third_party/xla/xla/client/lib/slicing.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/slicing.h" +#include "xla/hlo/builder/lib/slicing.h" #include #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/slicing.h b/third_party/xla/xla/hlo/builder/lib/slicing.h new file mode 100644 index 00000000000000..dfb880805d2153 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/slicing.h @@ -0,0 +1,83 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/types.h" + +#ifndef XLA_HLO_BUILDER_LIB_SLICING_H_ +#define XLA_HLO_BUILDER_LIB_SLICING_H_ + +namespace xla { + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); + +// Performs a slice in the minor dimensions of a tensor. +// x[..., start[0]:end[0], ..., start[n]:end[n]] +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0]:..., ..., start[n]:...] = update +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start); + +// Performs a dynamic slice in the minor dimensions of a tensor. +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes); + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts); + +// Gathers values along an axis specified by dim. +// +// For a 3-D tensor the output is specified by: +// +// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 +// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 +// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 +// +// If `input` is an n-dimensional tensor with size +// [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size +// [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as +// `index`. +XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse = true); + +// idx = index[i][j][k] +// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0 +// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1 +// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2 +XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, + const std::function& combiner); + +// Returns a new tensor which indexes the input tensor along dimension dim using +// the entries in index. +// +// The returned tensor has the same number of dimensions as the original tensor +// (input). The dimth dimension has the same size as the length of index; other +// dimensions have the same size as in the original tensor. +// +// This operation supports 0 or more major batch dimensions that act like a +// multidimensional loop over both the input and the index. +XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, + int64_t batch_dims = 0); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_SLICING_H_ diff --git a/third_party/xla/xla/client/lib/slicing_test.cc b/third_party/xla/xla/hlo/builder/lib/slicing_test.cc similarity index 99% rename from third_party/xla/xla/client/lib/slicing_test.cc rename to third_party/xla/xla/hlo/builder/lib/slicing_test.cc index 8dfc55f521f089..72e8e1ca7026d8 100644 --- a/third_party/xla/xla/client/lib/slicing_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/slicing.h" +#include "xla/hlo/builder/lib/slicing.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/client/lib/sorting.cc b/third_party/xla/xla/hlo/builder/lib/sorting.cc similarity index 97% rename from third_party/xla/xla/client/lib/sorting.cc rename to third_party/xla/xla/hlo/builder/lib/sorting.cc index 48eec5d5ff2f7c..456accc515e111 100644 --- a/third_party/xla/xla/client/lib/sorting.cc +++ b/third_party/xla/xla/hlo/builder/lib/sorting.cc @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/sorting.h" +#include "xla/hlo/builder/lib/sorting.h" #include #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/sorting.h b/third_party/xla/xla/hlo/builder/lib/sorting.h new file mode 100644 index 00000000000000..b951f26b97b043 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/sorting.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_SORTING_H_ +#define XLA_HLO_BUILDER_LIB_SORTING_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopK(XlaOp input, int64_t k, + PrimitiveType index_type = PrimitiveType::S32); + +// Split sort in TopK into smaller sorts. +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions = 1, + PrimitiveType index_type = PrimitiveType::S32); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_SORTING_H_ diff --git a/third_party/xla/xla/client/lib/sorting_test.cc b/third_party/xla/xla/hlo/builder/lib/sorting_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/sorting_test.cc rename to third_party/xla/xla/hlo/builder/lib/sorting_test.cc index 02eeff7ad80f22..2230eb73ecc4fb 100644 --- a/third_party/xla/xla/client/lib/sorting_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/sorting_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/sorting.h" +#include "xla/hlo/builder/lib/sorting.h" #include #include @@ -24,8 +24,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "xla/array.h" #include "xla/array2d.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/svd.cc b/third_party/xla/xla/hlo/builder/lib/svd.cc similarity index 98% rename from third_party/xla/xla/client/lib/svd.cc rename to third_party/xla/xla/hlo/builder/lib/svd.cc index 88afe31e2ed0c3..22e4ab8d039bdc 100644 --- a/third_party/xla/xla/client/lib/svd.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/svd.h" +#include "xla/hlo/builder/lib/svd.h" #include #include @@ -24,14 +24,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/svd.h b/third_party/xla/xla/hlo/builder/lib/svd.h new file mode 100644 index 00000000000000..42d165f766ab43 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/svd.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_SVD_H_ +#define XLA_HLO_BUILDER_LIB_SVD_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// The singular value decomposition of a given matrix A[..., M, N], the original +// matrix is recovered by u * diag(d) * v_t, where the first dims(A) - 2 +// dimensions are batch dimensions. +struct SVDResult { + // The columns of U are the left-singular vectors, e.g., + // U[..., :, :]_T * U[..., :, :] = I. + XlaOp u; + // Vector(s) with the singular values, within each vector sorted in descending + // order. The first dims(D) - 1 dimensions have the same size as the batch + // dimensions of A. And U[..., :, i] * D[..., i] = A[..., :, :] * V[..., :, + // i]. + XlaOp d; + // The columns of V are the right-singular vectors. e.g., + // V[..., :, :]_T * V[..., :, :] = I. + XlaOp v; +}; + +// TODO(kuny): Add a bool flag that supports SVD with economy (reduced) +// representation, which is more memory efficient, especially in the case of +// tall-skinny matrices. +SVDResult SVD(XlaOp a, int64_t max_iter = 100, float epsilon = 1e-6, + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_SVD_H_ diff --git a/third_party/xla/xla/client/lib/svd_test.cc b/third_party/xla/xla/hlo/builder/lib/svd_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/svd_test.cc rename to third_party/xla/xla/hlo/builder/lib/svd_test.cc index f1a7fc62a1c2e4..7266cde21684fe 100644 --- a/third_party/xla/xla/client/lib/svd_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/svd.h" +#include "xla/hlo/builder/lib/svd.h" #include #include @@ -22,12 +22,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/client/lib/tridiagonal.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc similarity index 99% rename from third_party/xla/xla/client/lib/tridiagonal.cc rename to third_party/xla/xla/hlo/builder/lib/tridiagonal.cc index 4d4a4604e5ce23..9538a742e4cfce 100644 --- a/third_party/xla/xla/client/lib/tridiagonal.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/tridiagonal.h" +#include "xla/hlo/builder/lib/tridiagonal.h" #include #include @@ -24,10 +24,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/tridiagonal.h b/third_party/xla/xla/hlo/builder/lib/tridiagonal.h new file mode 100644 index 00000000000000..d6bf56c009c2a7 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_TRIDIAGONAL_H_ +#define XLA_HLO_BUILDER_LIB_TRIDIAGONAL_H_ + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace tridiagonal { + +enum SolverAlgorithm { kThomas }; + +absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, + XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, XlaOp rhs); + +absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, + XlaOp rhs); + +absl::StatusOr TridiagonalMatMul(XlaOp upper_diagonal, + XlaOp main_diagonal, + XlaOp lower_diagonal, XlaOp rhs); + +} // namespace tridiagonal +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_TRIDIAGONAL_H_ diff --git a/third_party/xla/xla/client/lib/tridiagonal_test.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/tridiagonal_test.cc rename to third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc index 280e4dd8ec17ae..5948c8840303e1 100644 --- a/third_party/xla/xla/client/lib/tridiagonal_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/tridiagonal.h" +#include "xla/hlo/builder/lib/tridiagonal.h" #include #include @@ -22,8 +22,8 @@ limitations under the License. #include "absl/status/status.h" #include "xla/array.h" #include "xla/array3d.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/client/lib/tuple.cc b/third_party/xla/xla/hlo/builder/lib/tuple.cc similarity index 96% rename from third_party/xla/xla/client/lib/tuple.cc rename to third_party/xla/xla/hlo/builder/lib/tuple.cc index 4cefa748bc8d04..6a0145addefbde 100644 --- a/third_party/xla/xla/client/lib/tuple.cc +++ b/third_party/xla/xla/hlo/builder/lib/tuple.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/tuple.h" +#include "xla/hlo/builder/lib/tuple.h" #include #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/tuple.h b/third_party/xla/xla/hlo/builder/lib/tuple.h new file mode 100644 index 00000000000000..11d7d022806aef --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/tuple.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_TUPLE_H_ +#define XLA_HLO_BUILDER_LIB_TUPLE_H_ + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape_tree.h" + +namespace xla { + +// Returns a ShapeTree where each index is a GetTupleElement instruction for +// that subshape of the tuple. The root index is the original argument. +absl::StatusOr> DisassembleTuple(XlaOp tuple); + +// Assembles a tuple from a ShapeTree that contains the leaves of the tuple. +// Non-leaf elements of the ShapeTree are ignored. DisassembleTuple and +// AssembleTuple are essentially inverse operations. +XlaOp AssembleTuple(XlaBuilder* builder, ShapeTree elements); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_TUPLE_H_ diff --git a/third_party/xla/xla/client/lib/tuple_test.cc b/third_party/xla/xla/hlo/builder/lib/tuple_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/tuple_test.cc rename to third_party/xla/xla/hlo/builder/lib/tuple_test.cc index cb2cab8abd0bed..67f270300acce4 100644 --- a/third_party/xla/xla/client/lib/tuple_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/tuple_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/tuple.h" +#include "xla/hlo/builder/lib/tuple.h" #include #include #include -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/service.h" diff --git a/third_party/xla/xla/client/padding.cc b/third_party/xla/xla/hlo/builder/padding.cc similarity index 99% rename from third_party/xla/xla/client/padding.cc rename to third_party/xla/xla/hlo/builder/padding.cc index 8f4a536c0805e4..b8951735619e92 100644 --- a/third_party/xla/xla/client/padding.cc +++ b/third_party/xla/xla/hlo/builder/padding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include #include diff --git a/third_party/xla/xla/hlo/builder/padding.h b/third_party/xla/xla/hlo/builder/padding.h new file mode 100644 index 00000000000000..b0c83b7587a1ef --- /dev/null +++ b/third_party/xla/xla/hlo/builder/padding.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_PADDING_H_ +#define XLA_HLO_BUILDER_PADDING_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" + +namespace xla { + +// Describes the padding applied for a windowed operation like +// convolution, where a window is placed inside a base area. +enum class Padding { + // Make the output have the same dimensions as the base area. For + // example, for a 3x3 base area and a 2x2 window, the output will be + // 3x3, so that requires padding the 3x3 base area to 4x4. + kSame, + + // Use no padding. For example, for a 4x4 base area and a 2x2 + // window, the output will be 3x3. + kValid, +}; + +// Validates that the slices are acceptable for determining padding -- this can +// be used to check the preconditions of MakePadding below to produce an error +// message that can be returned to the user. +absl::Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides); + +// Returns the padding needed for the base area, given the base area dimensions, +// window dimensions, strides, and the type of padding. +// +// If v is the returned vector, then for each dimension number i, +// v[i].first is the padding to the left (i.e. in the direction of +// lower indices) and v[i].second is the padding to the right (i.e. in +// the direction of higher indices). +// +// Precondition: The number of dimensions (i.e., rank) in input_dimensions, +// window_dimensions, and strides must match, which is equal to the number +// of elements in the result vector. +std::vector> MakePadding( + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_PADDING_H_ diff --git a/third_party/xla/xla/client/padding_test.cc b/third_party/xla/xla/hlo/builder/padding_test.cc similarity index 98% rename from third_party/xla/xla/client/padding_test.cc rename to third_party/xla/xla/hlo/builder/padding_test.cc index 0d183d0e16ede9..2d06a84cd3da4e 100644 --- a/third_party/xla/xla/client/padding_test.cc +++ b/third_party/xla/xla/hlo/builder/padding_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include diff --git a/third_party/xla/xla/client/sharding_builder.cc b/third_party/xla/xla/hlo/builder/sharding_builder.cc similarity index 98% rename from third_party/xla/xla/client/sharding_builder.cc rename to third_party/xla/xla/hlo/builder/sharding_builder.cc index 7b179b8c91ee4a..2c01cb16203c2b 100644 --- a/third_party/xla/xla/client/sharding_builder.cc +++ b/third_party/xla/xla/hlo/builder/sharding_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/sharding_builder.h" #include diff --git a/third_party/xla/xla/hlo/builder/sharding_builder.h b/third_party/xla/xla/hlo/builder/sharding_builder.h new file mode 100644 index 00000000000000..245ab9e2c004d7 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/sharding_builder.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_SHARDING_BUILDER_H_ +#define XLA_HLO_BUILDER_SHARDING_BUILDER_H_ + +#include + +#include "xla/array.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace sharding_builder { +// A shaped array used to describe the assignment of tiles to devices. +using TileAssignment = Array; + +// Creates a replicated sharding - replicate a tensor on every device. +OpSharding Replicate(); + +// Creates a manual sharding - the partitioner will not change the shape. +OpSharding Manual(); + +// Creates a sharding that assigns a tensor to just one device. +OpSharding AssignDevice(int device); + +// Creates a tiled sharding with the given tile shape and assignment of tiles +// to devices. +// +// If tile_shape is not evenly divisible by the number of devices in +// tile_assignment, operations behave as if implicit padding had been inserted. +// The value of this padding is undefined. +OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment); + +// Creates a sharding in one dimension, with the given tile shape which must +// be rank 1 and using devices [0..num_tiles). +// +// This is simply a convenience wrapper for Tile(). +OpSharding Tile1D(const Shape& tile_shape, int64_t num_tiles); + +// Creates a tuple sharding from the given ShapeTree of element shardings. +OpSharding Tuple(const ShapeTree& shardings); + +} // namespace sharding_builder +} // namespace xla + +#endif // XLA_HLO_BUILDER_SHARDING_BUILDER_H_ diff --git a/third_party/xla/xla/client/value_inference.cc b/third_party/xla/xla/hlo/builder/value_inference.cc similarity index 99% rename from third_party/xla/xla/client/value_inference.cc rename to third_party/xla/xla/hlo/builder/value_inference.cc index 2f0b6e20756bff..165ee203042443 100644 --- a/third_party/xla/xla/client/value_inference.cc +++ b/third_party/xla/xla/hlo/builder/value_inference.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/value_inference.h" +#include "xla/hlo/builder/value_inference.h" #include #include @@ -29,8 +29,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_computation.h" diff --git a/third_party/xla/xla/hlo/builder/value_inference.h b/third_party/xla/xla/hlo/builder/value_inference.h new file mode 100644 index 00000000000000..7f69a5979553dc --- /dev/null +++ b/third_party/xla/xla/hlo/builder/value_inference.h @@ -0,0 +1,117 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_HLO_BUILDER_VALUE_INFERENCE_H_ +#define XLA_HLO_BUILDER_VALUE_INFERENCE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +// OptionalLiteral is an augmented literal class which returns optional +// values for each index (the value can be either valid or invalid). The +// implementation keeps two literals, a value literal, holding both the valid +// and garabage value, and a masking literal representing if a value is valid or +// garbage. +class OptionalLiteral { + public: + explicit OptionalLiteral(Literal value, Literal mask) + : value_(std::move(value)), mask_(std::move(mask)) {} + + template + std::optional Get(absl::Span element_index, + ShapeIndex shape_index = {}) const { + if (mask_.Get(element_index, shape_index)) { + return std::nullopt; + } else { + return value_.Get(element_index, shape_index); + } + } + + // Returns true if all values in this literal slice are value. + bool AllValid() { return mask_.IsAll(0); } + + // Get value out of this slice if all values are valid. Otherwise returns + // nullopt. + std::optional GetValue() { + if (!AllValid()) { + return std::nullopt; + } + return LiteralSlice(value_); + } + + private: + Literal value_; + Literal mask_; +}; + +enum ValueInferenceMode { + // Inference the constant value itself. + kValue = 0, + // Inference upper-bound and lower-bound of the value. Bounds are inclusive. + kUpperBound, + kLowerBound, +}; + +class ValueInference { + public: + // ValueInference analyzes values in XlaOp answers following questions: + // - What's the upper-bound of each value in a tensor. + // - What's the lower-bound of each value in a tensor. + // - What's the constant value of each tensor. + // - Whether or not each value in a tensor is dynamic. + explicit ValueInference(XlaBuilder* builder) : builder_(builder) { + CHECK(builder_); + } + absl::StatusOr AnalyzeIsDynamic(XlaOp op); + // Returns an OptionalLiteral. Each individual value of the literal is + // the concrete constant value if it can be inferred, otherwise a nullopt. + absl::StatusOr AnalyzeConstant(XlaOp op, + ValueInferenceMode mode); + + // Returns underlying xla builder. + XlaBuilder* builder() { return builder_; } + + private: + // Given an op handle, returns a simplified version of the handle inside a + // int64_t Literal. If the a -1 value for the handle means invalid + // simplification and the result shouldn't be used. + absl::StatusOr SimplifyOp(int64_t handle); + + // Perform CSE on a given handle, and return an equivalent handle if seen + // before. Otherwise, returns nullopt. + absl::StatusOr> CseOpHandle(int64_t handle); + XlaBuilder* builder_; + HloEvaluator evaluator_; + // A map from instruction_hash to handle that helps perform CSE. + absl::flat_hash_map cse_map_; +}; +} // namespace xla + +#endif // XLA_HLO_BUILDER_VALUE_INFERENCE_H_ diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc similarity index 99% rename from third_party/xla/xla/client/xla_builder.cc rename to third_party/xla/xla/hlo/builder/xla_builder.cc index 98e7dada978400..140addbcc026a4 100644 --- a/third_party/xla/xla/client/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include #include @@ -44,10 +44,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" -#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" diff --git a/third_party/xla/xla/hlo/builder/xla_builder.h b/third_party/xla/xla/hlo/builder/xla_builder.h new file mode 100644 index 00000000000000..891d4fec725c69 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/xla_builder.h @@ -0,0 +1,3086 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_XLA_BUILDER_H_ +#define XLA_HLO_BUILDER_XLA_BUILDER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/array.h" +#include "xla/array2d.h" +#include "xla/array3d.h" +#include "xla/array4d.h" +#include "xla/comparison_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/dynamic_parameter_binding.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/core/bitmap.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/stacktrace.h" + +namespace xla { + +class XlaBuilder; +class XlaOp; +class HloInstruction; + +namespace internal { + +struct XlaBuilderFriend { + static XlaOp BuildAddDependency(XlaBuilder* builder, XlaOp operand, + XlaOp token, const Shape& shape); + + static std::pair BuildAsyncStart( + XlaBuilder* builder, absl::Span operands, + std::string execution_thread, const XlaComputation& called_computation, + const Shape& shape); + static XlaOp BuildAsyncUpdate(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + static XlaOp BuildAsyncDone(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + + static XlaOp BuildAllGatherStart( + XlaBuilder* builder, XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + static XlaOp BuildAllGatherDone(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + + static XlaOp BuildAllReduceStart( + XlaBuilder* builder, XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + static XlaOp BuildAllReduceDone(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + + static XlaOp BuildCollectivePermuteStart( + XlaBuilder* builder, XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id = std::nullopt); + static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + + static XlaOp BuildCopyStart( + XlaBuilder* builder, XlaOp operand, + std::optional cross_program_prefetch_index = std::nullopt); + static XlaOp BuildCopyDone(XlaBuilder* builder, XlaOp operand, + const Shape& shape); + + static XlaOp BuildFusion( + XlaBuilder* builder, absl::Span operands, + absl::string_view fusion_kind, const XlaComputation& fused_computation, + absl::Span>> + output_operand_aliasing = {}); + + static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand, + const Shape& shape); + + static XlaOp BuildPartitionId(XlaBuilder* builder, const Shape& shape); + + static XlaOp BuildSend(XlaBuilder* builder, XlaOp operand, XlaOp token, + const ChannelHandle& handle, bool is_host_transfer); + static XlaOp BuildSendDone(XlaBuilder* builder, XlaOp operand, + const ChannelHandle& handle, + bool is_host_transfer); + + static XlaOp BuildRecv(XlaBuilder* builder, XlaOp token, const Shape& shape, + const ChannelHandle& handle, bool is_host_transfer); + static XlaOp BuildRecvDone(XlaBuilder* builder, XlaOp token, + const Shape& shape, const ChannelHandle& handle, + bool is_host_transfer); + + static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand, OpSharding entry, + OpSharding exit, const Shape& shape); + + static XlaOp BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta, + const Shape& shape); + + static HloInstructionProto* GetInstruction(XlaOp op); + static HloInstructionProto* GetInstructionByHandle(XlaBuilder* builder, + int64_t handle); +}; + +} // namespace internal + +// This represents an instruction that has been enqueued using the XlaBuilder. +// This is used to pass to subsequent computations that depends upon the +// instruction as an operand. +class XlaOp { + public: + XlaOp() : handle_(-1), builder_(nullptr) { + static_assert(std::is_trivially_destructible::value, + "XlaOp should be trivially destructible"); + } + ~XlaOp() = default; + + XlaOp(const XlaOp& other) = default; + XlaOp& operator=(const XlaOp& other) = default; + + // Precondition: !IsUninitialized(). + // + // It's very common to do foo.builder()->bar(). Without this precondition, if + // foo.builder() is null, the call to bar will segfault at some point possibly + // deep in the callstack when we finally dereference `this`. The precondition + // lets us avoid this tricky-to-debug problem. + XlaBuilder* builder() const { + CHECK(builder_ != nullptr); + return builder_; + } + + // Returns true if the XlaOp represents valid, non-erroneous value. + bool valid() const { return handle_ >= 0; } + + // Returns true if the XlaOp was created by the XlaOp() constructor and + // not returned by a builder. + bool IsUninitialized() const { return builder_ == nullptr; } + + bool IsIdenticalTo(XlaOp rhs) const { + return handle_ == rhs.handle_ && builder_ == rhs.builder_; + } + + friend std::ostream& operator<<(std::ostream& out, XlaOp op) { + out << op.handle(); + return out; + } + + private: + explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} + XlaOp(int64_t handle, XlaBuilder* builder) + : handle_(handle), builder_(builder) {} + + int64_t handle() const { return handle_; } + + friend class XlaBuilder; + friend class ValueInference; + friend struct internal::XlaBuilderFriend; + + // < 0 means "invalid handle". + int64_t handle_; + + // Not owned. Non-null for any handle returned by XlaBuilder, even if the + // handle is invalid. + XlaBuilder* builder_; +}; + +// Arithmetic operator overloads for the XlaOp type. +XlaOp operator-(XlaOp x); +XlaOp operator+(XlaOp x, XlaOp y); +XlaOp operator-(XlaOp x, XlaOp y); +XlaOp operator*(XlaOp x, XlaOp y); +XlaOp operator/(XlaOp x, XlaOp y); +XlaOp operator%(XlaOp x, XlaOp y); + +// Bitwise operator overloads for the XlaOp type. +XlaOp operator~(XlaOp x); +XlaOp operator&(XlaOp x, XlaOp y); +XlaOp operator|(XlaOp x, XlaOp y); +XlaOp operator^(XlaOp x, XlaOp y); +XlaOp operator<<(XlaOp x, XlaOp y); +// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs +// a right logical shift. +XlaOp operator>>(XlaOp x, XlaOp y); + +// We don't overload the relational operators (==, !=, <, <=, >, >=) because the +// semantics might be surprising since their result types are usually 'bool'. +// Further programmers may expect == to be a structural equality. +// We also choose not to overload any of the mutating operators (e.g., +=, -=) +// because the semantics might be misleading — XLA computations are immutable. + +// A convenient interface for building up computations. +// +// Thread-compatible. +class XlaBuilder { + public: + // computation_name: name to use for the built computation. + explicit XlaBuilder(const std::string& computation_name); + + XlaBuilder(const XlaBuilder&) = delete; + XlaBuilder& operator=(const XlaBuilder&) = delete; + + virtual ~XlaBuilder(); + + // Returns the computation name. + const std::string& name() const { return name_; } + + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the computation builder. All subsequent + // instructions generated via this computation builder will have the same + // OpMetadata attached until a call to ClearOpMetadata. + void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } + + // Swaps the passed op metadata with the ones currently set. + // + // Returns the old op metadata. + OpMetadata SwapOpMetadata(OpMetadata metadata) { + OpMetadata old_metadata = std::move(metadata_); + metadata_ = std::move(metadata); + return old_metadata; + } + + // Similar to SetOpMetadata, but only set the metadata for the next op. + void SetOneShotOpMetadata(OpMetadata metadata) { + one_shot_metadata_ = std::move(metadata); + } + + // Clears the HloMetadata state. + void ClearOpMetadata() { metadata_.Clear(); } + + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + + // Sets the FrontendAttributes that will be added to all instructions until + // cleared. + // + // FrontendAttributes are often applied to a series of XLA HLO instructions. + // As a result they are set on the computation builder and all the + // instructions generated via the computation builder will have the same + // frontend attributes attached to them. + virtual void SetFrontendAttributes( + const FrontendAttributes& frontend_attributes) { + frontend_attributes_ = frontend_attributes; + } + + // Swap the passed FrontendAttributes with the ones currently set. + // + // Return the old attributes. + FrontendAttributes SwapFrontendAttributes( + const FrontendAttributes& frontend_attributes) { + FrontendAttributes old_attributes = std::move(frontend_attributes_); + frontend_attributes_ = frontend_attributes; + return old_attributes; + } + + // Returns the FrontendAttributes that will be attached to all instructions. + const FrontendAttributes& frontend_attributes() const { + return frontend_attributes_; + } + + // Clears all the frontend attributes. + void ClearFrontendAttributes() { frontend_attributes_.Clear(); } + + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = std::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const std::optional& sharding() const { return sharding_; } + + // Sets the builder to a mode where it will die immediately when an error is + // encountered, rather than producing it in a deferred fashion when Build() is + // called (which is the default). + void set_die_immediately_on_error(bool enabled) { + die_immediately_on_error_ = enabled; + } + + // Default dimension numbers used for a 2D convolution. + static constexpr int64_t kConvBatchDimension = 0; + static constexpr int64_t kConvFeatureDimension = 1; + static constexpr int64_t kConvFirstSpatialDimension = 2; + static constexpr int64_t kConvSecondSpatialDimension = 3; + static constexpr int64_t kConvKernelOutputDimension = 0; + static constexpr int64_t kConvKernelInputDimension = 1; + static constexpr int64_t kConvKernelFirstSpatialDimension = 2; + static constexpr int64_t kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static absl::Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder( + const std::string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. The root of the computation will be the last + // added operation. + // + // `remove_dynamic_dimensions` tells the builder whether to remove the + // dynamic dimensions information in all ops. + // + // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the + // dynamic dimensions information when XLA backend can handle dynamic + // dimensions. + absl::StatusOr Build(bool remove_dynamic_dimensions = false); + + // Overload of Build which specifies a particular root instruction for the + // computation. + absl::StatusOr Build(XlaOp root, + bool remove_dynamic_dimensions = false); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + absl::StatusOr BuildConstantSubGraph( + XlaOp root_op, bool dynamic_dimension_is_minus_one = false); + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + absl::Status first_error() const { return first_error_; } + + // Returns the current status of the builder, complete with the stack trace + // information. + absl::Status GetCurrentStatus() const; + + // Returns the shape of the given op. + absl::StatusOr GetShape(XlaOp op) const; + + // Returns the shape of the given op. + virtual absl::StatusOr GetShapePtr(XlaOp op) const; + + // Returns the OpSharding of the given op. If "op" has no sharding, return + // std::nullopt. + absl::StatusOr> GetOpSharding(XlaOp op) const; + + // Returns the (inferred) result for the current computation's shape. This + // assumes the root instruction is the last added instruction. + absl::StatusOr GetProgramShape() const; + + // Returns the (inferred) result for the current computation's shape using the + // given operation as the root. + absl::StatusOr GetProgramShape(XlaOp root) const; + + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const absl::Status& error); + + // A helper function that converts a absl::StatusOr into an XlaOp. + // If the absl::Status was an error, reports the error to builder and returns + // an invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const absl::StatusOr& op); + + // A helper function that runs a function that returns a absl::StatusOr + // and returns an XlaOp. + XlaOp ReportErrorOrReturn( + absl::FunctionRef()> op_creator); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + absl::StatusOr IsConstant(XlaOp operand) const; + + // Adds a new input/output alias. Since the input/output shape information are + // not available until the computation is built, any eventual error in the + // arguments of this API will be detected only at computation Build() time. + // + // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias' + // and only donated buffer at runtime will be aliased with output. If a buffer + // is not donated at runtime, a copy will be inserted by XLA to prevent buffer + // clobbering. + void SetUpAlias(const ShapeIndex& output_index, int64_t param_number, + const ShapeIndex& param_index, + HloInputOutputAliasConfig::AliasKind kind = + HloInputOutputAliasConfig::AliasKind::kMayAlias) { + input_output_aliases_.push_back( + {output_index, param_number, param_index, kind}); + } + + // Describes an input/output alias as inserted by the SetUpAlias() API. + struct InputOutputAlias { + // Specifies the index of the aliased buffer in the result tuple. + ShapeIndex output_index; + // Specifies the parameter containing the buffer to be aliased. + int64_t param_number; + // Specifies the index of the aliased buffer in the parameter. + ShapeIndex param_index; + // Specifies if the alias is a must alias or may alias. + HloInputOutputAliasConfig::AliasKind kind; + }; + + // Adds a new buffer donor. The donated buffer may be paired with any valid + // output. On the contrary, the buffer aliasing bonds the input output pair. + // The input can only donate the buffer to the paired output. + void AddBufferDonor(int64_t param_number, const ShapeIndex& param_index) { + buffer_donors_.insert({param_number, param_index}); + } + + // Looks up the HloInstruction and sets the frontend attribute "attribute" to + // "value". If the attribute already existed, then its value is updated. + // + // The attribute is only added to the HloInstruction, not to the builder. + absl::Status SetInstructionFrontendAttribute(XlaOp op, std::string attribute, + std::string value); + + // Looks up the HloInstruction and sets the sharding. If the sharding already + // existed, then its value is updated. + // + // The sharding is only added to the HloInstruction, not to the builder. + absl::Status SetInstructionSharding( + XlaOp op, const std::optional& sharding); + + // Returns shapes for the operands. + absl::StatusOr> GetOperandShapes( + absl::Span operands) const; + + // Converts the op to string for the ease of debugging. + std::string OpToString(XlaOp op) const; + + private: + void ToStringHelper(std::string* out, int ident, int64_t op_handle) const; + + // Build helper which takes the id of the root operation.. + absl::StatusOr Build(int64_t root_id, + bool remove_dynamic_dimensions); + + // Description for the methods below can be found in the corresponding public + // functions section in this file. + + XlaOp Parameter(int64_t parameter_number, const Shape& shape, + const std::string& name, + const std::vector& replicated_at_leaf_buffers); + XlaOp Parameter(int64_t parameter_number, const Shape& shape, + const std::string& name) { + std::vector empty_bools; + return Parameter(parameter_number, shape, name, empty_bools); + } + + virtual XlaOp ConstantLiteral(const LiteralSlice& literal); + + XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); + + XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, + absl::Span broadcast_dimensions); + + // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim + // op from the XlaBuilder. This is only intended for export to MHLO or + // StableHLO, and cannot be compiled. Only static output_dimensions are + // allowed, and broadcast_dimensions is verified. + XlaOp MhloDynamicBroadcastInDim( + XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + + XlaOp Pad(XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); + XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, + int64_t pad_lo, int64_t pad_hi); + + virtual absl::StatusOr PadInternal( + const Shape& shape, XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); + + XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span new_sizes, + int64_t inferred_dimension = -1); + + XlaOp Reshape(XlaOp operand, absl::Span new_sizes, + int64_t inferred_dimension = -1); + + XlaOp Reshape(const Shape& shape, XlaOp operand, + int64_t inferred_dimension = -1); + + XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + + XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, + const Shape& shape); + + XlaOp Collapse(XlaOp operand, absl::Span dimensions); + + XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + virtual absl::StatusOr SliceInternal( + const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, int64_t stride, int64_t dimno); + + XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); + virtual absl::StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); + + XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, + absl::Span start_indices); + virtual absl::StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices); + + XlaOp ConcatInDim(absl::Span operands, int64_t dimension); + virtual absl::StatusOr ConcatInDimInternal( + const Shape& shape, absl::Span operands, int64_t dimension); + + XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); + + XlaOp Tuple(absl::Span elements); + virtual absl::StatusOr TupleInternal(const Shape& shape, + absl::Span elements); + + XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); + virtual absl::StatusOr GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64_t index); + + XlaOp Dot(XlaOp lhs, XlaOp rhs, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp DotGeneral( + XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp SparseDot( + XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp Conv( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, int64_t feature_group_count = 1, + int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp ConvWithGeneralPadding( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp ConvWithGeneralDimensions( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp ConvGeneralDilated( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt, + std::optional> window_reversal = std::nullopt); + + XlaOp DynamicConvForward( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + + XlaOp DynamicConvInputGrad( + XlaOp input_sizes, XlaOp lhs, XlaOp rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + + XlaOp DynamicConvKernelGrad( + XlaOp activations, XlaOp gradients, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + + absl::StatusOr DynamicConvInstruction( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + + virtual absl::StatusOr ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config); + + XlaOp Fft(XlaOp operand, FftType fft_type, + absl::Span fft_length); + virtual absl::StatusOr FftInternal( + const Shape& shape, XlaOp operand, FftType fft_type, + absl::Span fft_length); + + virtual absl::StatusOr TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options); + + virtual absl::StatusOr CholeskyInternal(const Shape& shape, XlaOp a, + bool lower); + + XlaOp Infeed(const Shape& shape, const std::string& config = ""); + XlaOp InfeedWithToken(XlaOp token, const Shape& shape, + const std::string& config); + virtual absl::StatusOr InfeedWithTokenInternal( + const Shape& infeed_instruction_shape, XlaOp token, + const std::string& config); + + void Outfeed(XlaOp operand, const Shape& shape_with_layout, + const std::string& outfeed_config); + XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, + const Shape& shape_with_layout, + const std::string& outfeed_config); + virtual absl::StatusOr OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const std::string& outfeed_config); + XlaOp Call(const XlaComputation& computation, + absl::Span operands); + + XlaOp CompositeCall( + const XlaComputation& computation, absl::Span operands, + const std::string& name, + std::optional attributes = std::nullopt, + std::optional version = std::nullopt); + + XlaOp CustomCall( + const std::string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const std::string& opaque, + std::optional> operand_shapes_with_layout, + bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, std::optional window, + std::optional dnums, + CustomCallSchedule schedule, CustomCallApiVersion api_version); + + // Internal version of CustomCall without computation that doesn't do op + // specific error handling and expects arguments to be legal. CustomCall + // method above calls this method after error handling. + virtual absl::StatusOr CustomCallInternal( + const std::string& call_target_name, absl::Span operands, + const XlaComputation* computation, const Shape& shape_with_layout, + const std::string& opaque, + std::optional> operand_shapes_with_layout, + bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, std::optional window, + std::optional dnums, + CustomCallSchedule schedule, CustomCallApiVersion api_version); + + // TODO(b/239474321) Remove this overload as it has simply led to code + // duplication. + XlaOp CustomCall( + const std::string& call_target_name, absl::Span operands, + const XlaComputation& computation, const Shape& shape_with_layout, + const std::string& opaque, + std::optional> operand_shapes_with_layout, + bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, CustomCallSchedule schedule, + CustomCallApiVersion api_version); + + XlaOp OptimizationBarrier(XlaOp operand); + + XlaOp Reduce(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + + XlaOp Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + + virtual absl::StatusOr ReduceInternal( + const Shape& shape, absl::Span all_operands, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + + XlaOp ReduceAll(XlaOp operand, XlaOp init_value, + const XlaComputation& computation); + + XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + + XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + + XlaOp ReduceWindowWithGeneralPadding( + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + absl::StatusOr ReduceWindowInternal( + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + virtual absl::StatusOr ReduceWindowInternal( + const Shape& shape, XlaOp operand, XlaOp init_value, + const XlaComputation& computation, Window window); + XlaOp CrossReplicaSum(XlaOp operand, + absl::Span replica_groups = {}); + + XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + + XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& shape_with_layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + + XlaOp ReduceScatter( + XlaOp operand, const XlaComputation& computation, + int64_t scatter_dimension, int64_t shard_count, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + + XlaOp AllToAll(XlaOp operand, int64_t split_dimension, + int64_t concat_dimension, int64_t split_count, + absl::Span replica_groups, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); + + XlaOp AllToAllTuple( + absl::Span operands, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id = std::nullopt); + + XlaOp AllToAllTuple( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id = std::nullopt); + + XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); + + XlaOp CollectivePermute( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id = std::nullopt); + + XlaOp ReplicaId(); + + XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, XlaOp source, XlaOp init_value, + const XlaComputation& scatter); + + XlaOp SelectAndScatterWithGeneralPadding( + XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + + absl::StatusOr SelectAndScatterInternal( + XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + + virtual XlaOp Iota(const Shape& shape, int64_t iota_dimension); + + XlaOp Iota(PrimitiveType type, int64_t size); + + XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); + + XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); + virtual absl::StatusOr BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand); + + XlaOp StochasticConvertType(XlaOp operand, XlaOp random, + PrimitiveType new_element_type); + + XlaOp Transpose(XlaOp operand, absl::Span permutation); + virtual absl::StatusOr TransposeInternal( + const Shape& shape, XlaOp operand, absl::Span permutation); + + XlaOp Rev(XlaOp operand, absl::Span dimensions); + virtual absl::StatusOr RevInternal( + const Shape& shape, XlaOp operand, absl::Span dimensions); + + XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64_t dimension = -1, bool is_stable = false); + virtual absl::StatusOr SortInternal(const Shape& shape, + absl::Span operands, + const XlaComputation& comparator, + int64_t dimension, bool is_stable); + + XlaOp TopK(XlaOp operand, int64_t k, bool largest); + virtual absl::StatusOr TopKInternal(const Shape& shape, XlaOp operand, + int64_t k, bool largest); + + XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); + + XlaOp Map(absl::Span operands, const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands = {}); + + XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); + + XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); + + XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, + const Shape& shape); + // Internal variant for the op with the full result shape containing both data + // and state shape as a tuple. + virtual absl::StatusOr RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state); + + XlaOp While(const XlaComputation& condition, const XlaComputation& body, + XlaOp init); + virtual absl::StatusOr WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init); + + XlaOp Conditional(XlaOp predicate, XlaOp true_operand, + const XlaComputation& true_computation, XlaOp false_operand, + const XlaComputation& false_computation); + + XlaOp Conditional(XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + + XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); + virtual absl::StatusOr ReducePrecisionInternal(const Shape& shape, + XlaOp operand, + int exponent_bits, + int mantissa_bits); + + XlaOp Gather(XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, + bool indices_are_sorted = false); + + virtual absl::StatusOr GatherInternal( + const Shape& shape, XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, bool indices_are_sorted); + + XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false, bool unique_indices = false); + XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, + absl::Span updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false, bool unique_indices = false); + + virtual absl::StatusOr ScatterInternal( + const Shape& shape, absl::Span inputs, XlaOp scatter_indices, + absl::Span updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, + bool unique_indices); + + void Send(XlaOp operand, const ChannelHandle& handle); + XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); + + XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const ChannelHandle& handle); + + XlaOp RecvFromHost(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + + virtual XlaOp CreateToken(); + + XlaOp AfterAll(absl::Span tokens); + + XlaOp Recv(const Shape& shape, const ChannelHandle& handle); + XlaOp RecvWithToken(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + + XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, + float epsilon, int64_t feature_index); + + XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, + XlaOp variance, float epsilon, + int64_t feature_index); + + XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, + XlaOp batch_var, XlaOp grad_output, float epsilon, + int64_t feature_index); + + XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); + + XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); + + virtual absl::StatusOr SetDimensionSizeInternal(const Shape& shape, + XlaOp operand, + XlaOp val, + int64_t dimension); + + XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); + + virtual absl::StatusOr AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands); + absl::StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode) { + return AddInstruction(std::move(instr), opcode, /*operands=*/{}); + } + + void AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr); + + absl::StatusOr LookUpInstruction(XlaOp op) const; + absl::StatusOr LookUpInstructionByHandle( + int64_t handle) const; + absl::StatusOr LookUpMutableInstruction(XlaOp op); + absl::StatusOr LookUpMutableInstructionByHandle( + int64_t handle); + + // Internal helper method that does the building for an arbitrary unary op. + virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand); + + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. The direction is + // only used if opcode is kCompare. + XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + std::optional direction = std::nullopt, + std::optional type = std::nullopt); + + absl::StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, + ComparisonDirection direction); + + // Internal helper method for binary op compare without broadcast dimensions. + virtual absl::StatusOr Compare(const Shape& shape, XlaOp lhs, + XlaOp rhs, + ComparisonDirection direction, + Comparison::Type type); + + // Internal helper method that does the building for an arbitrary binary op + // with same ranked operands that doesn't broadcast. + virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, + XlaOp lhs, XlaOp rhs); + + // Internal helper method that does the building for an arbitrary ternary op. + XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs); + + XlaOp RngOp(RandomDistribution distribution, + absl::Span parameters, const Shape& shape); + + virtual absl::StatusOr RngOpInternal( + RandomDistribution distribution, absl::Span parameters, + const Shape& shape); + + virtual absl::StatusOr InDimBroadcast( + const Shape& shape, XlaOp operand, + absl::Span broadcast_dimensions); + + // Internal helper method that creates a sequence of instructions that + // performs an explicit broadcast of the operand to the target shape. + // All dimensions of the operand must either be equal to the corresponding + // output shape dimension, or be exactly 1. (Such dimensions are the + // degenerate dimensions.) + absl::StatusOr AddBroadcastSequence(const Shape& output_shape, + XlaOp operand); + + // Internal helper method that broadcasts a scalar to the shape of the output. + absl::StatusOr BroadcastScalarToOutputShape(XlaOp scalar, + XlaOp output); + + // Internal helper method for creating a Reshape op with the already inferred + // shape. + virtual absl::StatusOr ReshapeInternal(const Shape& shape, + XlaOp operand, + int64_t inferred_dimension); + + // Returns the (inferred) result for the program shape using the given root. + absl::StatusOr GetProgramShape(int64_t root_id) const; + + // A visitor which checks whether an operation is a compile-time constant, + // meaning that it doesn't depend on any parameters, or on any stateful + // operation such as `RngNormal` or `Infeed`. The visitor walks the + // computation starting at a given operation and sets is_constant to false iff + // a parameter or stateful operation is encountered. + void IsConstantVisitor(int64_t op_handle, int depth, + absl::flat_hash_set* visited, + bool* is_constant) const; + + // Checks bounds for convolution parameters. + absl::Status VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const; + + int64_t GetNextId() { return ++next_id_; } + + // Populates the module with the input/output alias information stored within + // the input_output_aliases vector. + static absl::Status PopulateInputOutputAliasAndBufferDonor( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases, + const absl::flat_hash_set& + buffer_donors); + + std::string name_; // Name to use for the built computation. + + // The next sequential ID for every instruction/computation contained within + // this computation. + int64_t next_id_ = 0; + + // The first error encountered while building the computation. + // This is OK until the first error is encountered. + absl::Status first_error_; + + // The saved stack trace from the point at which the first error occurred. + tsl::SavedStackTrace first_error_backtrace_; + + // The instructions of this computation. + // Use a deque so pointers into this are stable, for example the return + // value of LookUpInstructionByHandle(). + std::deque instructions_; + // A cache for the HloInstructionProto shapes, to avoid recreating Shape + // objects from protos and to support the GetShapePtr() API. + std::vector> instruction_shapes_; + + // Dynamic parameter configuration of this computation. + DynamicParameterBinding dynamic_parameter_binding_; + + // Holds the input/output alias information populated by the SetUpAlias() API. + std::vector input_output_aliases_; + + // Holds the buffer donor information populated by the AddBufferDonor() API. + absl::flat_hash_set buffer_donors_; + + // A map from XlaOp::Handle to the index in the instructions_ vector where the + // instruction is held. + absl::flat_hash_map handle_to_index_; + + // Track imported instructions by their computation id and the position in + // their computation's instruction list. + struct ImportedInstruction { + int64_t computation_id; + int64_t instruction_index; + }; + + absl::flat_hash_map handle_to_imported_index_; + + // The embedded computations used by this computation. Each computation was + // the entry computation of some XlaComputation, the key is the unique id of + // that XlaComputation. + std::map embedded_; + + // The unique parameter numbers. + absl::flat_hash_set parameter_numbers_; + + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_; + + // A temporary metadata that will only be applied to the next op created. + std::optional one_shot_metadata_; + + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + std::optional sharding_; + + // Mode bit that indicates whether to die when a first error is encountered. + bool die_immediately_on_error_ = false; + + XlaBuilder* parent_builder_{nullptr}; + + FrontendAttributes frontend_attributes_; + + friend XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, + const Shape& shape, const std::string& name, + const std::vector& replicated_at_leaf_buffers); + friend XlaOp ConstantLiteral(XlaBuilder* builder, + const LiteralSlice& literal); + + friend XlaOp Broadcast(XlaOp operand, + absl::Span broadcast_sizes); + + friend XlaOp BroadcastInDim(XlaOp operand, + absl::Span out_dim_size, + absl::Span broadcast_dimensions); + + friend XlaOp MhloDynamicBroadcastInDim( + XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + + friend XlaOp Copy(XlaOp operand); + + friend XlaOp Pad(XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); + + friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, + int64_t pad_lo, int64_t pad_hi); + + friend XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span new_sizes); + + friend XlaOp Reshape(XlaOp operand, absl::Span new_sizes); + + friend XlaOp Reshape(const Shape& shape, XlaOp operand); + + friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + + friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, + const Shape& shape); + + friend XlaOp ReshapeWithInferredDimension(XlaOp operand, + absl::Span new_sizes, + int64_t inferred_dimension); + + friend XlaOp Collapse(XlaOp operand, absl::Span dimensions); + + friend XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + + friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, int64_t stride, int64_t dimno); + + friend XlaOp DynamicSlice(XlaOp operand, + absl::Span start_indices, + absl::Span slice_sizes); + + friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, + absl::Span start_indices); + + friend XlaOp ConcatInDim(XlaBuilder* builder, + absl::Span operands, int64_t dimension); + + friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); + friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); + friend XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, + Comparison::Type compare_type); + friend XlaOp Dot(XlaOp lhs, XlaOp rhs, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + virtual absl::StatusOr DotGeneralInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config); + friend XlaOp SparseDot(XlaOp lhs, XlaOp rhs, + absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp Conv(XlaOp lhs, XlaOp rhs, + absl::Span window_strides, Padding padding, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp ConvWithGeneralPadding( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp ConvWithGeneralDimensions( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp DynamicConvForward( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type); + friend XlaOp DynamicConvKernelGrad( + XlaOp activations, XlaOp gradients, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type); + friend XlaOp DynamicConvInputGrad( + XlaOp input_sizes, XlaOp lhs, XlaOp rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type); + + friend XlaOp ConvKernelGrad( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + + friend XlaOp ConvGeneralDilated( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type, + std::optional> window_reversal); + + friend XlaOp Fft(XlaOp operand, FftType fft_type, + absl::Span fft_length); + friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); + friend XlaOp Cholesky(XlaOp a, bool lower); + friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const std::string& config); + friend void Outfeed(XlaOp operand, const Shape& shape_with_layout, + const std::string& outfeed_config); + friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + absl::Span operands); + + friend XlaOp CompositeCall(XlaBuilder* builder, + const XlaComputation& computation, + absl::Span operands, + const std::string& name, + std::optional attributes, + std::optional version); + + friend XlaOp CustomCall( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape, + const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, CustomCallSchedule schedule, + CustomCallApiVersion api_version); + friend XlaOp CustomCallWithComputation( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const XlaComputation& computation, + const Shape& shape, const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, CustomCallSchedule schedule, + CustomCallApiVersion api_version); + friend XlaOp CustomCallWithLayout( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, CustomCallSchedule schedule, + CustomCallApiVersion api_version); + friend XlaOp CustomCallWithConvDnums( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, Window window, ConvolutionDimensionNumbers dnums, + CustomCallSchedule schedule, CustomCallApiVersion api_version); + friend XlaOp OptimizationBarrier(XlaOp operand); + friend XlaOp Complex(XlaOp real, XlaOp imag, + absl::Span broadcast_dimensions); + friend XlaOp Conj(XlaOp operand); + friend XlaOp Add(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Sub(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Mul(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Div(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Rem(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Max(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Min(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp And(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Or(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Xor(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Not(XlaOp operand); + friend XlaOp PopulationCount(XlaOp operand); + friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp ShiftRightArithmetic( + XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); + friend XlaOp ShiftRightLogical( + XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); + friend XlaOp Reduce(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value, + const XlaComputation& computation); + friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); + friend XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); + friend XlaOp ReduceWindowWithGeneralPadding( + XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + friend XlaOp ReduceWindowWithGeneralPadding( + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + + friend XlaOp CrossReplicaSum(XlaOp operand, + absl::Span replica_groups); + friend XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids); + friend XlaOp AllGatherTuple(absl::Span operands, + int64_t all_gather_dimension, int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids); + friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& shape_with_layout, + std::optional use_global_device_ids); + friend XlaOp AllReduceTuple(absl::Span operand, + const XlaComputation& computation, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& shape_with_layout, + std::optional use_global_device_ids); + friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation, + int64_t scatter_dimension, int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids); + + friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension, + int64_t concat_dimension, int64_t split_count, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id); + friend XlaOp AllToAllTuple(absl::Span operands, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id); + friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension, + int64_t concat_dimension, int64_t split_count, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id); + friend XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id); + friend XlaOp CollectivePermute( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id); + friend XlaOp ReplicaId(XlaBuilder* builder); + friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, XlaOp source, XlaOp init_value, + const XlaComputation& scatter); + friend XlaOp SelectAndScatterWithGeneralPadding( + XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + friend XlaOp Abs(XlaOp operand); + friend XlaOp Atan2(XlaOp y, XlaOp x, + absl::Span broadcast_dimensions); + friend XlaOp Erf(XlaOp operand); + friend XlaOp Exp(XlaOp operand); + friend XlaOp Expm1(XlaOp operand); + friend XlaOp Floor(XlaOp operand); + friend XlaOp Ceil(XlaOp operand); + friend XlaOp Round(XlaOp operand); + friend XlaOp RoundNearestEven(XlaOp operand); + friend XlaOp Log(XlaOp operand); + friend XlaOp Log1p(XlaOp operand); + friend XlaOp Logistic(XlaOp operand); + friend XlaOp Sign(XlaOp operand); + friend XlaOp Clz(XlaOp operand); + friend XlaOp Cos(XlaOp operand); + friend XlaOp Sin(XlaOp operand); + friend XlaOp Tan(XlaOp operand); + friend XlaOp Tanh(XlaOp operand); + friend XlaOp Real(XlaOp operand); + friend XlaOp Imag(XlaOp operand); + friend XlaOp Sqrt(XlaOp operand); + friend XlaOp Rsqrt(XlaOp operand); + friend XlaOp Cbrt(XlaOp operand); + friend XlaOp Pow(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp IsFinite(XlaOp operand); + friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, + int64_t iota_dimension); + friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size); + friend XlaOp ConvertElementType(XlaOp operand, + PrimitiveType new_element_type); + friend XlaOp BitcastConvertType(XlaOp operand, + PrimitiveType new_element_type); + friend XlaOp StochasticConvertType(XlaOp operand, XlaOp random, + PrimitiveType new_element_type); + friend XlaOp Neg(XlaOp operand); + friend XlaOp Transpose(XlaOp operand, absl::Span permutation); + friend XlaOp Rev(XlaOp operand, absl::Span dimensions); + friend XlaOp Sort(absl::Span operands, + const XlaComputation& comparator, int64_t dimension, + bool is_stable); + friend XlaOp TopK(XlaOp operand, int64_t k, bool largest); + friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); + friend XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands); + friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); + friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); + friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, + const Shape& shape); + friend XlaOp While(const XlaComputation& condition, + const XlaComputation& body, XlaOp init); + friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand, + const XlaComputation& true_computation, + XlaOp false_operand, + const XlaComputation& false_computation); + friend XlaOp Conditional( + XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + friend XlaOp ConditionalImpl( + XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + friend XlaOp ReducePrecision(XlaOp operand, int exponent_bits, + int mantissa_bits); + friend XlaOp Gather(XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, + bool indices_are_sorted); + friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted, bool unique_indices); + friend XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, + absl::Span updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted, bool unique_indices); + friend void Send(XlaOp operand, const ChannelHandle& handle); + friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, + float epsilon, int64_t feature_index); + friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, + XlaOp mean, XlaOp variance, float epsilon, + int64_t feature_index); + friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, + XlaOp batch_var, XlaOp grad_output, float epsilon, + int64_t feature_index); + friend XlaOp SendWithToken(XlaOp operand, XlaOp token, + const ChannelHandle& handle); + friend XlaOp RecvWithToken(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp SendToHost(XlaOp operand, XlaOp token, + const Shape& shape_with_layout, + const ChannelHandle& handle); + friend XlaOp RecvFromHost(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape, + const std::string& config); + friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, + const Shape& shape_with_layout, + const std::string& outfeed_config); + friend XlaOp CreateToken(XlaBuilder* builder); + friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); + + friend XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); + friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); + friend XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); + + protected: + // Returns OK status if the given op was built using this builder. Otherwise, + // returns an error. + absl::Status CheckOpBuilder(XlaOp op) const; + + private: + XlaOp AllGatherImpl(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids, bool async); + + XlaOp AllReduceImpl(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids, bool async); + + XlaOp CollectiveBroadcastImpl(XlaOp operand, + absl::Span replica_groups, + const std::optional& channel_id); + + XlaOp CollectivePermuteImpl( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id, bool async); + + XlaOp ConditionalImpl( + XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + + XlaOp AllToAllArray( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); + + // Creates an op with the given opcode and the output shape. + virtual absl::StatusOr AddOpWithShape( + HloOpcode opcode, const Shape& shape, absl::Span operands); + + // Here, InstructionType is either const HloInstructionProto* or non-const + // HloInstructionProto*. + template + absl::StatusOr LookUpInstructionByHandleInternal( + int64_t handle) const { + auto it = handle_to_index_.find(handle); + if (it == handle_to_index_.end()) { + // Try look for the instruction in the imported instructions. + auto imported_it = handle_to_imported_index_.find(handle); + if (imported_it != handle_to_imported_index_.end()) { + ImportedInstruction imported = imported_it->second; + return const_cast( + &embedded_.at(imported.computation_id) + .instructions(imported.instruction_index)); + } + return InvalidArgument("No XlaOp with handle %d", handle); + } + return const_cast(&instructions_.at(it->second)); + } + + // Here, InstructionType is either const HloInstructionProto* or non-const + // HloInstructionProto*. + // + // TODO(hinsu): Return const pointer within absl::StatusOr and use + // absl::implicit_cast at callsites. This requires implicit_cast support in + // absl::StatusOr similar to absl::StatusOr. + template + absl::StatusOr LookUpInstructionInternal(XlaOp op) const { + TF_RETURN_IF_ERROR(CheckOpBuilder(op)); + return LookUpInstructionByHandleInternal(op.handle()); + } + + friend struct internal::XlaBuilderFriend; + + friend class ValueInference; +}; + +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class XlaScopedShardingAssignment { + public: + XlaScopedShardingAssignment(xla::XlaBuilder* builder, + std::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; + XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = + delete; + + ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const std::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::XlaBuilder* const builder_; + std::optional prev_sharding_; +}; + +// RAII-style object: save the current builder's frontend attributes, and merge +// them with the new ones on construction. +// Restore the original attributes on destruction. +class XlaScopedFrontendAttributesAssignment { + public: + XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder, + FrontendAttributes attributes) + : builder_(builder) { + saved_ = builder_->SwapFrontendAttributes(attributes); + } + + ~XlaScopedFrontendAttributesAssignment() { + builder_->SetFrontendAttributes(saved_); + } + + private: + xla::XlaBuilder* const builder_; + FrontendAttributes saved_; + + XlaScopedFrontendAttributesAssignment( + const XlaScopedFrontendAttributesAssignment&) = delete; + XlaScopedFrontendAttributesAssignment& operator=( + const XlaScopedFrontendAttributesAssignment&) = delete; +}; + +// RAII-style object: sets the current op metadata in builder on construction, +// and sets back to the previous assignment on destruction. +class XlaScopedOpMetadataAssignment { + public: + XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata) + : builder_(builder) { + saved_ = builder_->SwapOpMetadata(metadata); + } + + ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); } + + private: + xla::XlaBuilder* const builder_; + OpMetadata saved_; + + XlaScopedOpMetadataAssignment(const XlaScopedOpMetadataAssignment&) = delete; + XlaScopedOpMetadataAssignment& operator=( + const XlaScopedOpMetadataAssignment&) = delete; +}; + +// Free functions for building XlaOps. The intention is that these will +// become the public API for building XlaOps rather than calling methods on +// XlaBuilder directly. +// + +// Enqueues a "retrieve parameter value" instruction for a parameter that was +// passed to the computation. +XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, + const Shape& shape, const std::string& name); + +// Same as above, but with leaf buffer replication annotation. +XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, + const Shape& shape, const std::string& name, + const std::vector& replicated_at_leaf_buffers); + +// Enqueues a constant with the value of the given literal onto the +// computation. +XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); + +// Enqueues a constant onto the computation. Methods are templated on the +// native host type (NativeT) which corresponds to a specific XLA +// PrimitiveType as given in the following table: +// +// Native Type PrimitiveType +// ----------------------------- +// bool PRED +// int32_t S32 +// int64_t S64 +// uint32_t U32 +// uint64_t U64 +// float F32 +// double F64 +// +// Note: not all primitive types defined in xla_data.proto have a +// corresponding native type yet. +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value); +template +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values); +XlaOp ConstantR1(XlaBuilder* builder, const tsl::core::Bitmap& values); +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values); +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values); +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + +// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the +// computation. The vector has size 'length' and every element has the value +// 'value'. +template +XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value); + +// Adds dimensions to an array by duplicating the data in the array. +// +// The new dimensions are inserted on the left, i.e. if +// broadcast_sizes has values {a0, ..., aN} and the operand shape +// has dimensions {b0, ..., bM} then the shape of the output has +// dimensions {a0, ..., aN, b0, ..., bM}. +// +// The new dimensions index into copies of the operand, i.e. +// +// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] +XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); + +// This op broadcasts the `operand` to an output with the given `shape`. +// `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the +// i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th +// dimension of the output. This also requires that the i'th input dimension is +// either 1 or is the same as the output dimension it's broadcasting into. +// +// For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the +// output shape is s32[2,2]: +// - Specifying {1} as broadcast_dimension will generate output +// {{1, 2}, +// {1, 2}} +// - On the other hand, specifying {0} as broadcast_dimension +// will generate output +// {{1 , 1}, +// {2 , 2}} +XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, + absl::Span broadcast_dimensions); + +// This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim +// op from the XlaBuilder. This is only intended for export to MHLO or +// StableHLO, and cannot be compiled. See +// https://www.tensorflow.org/mlir/hlo_ops#mhlodynamic_broadcast_in_dim_mhlodynamicbroadcastindimop. +// for the op semantics. +XlaOp MhloDynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + +// Copies the input operand to the output. This operation is for internal +// purpose and is only used by the compiler for optimization purposes or to +// ensure correctness. The XLA client should never have to generate this +// instruction. +// +// Copy has two potential use cases: +// +// * Create a copy of the operand with a new layout. +// +// * Create a copy of the operand in a separately allocated buffer. This is +// necessary for some backends if the operand is a parameter or constant and +// the operand is returned within a tuple. In this case, the lifetime of the +// operand buffer must be the same as the lifetime of the output result. +// However, the lifetimes of parameters and constants are managed separately +// from the lifetime of the output result. Creating a separate copy of the +// parameter or constant buffer resolves this issue. +XlaOp Copy(XlaOp operand); + +// Enqueues a pad operation onto the computation that pads the given value on +// the edges as well as between the elements of the input. padding_config +// specifies the padding amount for each dimension. +XlaOp Pad(XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); + +// Enqueues a pad operation in a given dimension, taking all other +// dimensions as they are. +XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, + int64_t pad_lo, int64_t pad_hi); + +// Enqueues an operation onto the computation that flattens the operand based +// on the dimension order (major/slowest-varying to minor/fastest-varying) +// given, followed by reshaping it into the shape with the given dimension +// sizes (also major to minor). Conceptually, this is a limited form of +// "shape casting". +XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span new_sizes); + +// Enqueues a dynamic reshape operation. The dynamic reshape takes additional +// XlaOps as sizes for the result dimension. The result dim i is a dynamic +// dimension dimension if dims_are_dynamic[i] is true. +XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + +// This is an experimental API for creating the mhlo.dynamic_reshape op from the +// XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot +// be compiled. +XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); + +// Enqueues an operation onto the computation that collapses the operand, +// from first to last dimension (C order), then reshapes it to the given +// dimension sizes. Conceptually, this is a limited form of "shape casting". +XlaOp Reshape(XlaOp operand, absl::Span new_sizes); + +// Enqueues a Reshape op that uses an explicit target shape. +XlaOp Reshape(const Shape& shape, XlaOp operand); + +// `inferred_dimension` represents the output dimension that's inferred by +// upper-level framework by dividing the input element count by the known +// output element count. While an inferred_dimension can be static, if there +// is a dynamic dimension in the output, it must be the inferred dimension. +XlaOp ReshapeWithInferredDimension(XlaOp operand, + absl::Span new_sizes, + int64_t inferred_dimension); + +// Wrapper for Reshape. +// Enqueues an operation to collapse the provided dimensions; e.g. an +// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to +// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must +// be a consecutive, in-order subsequence of the operand dimensions. +// +// Note that collapsing a single dimension does nothing: +// +// {256} collapsing {0} => {256} +// {1} collapsing {0} => {1} +// +// Collapsing multiple dimensions produces a single result dimension: +// +// {256, 2} collapsing {0,1} => {512} +// {256, 2, 3} collapsing {0,1} => {512, 3} +// +// This could potentially cause data to be moved -- it provides a more +// structured form of reshaping than an arbitrary Reshape operation. +XlaOp Collapse(XlaOp operand, absl::Span dimensions); + +// Enqueues a slice operation onto the computation that slices the operand +// from the start indices to the limit indices; e.g. +// +// x +// [ 0 1 2 3 ] +// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] +// [ 8 9 a b ] +// +// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D +// range notation. +// The strides parameter determines the stride over the slice +XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + +// Enqueues a slice operation in a given dimension, taking all other +// dimensions as they are; e.g. if dimno is 1 from start_index 2 to +// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand +// for: +// +// array[:, 2:4:1, :] +XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, + int64_t stride, int64_t dimno); + +// Enqueues a slice operation onto the computation that slices the 'operand' +// from dynamic start indices which are passed in 'start_indices'. +// The size of the slice in each dimension is passed in 'slice_sizes', +// which specify the end point of exclusive slice intervals in each +// dimension [start, start + size). +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. +// Slice index calculations are computed modulo input dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); + +// Enqueues a dynamic update slice operation onto the computation, which +// updates a slice of 'operand' with 'update' at dynamic 'start_indices'. +// The shape of 'update' determines the shape of the slice of 'operand' +// which is updated. +// The indices specified in 'start_indices' specify the offset of the slice +// of 'operand' which is updated. +// +// update = {10, 11} // calculated at runtime. +// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] +// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] +// [7 8 9] [7 8 9 ] +// +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. +// Slice index calculations are computed modulo update dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, + absl::Span start_indices); + +// Enqueues a concatenate instruction onto the computation. 'operands' must +// have >= 1 entry. +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, + int64_t dimension); + +// Enqueues a conditional-move-like select operation onto the computation; +// predicated on pred, selects between on_true and on_false. +XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); + +// Enqueues a tuple-creation instruction onto the computation. +XlaOp Tuple(XlaBuilder* builder, absl::Span elements); + +// Enqueues a tuple-element-get instruction onto the computation. +XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); + +// Enqueues an equal-to comparison instruction onto the computation. +XlaOp Eq(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a not-equal comparison instruction onto the computation. +XlaOp Ne(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a greater-or-equal comparison instruction onto the computation. +XlaOp Ge(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a greater-than comparison instruction onto the computation. +XlaOp Gt(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a less-than comparison instruction onto the computation. +XlaOp Lt(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a less-or-equal comparison instruction onto the computation. +XlaOp Le(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a comparison instruction onto the computation (optionally without +// broadcast_dimensions for consistency with others). +XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, Comparison::Type compare_type); +XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); +XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); + +// Enqueues a dot instruction onto the computation. +XlaOp Dot(XlaOp lhs, XlaOp rhs, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a general dot instruction onto the computation. +XlaOp DotGeneral( + XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a sparse dot instruction onto the computation. +XlaOp SparseDot( + XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, which uses the +// default convolution dimension numbers. +XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, int64_t feature_group_count = 1, + int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration in the format returned by MakePadding(). +XlaOp ConvWithGeneralPadding( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided dimension numbers configuration. +XlaOp ConvWithGeneralDimensions( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration as well as the dimension numbers. +XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration, dilation factors and dimension numbers. +XlaOp ConvGeneralDilated( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt, + std::optional> window_reversal = std::nullopt); + +XlaOp DynamicConvForward( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + +XlaOp DynamicConvInputGrad( + XlaOp input_sizes, XlaOp lhs, XlaOp rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + +XlaOp DynamicConvKernelGrad( + XlaOp activations, XlaOp gradients, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + +// Enqueues an FFT instruction onto the computation, of the given type and +// with the given FFT length. +XlaOp Fft(XlaOp operand, FftType fft_type, + absl::Span fft_length); + +// Solves systems of linear equations with lower or upper triangular coefficient +// matrices by forward- or back-substitution. Broadcasting along leading +// dimensions, this routine solves for x in one of the matrix systems +// `op(a) * x = b`, or `x * op(a) = b`, +// for the variable `x` given `a` and `b`, where `op(a)` is either +// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. +// +// * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form +// square matrices. If `lower` is true (false), then the strictly upper +// (lower) triangular part of each innermost matrix in `a` is assumed to be +// zero and is not accessed. +// * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a +// tensor of shape `[..., K, M]`. +// * `left_side` is a boolean, indicating whether to solve a system of the form +// op(a) * x = b (true) or x * op(a) = b (false). +// * `lower` is a boolean, indicating whether the argument `a` is +// lower-triangular (true) or upper-triangular (false). +// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be +// 1 and not accessed. +// * `transpose_a` indicates which function `op` we use to transform the tensor +// `a`: the identity function, transpose(a), or conjugate(transpose(a)) +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); + +// Computes the Cholesky decompositions of a batch of symmetric (Hermitian) +// positive definite matrices. +// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the +// two minor dimensions equal. +// If `lower` is true, the data from the lower triangle is used; if false, the +// upper triangle is used. The input data in the other triangle of the input +// does not affect the output. Returns the output in the same lower/upper +// triangle. The data returned in the other output triangle is arbitrary and +// implementation-defined. +// +// If `a` is not Hermitian positive definite, returns an array full of NaNs. +XlaOp Cholesky(XlaOp a, bool lower); + +// Enqueues an infeed instruction onto the computation, which writes data of +// the given shape to the infeed buffer of the device. +XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const std::string& config = ""); + +// Variant of Infeed which takes a token-shaped operand and produces a +// two-element tuple containing the data value and a token-shaped value. +// Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp InfeedWithToken(XlaOp token, const Shape& shape, + const std::string& config = ""); + +// Enqueues an outfeed instruction onto the computation. This instruction +// generates outgoing data transfers for the given data. +// +// shape_with_layout communicates the laid out shape that we want to outfeed +// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error +// will occur. +void Outfeed(XlaOp operand, const Shape& shape_with_layout, + const std::string& outfeed_config); + +// Variant of Outfeed which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, + const Shape& shape_with_layout, + const std::string& outfeed_config); + +// Enqueues a call instruction onto the computation. +XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + absl::Span operands); + +// Enqueues a composite call instruction onto the computation. +XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation, + absl::Span operands, const std::string& name, + std::optional attributes = std::nullopt, + std::optional version = std::nullopt); + +// Enqueues a custom call instruction onto the computation. A custom call +// invokes code external to XLA. The |operands| are passed to the external code, +// and the external code is expected to produce a result of the given +// |shape|. The exact mechanism is backend-specific. For example, in the CPU +// backend, a call instruction is emitted which targets a symbol with the name +// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, +// but |call_target_name| should be short as it may be used in labels. |opaque| +// can encode arbitrarily large amounts of information. |has_side_effect| +// specifies whether the instruction can have side effects. +// |output_operand_aliasing| specifies a list of output/operand buffer pairs +// that alias each other, where the output buffer is represented as a +// ShapeIndex, and the operand buffer is represented as the operand index and +// the ShapeIndex. +XlaOp CustomCall( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape, + const std::string& opaque = "", bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}, + const Literal* literal = nullptr, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion api_version = API_VERSION_ORIGINAL); + +// Overload which constructs a custom call that applies an Xla computation. +XlaOp CustomCallWithComputation( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const XlaComputation& computation, + const Shape& shape, const std::string& opaque = "", + bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}, + const Literal* literal = nullptr, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion api_version = API_VERSION_ORIGINAL); + +// Overload which constructs a custom call with fixed layouts. The operands will +// have the layouts specified by |operand_shapes_with_layout| when provided to +// external code, and the external code is expected to produce a result with the +// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| +// and |operand_shapes_with_layout| must have layouts. +XlaOp CustomCallWithLayout( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const std::string& opaque = "", bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}, + const Literal* literal = nullptr, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion api_version = API_VERSION_ORIGINAL); + +// Overload which annotates a custom call with the given Window and +// ConvolutionDimensionNumbers. Useful for custom-calls which represent +// convolutions. +// +// This sets the layout of its operands if operand_shapes_with_layout is +// nonempty, and it sets the layout of its result if `shape` has a layout. +XlaOp CustomCallWithConvDnums( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, Window window, ConvolutionDimensionNumbers dnums, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion api_version = API_VERSION_ORIGINAL); + +// Enqueues an optimization barrier onto the computation. +XlaOp OptimizationBarrier(XlaOp operand); + +// The following methods enqueue element-wise binary arithmetic operations +// onto the computation. The shapes of the operands have to match unless one +// of the operands is a scalar, or an explicit broadcast dimension is given +// (see g3doc for more details). + +// Enqueues a complex compose instruction onto the computation. +XlaOp Complex(XlaOp real, XlaOp imag, + absl::Span broadcast_dimensions = {}); + +// Enqueues a complex conjugate instruction onto the computation. +XlaOp Conj(XlaOp operand); + +// Enqueues an add instruction onto the computation. +XlaOp Add(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a subtract instruction onto the computation. +XlaOp Sub(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a multiply instruction onto the computation. +XlaOp Mul(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a divide instruction onto the computation. +XlaOp Div(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a remainder instruction onto the computation. +XlaOp Rem(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a max instruction onto the computation. +XlaOp Max(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a min instruction onto the computation. +XlaOp Min(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Element-wise logical operators +XlaOp And(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Overload to call And with 3 or more operands. We need the following somewhat +// convoluted overload set to disambiguate with the overload that takes the +// `broadcast_dimensions` optional param. +inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) { + return And(op1, And(op2, op3)); +} +template +XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { + return And(op1, And(op2, And(op3, operands...))); +} + +XlaOp Or(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Overload to call Or with 3 or more operands. As with `And`, we need the +// following complicated overload set to handle the default arg in the `Or` +// overload above. +inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) { + return Or(op1, Or(op2, op3)); +} +template +XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { + return Or(op1, Or(op2, Or(op3, operands...))); +} + +XlaOp Xor(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +XlaOp Not(XlaOp operand); + +XlaOp PopulationCount(XlaOp operand); + +XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +// Reduces an array among the provided dimensions, given "computation" as a +// reduction operator. +XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span dimensions_to_reduce); + +// Reduces several arrays simultaneously among the provided dimensions, given +// "computation" as a reduction operator. +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + +// Convenience wrapper around the above that reduces all the dimensions in the +// operand shape. +XlaOp ReduceAll(XlaOp operand, XlaOp init_value, + const XlaComputation& computation); + +// Enqueues a windowed reduce instruction onto the computation. +XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + +XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + +// As ReduceWindow(), but the padding is given in the format +// returned by MakePadding(). +XlaOp ReduceWindowWithGeneralPadding( + XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); +XlaOp ReduceWindowWithGeneralPadding( + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + +// Returns the sum of the operand value within each subgroup of replicas. All +// replicas supply one input to the sum and all replicas receive the resulting +// sum for each subgroup. +XlaOp CrossReplicaSum(XlaOp operand, + absl::Span replica_groups = {}); + +XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +XlaOp AllGatherTuple( + absl::Span operands, int64_t all_gather_dimension, + int64_t shard_count, absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +// Enqueues an operation that do an AllReduce of the operand cross cores. Here +// AllReduce means doing a reduction on the input operand cross cores and then +// broadcasting the reduction result to those cores. The reduction function is +// defined by `computation`, which should be a commutative computation on +// scalars, e.g., add, min, or max. The way that AllReduce is applied is +// configured by: +// +// - `replica_groups`: each ReplicaGroup contains a list of replica id. If +// empty, all replicas belong to one group. Allreduce will be applied within +// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} +// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. +// +// - `channel_id`: for Allreduce nodes from different modules, if they have the +// same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be +// applied cross modules. +// +// - `shape_with_layout`: forces the layout of the AllReduce to the given +// layout. This is used to guarantee the same layout for a group of AllReduce +// ops compiled separately. +XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& shape_with_layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +XlaOp AllReduceTuple( + absl::Span operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& shape_with_layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +XlaOp ReduceScatter( + XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, + int64_t shard_count, absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +// Enqueues an operation that do an AllToAll of the operand cross cores. +// This involves AllToAll, followed by Reshape, Transpose, and another Reshape +// to get proper codegen. See implementation for additional details. +// +// An optional `layout` can be specified to force the layout of the instruction. +// This is used to guarantee the same layout for a group of AllToAll ops +// compiled separately. +XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, + absl::Span replica_groups = {}, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); + +XlaOp AllToAllTuple( + absl::Span operand, + absl::Span replica_groups = {}, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); + +XlaOp AllToAllTuple( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups = {}, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); + +XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); + +// Enqueues an collective operation that sends and receives data cross replicas. +// +// - `source_target_pair`: a list of (source_replica_id, target_replica_id) +// pairs. For each pair, the operand is sent from source replica to target +// replica. Note that, 1) any two pairs should not have the same target replica +// id, and they should not have the same source replica id; 2) if a replica id +// is not a target in any pair, then the output on that replica is a tensor +// consists of 0(s) with the same shape as the input. +XlaOp CollectivePermute( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id = std::nullopt); + +// Enqueues an operation that returns the replica ID. +XlaOp ReplicaId(XlaBuilder* builder); + +// Enqueues an operation that scatters the `source` array to the selected +// indices of each window. +XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, XlaOp source, XlaOp init_value, + const XlaComputation& scatter); + +// As SelectAndScatter(), but the padding is given in the format +// returned by MakePadding(). +XlaOp SelectAndScatterWithGeneralPadding( + XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + +// Enqueues an abs instruction onto the computation. +XlaOp Abs(XlaOp operand); + +// Enqueues a atan2 instruction onto the computation. +XlaOp Atan2(XlaOp y, XlaOp x, + absl::Span broadcast_dimensions = {}); + +// Enqueues an erf instruction onto the computation. +XlaOp Erf(XlaOp operand); + +// Enqueues an exp instruction onto the computation. +XlaOp Exp(XlaOp operand); + +// Enqueues an expm1 instruction onto the computation. +XlaOp Expm1(XlaOp operand); + +// Enqueues a floor instruction onto the computation. +XlaOp Floor(XlaOp operand); + +// Enqueues a ceil instruction onto the computation. +XlaOp Ceil(XlaOp operand); + +// Enqueues a round instruction onto the computation, +// with half-way cases rounding away from zero. +XlaOp Round(XlaOp operand); + +// Enqueues a round instruction onto the computation, rounding to nearest even +XlaOp RoundNearestEven(XlaOp operand); + +// Enqueues an log instruction (natural logarithm) onto the computation. +XlaOp Log(XlaOp operand); + +// Enqueues an log1p instruction (log(x+1)) onto the computation. +XlaOp Log1p(XlaOp operand); + +// Enqueues a logistic instruction onto the computation. +XlaOp Logistic(XlaOp operand); + +// Enqueues a sign instruction onto the computation. +XlaOp Sign(XlaOp operand); + +// Enqueues a count leading zeros instruction onto the computation. +XlaOp Clz(XlaOp operand); + +// Enqueues a cosine instruction onto the computation. +XlaOp Cos(XlaOp operand); + +// Enqueues a sine instruction onto the computation. +XlaOp Sin(XlaOp operand); + +// Enqueues a tan instruction onto the computation. +XlaOp Tan(XlaOp operand); + +// Enqueues a tanh instruction onto the computation. +XlaOp Tanh(XlaOp operand); + +// Enqueues a real-part instruction onto the computation. +XlaOp Real(XlaOp operand); + +// Enqueues an imaginary-part instruction onto the computation. +XlaOp Imag(XlaOp operand); + +// Enqueues a sqrt computation onto the computation. +XlaOp Sqrt(XlaOp operand); + +// Enqueues a cbrt computation onto the computation. +XlaOp Cbrt(XlaOp operand); + +// Enqueues a rsqrt computation onto the computation. +XlaOp Rsqrt(XlaOp operand); + +// Enqueues a lhs^rhs computation onto the computation. +XlaOp Pow(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues an operator that tests if the operand's values are finite, i.e., not +// +/-Inf or NaN. Returns an array of booleans with the same shape where +// entries are true iff the corresponding entry was not infinite or NaN. +// +// Defined only for real-valued (i.e. not complex) floating-point types; raises +// an error for other types. +// +// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. +XlaOp IsFinite(XlaOp operand); + +// Enqueues an iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension); + +// Enqueues a rank-1 iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size); + +// Enqueues a convert instruction onto the computation that changes the +// element type of the operand array to primitive_type. +XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); + +// Enqueues a no-op instruction onto the computation that changes +// the element type of the operand array to primitive_type. The +// bit-widths of the source and destination element types must be +// identical. +XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); + +// Enqueues a stochastic convert instruction onto the computation that changes +// the element type of the operand array with stochastic rounding to +// primitive_type. +XlaOp StochasticConvertType(XlaOp operand, XlaOp random, + PrimitiveType new_element_type); + +// Enqueues a negate instruction onto the computation. +XlaOp Neg(XlaOp operand); + +// Enqueues a transpose instruction onto the computation. +XlaOp Transpose(XlaOp operand, absl::Span permutation); + +// Enqueues a reverse instruction onto the computation. The order of the +// elements in the given dimensions is reversed (i.e., the element at index i +// is moved to index dimension_size - 1 - i). +XlaOp Rev(XlaOp operand, absl::Span dimensions); + +// Enqueues a sort instruction onto the computation, using 'comparator' for +// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' +// determines whether the stable sorting should be used. +// If only one operand is provided: +// * If the operand is a rank-1 tensor (an array), the result is a sorted array. +// The resulting sorting order has the property that for all index positions +// i, j with i < j, either +// comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or +// comparator(value[i], value[j]) = true. +// * If the operand has higher rank, the operand is sorted along the provided +// dimension. For example, for a rank-2 tensor (a matrix), a dimension value +// of 0 will independently sort every column, and a dimension value of 1 will +// independently sort each row. If no dimension number is provided, then the +// last dimension is chosen by default. For the dimension which is sorted, the +// same sorting order applies as in the rank-1 case. +// +// If more than one operand is provided: +// * All operands must be tensors with the same dimensions. The element types of +// the tensors may be different. +// * The result is a tuple that consists of the operands in sorted order (along +// the provided dimension, as above). The same permutation as implied by the +// comparison computation is applied to all operand tensors. When comparing +// two index positions, 'comparator' is called with 2 * n scalar parameters, +// where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at +// two index positions. +// Default comparator computations can be found in lib/comparators.h +XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64_t dimension = -1, bool is_stable = false); + +// Enqueues a topk instruction onto the computation. TopK returns the largest +// 'k' values and their indices along the last dimension of the 'operand' if +// `lagest=true` or the smallest `k` values if `largest=false`. +// +// * If the operand is a rank-1 tensor (an array), the result is a tuple that +// consists of: +// * a sorted array with the top 'k' elements. +// * an array containing the indices of the k elements. +// For example, if the input is [0.1, 0.3, 0.2] and k == 2, the output tuple +// is ([0.3, 0.2], [1, 2]). +// * If the operand has higher rank, the result is a tuple that consists of: +// * a tensor equivalent to one produced by sorting the operand along the last +// dimension and slicing that dimension to only the top 'k' values. The last +// dimension is sorted as in the rank-1 case. +// * a tensor containing the indices of the top 'k' values along the last +// dimension. +// For example, if the input is [0.1, 0.3, 0.2][0.5, 0.4, 0.6] and k == 1, the +// output tuple is ([0.3][0.6], [1][2]). +XlaOp TopK(XlaOp operand, int64_t k, bool largest); + +// Enqueues a clamp instruction onto the computation. +XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); + +// Enqueues a map instruction onto the computation. +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands = {}); + +// Enqueues a N(mu, sigma) random number generation instruction onto the +// computation. +XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); + +// Enqueues a U(a, b) random number generation instruction onto the +// computation. Returns values in the semi-open interval [a, b). +XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); + +// Enqueues a B(initial_state) random bit generation instruction onto the +// computation. Returns the new key and random bits with the specified shape. +XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, + const Shape& shape); + +// Enqueues a while node onto the computation. +XlaOp While(const XlaComputation& condition, const XlaComputation& body, + XlaOp init); + +// Enqueues a conditional node onto the computation. +XlaOp Conditional(XlaOp predicate, XlaOp true_operand, + const XlaComputation& true_computation, XlaOp false_operand, + const XlaComputation& false_computation); + +// Enqueues either a predicated (if/else) or indexed (switch/case/default) +// conditional node onto the computation. N >= 1 branch_computations and +// branch_operands are matched by index. branch_index selects the branch that +// will be executed. Out of range branch_index uses the N-1'th +// branch_computation as default. +XlaOp Conditional(XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + +// Enqueues a ReducePrecision node onto the computation. +XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); + +// Enqueues a Gather node onto the computation. +XlaOp Gather(XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, + bool indices_are_sorted = false); + +// Enqueues a Scatter node onto the computation. +XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false, bool unique_indices = false); +XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, + absl::Span updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false, bool unique_indices = false); + +// Enqueues a Send node onto the computation for device-to-device +// communication. This operation sends the given operand to +// a Recv instruction in a different computation that shares the same channel +// handle. +void Send(XlaOp operand, const ChannelHandle& handle); + +// Variant of Send which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); + +// Enqueues a Recv node onto the computation for device-to-device +// communication. The data comes from a Send instruction in a different +// computation that shares the same channel handle and its shape must be the +// same as the given shape. +XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + +// Variant of Recv which takes a token-shaped operand and produces a two-element +// tuple containing the data value and a token-shaped value. Tokens are used +// for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp RecvWithToken(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues a Send node which transfers data from the device to the host. The +// 'shape_with_layout' argument defines the layout of the data transferred; its +// shape must be compatible with the shape of the operand. The operand must be +// array-shaped. +// TODO(b/111544877): Support tuple shapes. +XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const ChannelHandle& handle); + +// Enqueues a Recv node which transfers data from the host to the device. The +// given shape must contain a layout and must be an array. +// TODO(b/111544877): Support tuple shapes. +XlaOp RecvFromHost(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues an operation (AfterAll) with no operands that produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// This is a separate method from AfterAll to facility the removal of +// operand-less AfterAll instructions. +// TODO(b/110532604): Remove this function when all tokens are derived from a +// single token generated or passed into the entry computation. +XlaOp CreateToken(XlaBuilder* builder); + +// Enqueues an AfterAll instruction which produces a token-shaped value and +// takes a variadic number of token-shaped operands. The number of operands must +// be greater than zero. Used for joining tokens. +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// Returns a tuple (normalized, batch_mean, batch_var) where `normalized` +// is the normalized result and batch_mean and batch_var are the mean and +// variance, respectively, across batch for the operand. +XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon, + int64_t feature_index); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// `BatchNormInference` is equivalent to calling `BatchNormTraining` without +// computing `mean` and `variance` for each batch inside the operation. It +// uses the input `mean` and `variance` instead as estimated values. The +// purpose of this op is to reduce latency in inference, hence the name +// `BatchNormInference`. +// +// The output has the same shape as `operand`, and contains the normalized +// values for each batch. +XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, + XlaOp variance, float epsilon, int64_t feature_index); + +// Calculates the gradients of a batch norm op. +// +// The inputs `batch_mean` and `batch_var` represent the mean and variance +// across the batch. +// +// Returns a tuple of three elements: +// - grad_operand: Gradient with respect to input `operand` +// - grad_offset: Gradient with respect to input `offset` +// - grad_scale: Gradient with respect to input `scale` +XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, + XlaOp batch_var, XlaOp grad_output, float epsilon, + int64_t feature_index); + +// Returns the size of the given dimension of the operand. The operand must be +// array shaped. +XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); + +// Sets the size of the given dimension of the operand. The operand must be +// array shaped. The result will have the same shape as the operand, but the +// given dimension will be dynamic (if not already). +XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); + +// Returns the same op but with dynamic dimension removed. +XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); + +// Implementation details below this point. +// + +// Free function template implementations. + +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { + return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); +} + +template +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { + BorrowingLiteral literal( + reinterpret_cast(values.begin()), + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + {static_cast(values.size())})); + return ConstantLiteral(builder, literal); +} + +template +XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(builder, literal); +} + +inline XlaOp ConstantR1(XlaBuilder* builder, const tsl::core::Bitmap& values) { + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); +} + +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values) { + return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); +} + +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout) { + return ConstantLiteral( + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { + return ConstantLiteral(builder, + LiteralUtil::CreateFromArray(values)); +} + +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout) { + return ConstantLiteral( + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values) { + return ConstantLiteral(builder, + LiteralUtil::CreateR2FromArray2D(values)); +} + +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout) { + return ConstantLiteral( + builder, + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values) { + return ConstantFromArray(builder, values); +} + +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout) { + return ConstantFromArrayWithLayout(builder, values, layout); +} + +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values) { + return ConstantFromArray(builder, values); +} + +// Switches from automatic SPMD partitioning to manual partitioning. Converts a +// full-shaped tensor (to be automatically partitioned by SPMD partitioner) to a +// shard-shaped tensor to be consumed by manually partitioned ops. +absl::StatusOr ConvertSpmdFullToShardShape( + xla::XlaBuilder* builder, xla::XlaOp input, int single_dim, + const xla::OpSharding& manual_sharding, + absl::Span unspecified_dims); + +// Switches from manual partitioning to automatic SPMD partitioning. Converts a +// shard-shaped tensor (manually partitioned in SPMD-style) to a full-shaped +// tensor to be partitioned automatically by the SPMD partitioner. +absl::StatusOr ConvertSpmdShardToFullShape( + xla::XlaBuilder* builder, xla::XlaOp input, const xla::Shape& output_shape, + int single_dim, const xla::OpSharding& manual_sharding, + absl::Span unspecified_dims); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_XLA_BUILDER_H_ diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/hlo/builder/xla_builder_test.cc similarity index 99% rename from third_party/xla/xla/client/xla_builder_test.cc rename to third_party/xla/xla/hlo/builder/xla_builder_test.cc index 5aa2ef9fb13c19..293f24d634f67a 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include #include @@ -34,12 +34,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/client/xla_computation.cc b/third_party/xla/xla/hlo/builder/xla_computation.cc similarity index 96% rename from third_party/xla/xla/client/xla_computation.cc rename to third_party/xla/xla/hlo/builder/xla_computation.cc index fc558462d1a576..1d01870f1d85c9 100644 --- a/third_party/xla/xla/client/xla_computation.cc +++ b/third_party/xla/xla/hlo/builder/xla_computation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include #include diff --git a/third_party/xla/xla/hlo/builder/xla_computation.h b/third_party/xla/xla/hlo/builder/xla_computation.h new file mode 100644 index 00000000000000..379d386e4b7908 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/xla_computation.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_XLA_COMPUTATION_H_ +#define XLA_HLO_BUILDER_XLA_COMPUTATION_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// The computation graph that the user builds up with the XlaBuilder. +class XlaComputation { + public: + XlaComputation() : unique_id_(-1) {} + XlaComputation(HloModuleProto proto) + : unique_id_(proto.id()), proto_(std::move(proto)) {} + + ~XlaComputation() = default; + + XlaComputation(const XlaComputation&) = delete; + XlaComputation& operator=(const XlaComputation&) = delete; + + XlaComputation(XlaComputation&& from) = default; + + XlaComputation& operator=(XlaComputation&& from) = default; + + // Returns the "program shape" (parameter and return shapes) for this + // computation. + absl::StatusOr GetProgramShape() const; + + const std::string& name() const { return proto().name(); } + + const HloModuleProto& proto() const { return proto_; } + HloModuleProto* mutable_proto() { return &proto_; } + + // Requests that we snapshot the computation into a serializable protocol + // buffer form. + absl::StatusOr> Snapshot() const; + + // Returns true if this object is a null Computation. + bool IsNull() const { return unique_id_ == -1; } + + private: + XlaComputation(const int64_t unique_id) : unique_id_(unique_id) {} + friend class XlaBuilder; + + int64_t unique_id_; + HloModuleProto proto_; +}; + +} // namespace xla + +#endif // XLA_HLO_BUILDER_XLA_COMPUTATION_H_ From a947deb1e412f984c69186e60c04ee48795b3f28 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Fri, 27 Sep 2024 09:06:05 -0700 Subject: [PATCH 370/483] #sdy remove size one axes from all shardings and meshes in the module. PiperOrigin-RevId: 679612495 --- third_party/xla/xla/service/spmd/shardy/BUILD | 1 + .../xla/service/spmd/shardy/sdy_opt_main.cc | 2 + .../service/spmd/shardy/sdy_round_trip/BUILD | 21 +- .../shardy/sdy_round_trip/import_shardings.cc | 3 - .../spmd/shardy/sdy_round_trip/pipelines.cc | 3 +- .../sdy_round_trip/remove_size_one_axes.cc | 238 ++++++++++++++++++ .../sdy_round_trip/remove_size_one_axes.h | 37 +++ .../test/sdy_round_trip_import_pipeline.mlir | 73 +++++- .../sdy_round_trip_remove_size_one_axes.mlir | 81 ++++++ 9 files changed, 441 insertions(+), 18 deletions(-) create mode 100644 third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc create mode 100644 third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h create mode 100644 third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_remove_size_one_axes.mlir diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD index db4d32ecc78ea0..3689706e3a1c14 100644 --- a/third_party/xla/xla/service/spmd/shardy/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/BUILD @@ -154,6 +154,7 @@ xla_cc_binary( "//xla/service/spmd/shardy/sdy_round_trip:export_shardings", "//xla/service/spmd/shardy/sdy_round_trip:import_shardings", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "//xla/service/spmd/shardy/sdy_round_trip:remove_size_one_axes", "//xla/service/spmd/shardy/sdy_round_trip:shard_map_export", "//xla/service/spmd/shardy/sdy_round_trip:shard_map_import", "//xla/service/spmd/shardy/sdy_round_trip/test_utils:mhlo_to_hlo_to_mhlo", diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc index 892dd9b66a0859..3b4689c4a60d54 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" #include "xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.h" @@ -68,6 +69,7 @@ int main(int argc, char** argv) { xla::sdy::registerSdyRoundTripMhloToHloToMhloPass(); xla::sdy::registerSdyRoundTripExportShardingsPass(); xla::sdy::registerSdyRoundTripImportShardingsPass(); + xla::sdy::registerSdyRoundTripRemoveSizeOneAxesPass(); xla::sdy::registerSdyRoundTripExportOpsPass(); xla::sdy::registerSdyRoundTripExportPipeline(); xla::sdy::registerSdyRoundTripShardMapExportPass(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD index e7a81caaa33155..ab72795f3959be 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -58,8 +58,6 @@ cc_library( "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:FuncDialect", @@ -110,6 +108,22 @@ cc_library( ], ) +cc_library( + name = "remove_size_one_axes", + srcs = ["remove_size_one_axes.cc"], + hdrs = ["remove_size_one_axes.h"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@shardy//shardy/dialect/sdy/ir:dialect", + ], +) + cc_library( name = "pipelines", srcs = ["pipelines.cc"], @@ -118,12 +132,11 @@ cc_library( ":export_ops", ":export_shardings", ":import_shardings", + ":remove_size_one_axes", ":shard_map_export", ":shard_map_import", - "//xla/mlir_hlo:mhlo_passes", "//xla/service:hlo_proto_cc", "//xla/service/spmd/shardy/mhlo_round_trip:export_shardings", - "//xla/service/spmd/shardy/mhlo_round_trip:shard_map_import", "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc index fc6e55b2036781..7e4fbe31f7a4ed 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc @@ -18,10 +18,7 @@ limitations under the License. #include #include #include -#include -#include "absl/log/check.h" -#include "absl/strings/escaping.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir/AsmParser/AsmParser.h" diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc index 78fe16400e2299..31915ffe830237 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc @@ -23,11 +23,11 @@ limitations under the License. #include "mlir/Transforms/Passes.h" #include "xla/service/hlo.pb.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" -#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h" #include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h" +#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" @@ -55,6 +55,7 @@ void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) { addCommonPreImportPasses(pm); pm.addPass(createSdyRoundTripImportShardingsPass()); pm.addPass(createSdyRoundTripShardMapImportPass()); + pm.addPass(createSdyRoundTripRemoveSizeOneAxesPass()); addCommonPostImportPasses(pm); } diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc new file mode 100644 index 00000000000000..446fd49756a851 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc @@ -0,0 +1,238 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "shardy/dialect/sdy/ir/constants.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::ModuleOp; +using ::mlir::Operation; +using ::mlir::SmallVector; +using ::mlir::StringAttr; +using ::mlir::StringRef; +using ::mlir::SymbolTable; +using ::mlir::func::FuncOp; +using ::mlir::sdy::AxisRefAttr; +using ::mlir::sdy::DimensionShardingAttr; +using ::mlir::sdy::getMeshAttr; +using ::mlir::sdy::kShardingAttr; +using ::mlir::sdy::ManualAxesAttr; +using ::mlir::sdy::ManualComputationOp; +using ::mlir::sdy::MeshAttr; +using ::mlir::sdy::MeshAxisAttr; +using ::mlir::sdy::MeshOp; +using ::mlir::sdy::TensorShardingAttr; +using ::mlir::sdy::TensorShardingPerValueAttr; + +bool hasSizeOneAxes(MeshOp meshOp) { + return llvm::any_of(meshOp.getMesh().getAxes(), + [](MeshAxisAttr axis) { return axis.getSize() == 1; }); +} + +MeshAttr removeSizeOneAxes(MeshAttr mesh) { + SmallVector axes; + llvm::copy_if(mesh.getAxes(), std::back_inserter(axes), + [](MeshAxisAttr axis) { return axis.getSize() != 1; }); + return MeshAttr::get(mesh.getContext(), axes); +} + +TensorShardingAttr removeSizeOneAxes(TensorShardingAttr sharding, + const SymbolTable& symbolTable) { + MeshAttr mesh = getMeshAttr(symbolTable, sharding.getMeshName()); + CHECK(mesh) << "unknown mesh: " << std::string_view(sharding.getMeshName()); + + auto isNotSizeOne = [&](AxisRefAttr axis) { return axis.getSize(mesh) != 1; }; + + // Remove from dimension shardings. + SmallVector dimShardings; + dimShardings.reserve(sharding.getRank()); + for (DimensionShardingAttr dimSharding : sharding.getDimShardings()) { + SmallVector newAxes; + newAxes.reserve(dimSharding.getAxes().size()); + llvm::copy_if(dimSharding.getAxes(), std::back_inserter(newAxes), + isNotSizeOne); + // Remove priority if there are no sharding axes and the dimension is + // closed, since this isn't allowed by verification (would have no effect on + // propagation). + std::optional priority = + newAxes.empty() && dimSharding.getIsClosed() + ? std::nullopt + : dimSharding.getPriority(); + dimShardings.push_back( + DimensionShardingAttr::get(dimSharding.getContext(), newAxes, + dimSharding.getIsClosed(), priority)); + } + + // Remove from replicated axes. + SmallVector replicatedAxes; + llvm::copy_if(sharding.getReplicatedAxes(), + std::back_inserter(replicatedAxes), isNotSizeOne); + + return TensorShardingAttr::get(sharding.getContext(), sharding.getMeshName(), + dimShardings, replicatedAxes); +} + +TensorShardingPerValueAttr removeSizeOneAxes( + TensorShardingPerValueAttr shardings, const SymbolTable& symbolTable) { + SmallVector newShardings; + newShardings.reserve(shardings.size()); + for (TensorShardingAttr sharding : shardings.getShardings()) { + newShardings.push_back(removeSizeOneAxes(sharding, symbolTable)); + } + return TensorShardingPerValueAttr::get(shardings.getContext(), newShardings); +} + +ManualAxesAttr removeSizeOneAxes(ManualAxesAttr manualAxes, MeshAttr mesh) { + SmallVector newAxes; + llvm::copy_if( + manualAxes.getValue(), std::back_inserter(newAxes), + [&](StringAttr axisName) { return mesh.getAxisSize(axisName) != 1; }); + return ManualAxesAttr::get(manualAxes.getContext(), newAxes); +} + +void removeSizeOneAxes(ManualComputationOp manualComputationOp, + const SymbolTable& symbolTable) { + CHECK(!manualComputationOp->getOperands().empty() && + !manualComputationOp->getResults().empty()) + << "ManualComputationOp must have at least one operand or one result"; + std::optional meshName = mlir::sdy::getCommonMeshName( + manualComputationOp.getInShardings().getShardings(), + manualComputationOp.getOutShardings().getShardings()); + CHECK(meshName) << "all in/out shardings must have the same mesh"; + MeshAttr mesh = getMeshAttr(symbolTable, *meshName); + CHECK(mesh) << "unknown mesh: " << std::string_view(*meshName); + + manualComputationOp.setInShardingsAttr( + removeSizeOneAxes(manualComputationOp.getInShardingsAttr(), symbolTable)); + manualComputationOp.setOutShardingsAttr(removeSizeOneAxes( + manualComputationOp.getOutShardingsAttr(), symbolTable)); + manualComputationOp.setManualAxesAttr( + removeSizeOneAxes(manualComputationOp.getManualAxesAttr(), mesh)); +} + +void removeSizeOneAxes(FuncOp funcOp, const SymbolTable& symbolTable) { + for (mlir::BlockArgument arg : funcOp.getArguments()) { + if (auto sharding = mlir::sdy::getSharding(arg)) { + mlir::sdy::setSharding(arg, removeSizeOneAxes(sharding, symbolTable)); + } + } + + for (int64_t resNum = 0; resNum < funcOp.getNumResults(); ++resNum) { + if (auto sharding = funcOp.getResultAttrOfType( + resNum, kShardingAttr)) { + funcOp.setResultAttr(resNum, kShardingAttr, + removeSizeOneAxes(sharding, symbolTable)); + } + } + + funcOp.front().walk([&](Operation* op) { + return mlir::TypeSwitch(op) + .Case( + [&](ManualComputationOp manualComputationOp) { + removeSizeOneAxes(manualComputationOp, symbolTable); + }) + .Default([&](Operation* op) { + if (auto sharding = op->getAttrOfType( + kShardingAttr)) { + op->setAttr(kShardingAttr, + removeSizeOneAxes(sharding, symbolTable)); + } + }); + }); +} + +class SdyRoundTripRemoveSizeOneAxesPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + SdyRoundTripRemoveSizeOneAxesPass) + + void runOnOperation() final { + ModuleOp moduleOp = getOperation(); + mlir::SymbolTableCollection symbolTableCollection; + SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp); + + if (llvm::none_of(moduleOp.getOps(), hasSizeOneAxes)) { + // Nothing to do. + return; + } + + LOG(INFO) << "[Shardy] removing axes of size one."; + + for (auto funcOp : moduleOp.getOps()) { + removeSizeOneAxes(funcOp, symbolTable); + } + + for (auto meshOp : moduleOp.getOps()) { + meshOp.setMeshAttr(removeSizeOneAxes(meshOp.getMesh())); + } + } + + StringRef getArgument() const override { + return "xla-sdy-round-trip-remove-size-one-axes"; + } + + StringRef getDescription() const override { + return "Removes axes of size one from all meshes, shardings, and manual " + "computation ops, to avoid conflict during propagation that are due " + "to such axes."; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const final { + registry.insert(); + } +}; + +} // namespace + +std::unique_ptr createSdyRoundTripRemoveSizeOneAxesPass() { + return std::make_unique(); +} + +void registerSdyRoundTripRemoveSizeOneAxesPass() { + mlir::registerPass(createSdyRoundTripRemoveSizeOneAxesPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h new file mode 100644 index 00000000000000..04d280e5d91178 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_REMOVE_SIZE_ONE_AXES_H_ +#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_REMOVE_SIZE_ONE_AXES_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates the pass that removes axes of size one from all meshes, shardings, +// and manual computation ops, to avoid conflict during propagation that are due +// to such axes. +std::unique_ptr createSdyRoundTripRemoveSizeOneAxesPass(); + +// Registers the xla-sdy-round-trip-remove-size-one-axes pass. +void registerSdyRoundTripRemoveSizeOneAxesPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_REMOVE_SIZE_ONE_AXES_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 85cbde38d43f52..067125d48fbe8d 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -1,19 +1,22 @@ // RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s // CHECK-LABEL: module @multiple_func_result_shardings -module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>}"}} { +module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = + "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>, mesh2 = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=1]>}"}} { // CHECK: sdy.mesh @mesh = <["a"=8, "b"=8, "c"=8]> + // CHECK: sdy.mesh @mesh2 = <["b"=4]> + // CHECK-LABEL: func @func_results_with_sharding - // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, - // CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, - // CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>} - // CHECK-SAME: ) -> ( - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>}, - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>}, - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p3]>}) { + // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, + // CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, + // CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>} + // CHECK-SAME: ) -> ( + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p3]>}) { // CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg2 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> // CHECK-NEXT: } func.func @func_results_with_sharding( @@ -109,6 +112,56 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x %3 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %3 : tensor<32xi32> } + + // CHECK-LABEL: func @shardings_with_size_one_axes + // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{"b"}p1]>}, + // CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{}], replicated={"b"}>}, + // CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{"b", ?}p0]>} + // CHECK-SAME: ) -> ( + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{}]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{"b"}]>}) { + func.func @shardings_with_size_one_axes( + %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22b\\\22}p1], replicated={\\\22c\\\22}>"}}, + %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22a\\\22}p2], replicated={\\\22b\\\22}>"}}, + %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22c\\\22, \\\22b\\\22, ?}p0]>"}} + ) -> (tensor<32xi32>, tensor<32xi32>) { + // CHECK-NEXT: %[[SC1:.*]] = sdy.sharding_constraint %arg0 <@mesh2, [{"b", ?}]> + // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SC1]], %[[SC1]] + // CHECK-NOT: sdy.sharding + // CHECK-NEXT: %[[SC2:.*]] = sdy.sharding_constraint %arg1 <@mesh2, [{}]> + // CHECK-NEXT: return %[[ADD]], %[[SC2]] + // CHECK-NEXT: } + %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22, \\\22b\\\22, ?}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = mhlo.add %0, %0 : tensor<32xi32> + %2 = mhlo.custom_call @Sharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22c\\\22, \\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %4 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22b\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + return %3, %4 : tensor<32xi32>, tensor<32xi32> + } + + // CHECK-LABEL: func @manual_computation_with_size_one_axes + func.func @manual_computation_with_size_one_axes(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xf32>) -> (tensor<16x32xf32>) { + // CHECK-NOT: call @local_xla.sdy.manual_computation_body + // CHECK: %[[MAN_COMP:.*]] = sdy.manual_computation(%arg0, %arg1) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh2, [{}, {"b"}]>, <@mesh2, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh2, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME: (%arg2: tensor<16x8xf32>, %arg3: tensor<16x8xf32>) { + // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg2, %arg3 + // CHECK-NEXT: sdy.return %[[ADD]] + // CHECK-NEXT: } : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32> + // CHECK-NEXT: return %[[MAN_COMP]] + %0:2 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<16x8xf32>, tensor<16x8xf32>) + %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh2, [{}, {\\\22b\\\22}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh2, [{}, {\\\22b\\\22, \\\22a\\\22}]>]>"}} : (tensor<16x8xf32>, tensor<16x8xf32>) -> tensor<16x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<16x8xf32>) -> tensor<16x32xf32> + return %2 : tensor<16x32xf32> + } + + // CHECK-NOT: func @local_xla.sdy.manual_computation_body( + func.func @local_xla.sdy.manual_computation_body(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> tensor<16x8xf32> { + %0 = mhlo.add %arg0, %arg1 : tensor<16x8xf32> + return %0 : tensor<16x8xf32> + } } // ----- diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_remove_size_one_axes.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_remove_size_one_axes.mlir new file mode 100644 index 00000000000000..c83395f2e36eff --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_remove_size_one_axes.mlir @@ -0,0 +1,81 @@ +// RUN: sdy_opt %s -xla-sdy-round-trip-remove-size-one-axes 2>&1 | FileCheck %s + +sdy.mesh @mesh1 = <["a"=1, "b"=2, "c"=1, "d"=4, "e"=1]> +sdy.mesh @mesh2 = <["a"=4, "b"=2]> +sdy.mesh @mesh3 = <["x"=1, "y"=1]> + +// CHECK: sdy.mesh @mesh1 = <["b"=2, "d"=4]> +// CHECK: sdy.mesh @mesh2 = <["a"=4, "b"=2]> +// CHECK: sdy.mesh @mesh3 = <[]> + +// CHECK-LABEL: func @func_and_op_shardings +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"b"}, {?}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"d", ?}, {}], replicated={"b"}>}, +// CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh2, [{"a"}, {"b"}]>} +// CHECK-SAME: ) -> ( +// CHECK-SAME: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{}, {?}]>}, +// CHECK-SAME: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"b"}, {}]>}) { +func.func @func_and_op_shardings( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a", "b"}, {"c", ?}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"d", "e", ?}, {}], replicated={"b", "c"}>}, + %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh2, [{"a"}, {"b"}]>} +) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"e"}, {"c", ?}]>}, + tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a", "b", "c"}, {}], replicated={"e"}>}) { + // CHECK-NEXT: %[[ADD1:.*]] = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"d", ?}, {?}]>]>} + // CHECK-NEXT: %[[ADD2:.*]] = mhlo.add %arg2, %arg2 + // CHECK-NOT: sdy.sharding + // CHECK-NEXT: %[[ADD3:.*]] = mhlo.add %[[ADD2]], %[[ADD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{}, {}], replicated={"d"}>]>} + // CHECK-NEXT: return %[[ADD1]], %[[ADD3]] + // CHECK-NEXT: } + %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"d", ?}, {"e", ?}]>]>} : tensor<8x8xf32> + %1 = mhlo.add %arg2, %arg2 : tensor<8x8xf32> + %2 = mhlo.add %1, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"c"}, {}], replicated={"d"}>]>} : tensor<8x8xf32> + return %0, %2 : tensor<8x8xf32>, tensor<8x8xf32> +} + +// CHECK-LABEL: func @shardings_with_priorities +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"b"}p0, {?}p3], replicated={"d"}>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh2, [{"a", ?}p2, {}]>} +// CHECK-SAME: ) -> ( +// CHECK-SAME: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh3, [{}, {?}p2]>}) { +func.func @shardings_with_priorities( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a", "b"}p0, {"c", ?}p3], replicated={"d", "e"}>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh2, [{"a", ?}p2, {}]>} +) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh3, [{"x"}p1, {"y", ?}p2]>}) { + // CHECK-NEXT: %[[ADD1:.*]] = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"d", ?}p1, {?}]>]>} + // CHECK-NEXT: %[[ADD2:.*]] = mhlo.add %[[ADD1]], %[[ADD1]] + // CHECK-NOT: sdy.sharding + // CHECK-NEXT: return %[[ADD2]] + // CHECK-NEXT: } + %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"d", ?}p1, {"e", ?}]>]>} : tensor<8x8xf32> + %1 = mhlo.add %0, %0 : tensor<8x8xf32> + return %1 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @manual_computation +func.func @manual_computation(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) -> tensor<8x32xf32> { + // CHECK-NEXT: %[[MAN_COMP:.*]] = sdy.manual_computation(%arg0, %arg1) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh1, [{"d"}, {"b"}]>, <@mesh1, [{"b"}, {}], replicated={"d"}>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh1, [{"d"}, {}], replicated={"b"}>] + // CHECK-SAME{LITERAL}: manual_axes={"b", "d"} + // CHECK-SAME: (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { + // CHECK-NEXT: stablehlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{?}, {?}]>]>} + // CHECK: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: return %[[MAN_COMP]] + %0 = sdy.manual_computation(%arg0, %arg1) + in_shardings=[<@mesh1, [{"d", "a"}, {"b"}]>, <@mesh1, [{"b"}, {"c", "a"}], replicated={"d"}>] + out_shardings=[<@mesh1, [{"d"}, {}], replicated={"b", "c"}>] + manual_axes={"a", "b", "c", "d"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { + %1 = stablehlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{?}, {"e", ?}]>]>} : tensor<2x8xf32> + %2 = stablehlo.dot %1, %arg3 : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %3 = "stablehlo.all_reduce"(%2) ({ + ^bb0(%arg4: tensor, %arg5: tensor): + %4 = stablehlo.add %arg4, %arg5 : tensor + stablehlo.return %4 : tensor + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<2x32xf32>) -> tensor<2x32xf32> + sdy.return %3 : tensor<2x32xf32> + } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> + return %0 : tensor<8x32xf32> +} From c540d1a3d42552276266686e413acaa209335c53 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Fri, 27 Sep 2024 10:27:44 -0700 Subject: [PATCH 371/483] Add ARM tolerances to exhaustive tests PiperOrigin-RevId: 679641677 --- .../exhaustive_unary_complex_test.cc | 76 ++++++++++++------- .../exhaustive_unary_test_functions.cc | 28 +++++++ 2 files changed, 78 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc index f93f5e31f544b0..d652e56bd1b6f4 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc @@ -138,16 +138,19 @@ UNARY_TEST_COMPLEX_64(Sqrt, { Run(Sqrt, [](complex64 x) { return std::sqrt(x); }, error_spec_gen); }) -double RsqrtCpuGpuAbsErr(complex64 x) { - return std::sqrt(std::numeric_limits::min()); +template +double RsqrtCpuGpuAbsErr(NativeT x) { + return std::sqrt(std::numeric_limits::min()); } -double RsqrtCpuGpuRelErr(complex64 x) { +template +double RsqrtCpuGpuRelErr(NativeT x) { // As noted above for Sqrt, the accuracy of sqrt degrades severely for // inputs with inputs with subnormals entries. - constexpr double eps = std::numeric_limits::epsilon(); - constexpr double norm_min = std::numeric_limits::min(); - constexpr double denorm_min = std::numeric_limits::denorm_min(); + constexpr double eps = std::numeric_limits::epsilon(); + constexpr double norm_min = std::numeric_limits::min(); + constexpr double denorm_min = + std::numeric_limits::denorm_min(); if (std::abs(x) < norm_min) { // Gradually loosen the relative tolerance as abs(x) becomes smaller // than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min. @@ -164,9 +167,16 @@ UNARY_TEST_COMPLEX_64(Rsqrt, { if (IsCpu()) { error_spec_gen = +[](complex64 x) { return ErrorSpec::Builder() - .abs_err(RsqrtCpuGpuAbsErr(x)) - .rel_err(RsqrtCpuGpuRelErr(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) +#ifdef __aarch64__ + // TODO(b/365620546): ARM and x86 handle complex(inf, nan) + // differently. + .skip_comparison(x.real() == 0.0f || + (std::isinf(x.real()) && std::isnan(x.imag()))) +#else .skip_comparison(x.real() == 0.0f) +#endif .strict_signed_zeros(false) .build(); }; @@ -175,8 +185,8 @@ UNARY_TEST_COMPLEX_64(Rsqrt, { if (IsGpu()) { error_spec_gen = +[](complex64 x) { return ErrorSpec::Builder() - .abs_err(RsqrtCpuGpuAbsErr(x)) - .rel_err(RsqrtCpuGpuRelErr(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) .strict_signed_zeros(false) .build(); }; @@ -286,24 +296,38 @@ UNARY_TEST_COMPLEX_128(Sqrt, { }) UNARY_TEST_COMPLEX_128(Rsqrt, { - ErrorSpecGen error_spec_gen = +[](complex128 x) { - // As noted above for Sqrt, the accuracy of sqrt degrades severely for - // inputs with inputs with subnormals entries. - constexpr double norm_min = std::numeric_limits::min(); - constexpr double denorm_min = std::numeric_limits::denorm_min(); - if (std::abs(x) < norm_min) { - // Gradually loosen the relative tolerance as abs(x) becomes smaller - // than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min. + ErrorSpecGen error_spec_gen = +[](complex128) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu()) { + error_spec_gen = +[](complex128 x) { return ErrorSpec::Builder() - .abs_err(std::sqrt(std::numeric_limits::min())) - .rel_err(10 * denorm_min / std::abs(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) +#ifdef __aarch64__ + // TODO(b/365620546): ARM and x86 handle complex(inf, nan) + // differently. + .skip_comparison(x.real() == 0.0f || + (std::isinf(x.real()) && std::isnan(x.imag()))) +#else + .skip_comparison(x.real() == 0.0f) +#endif + .strict_signed_zeros(false) .build(); - } - return ErrorSpec::Builder() - .abs_err(std::sqrt(std::numeric_limits::min())) - .rel_err(50 * std::numeric_limits::epsilon()) - .build(); - }; + }; + } + + if (IsGpu()) { + error_spec_gen = +[](complex128 x) { + return ErrorSpec::Builder() + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) + .strict_signed_zeros(false) + .build(); + }; + } + Run( Rsqrt, [](complex128 x) { return complex128(1, 0) / std::sqrt(x); }, error_spec_gen); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc index 9320e0c9ccedbc..5baa12f15ca455 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc @@ -211,6 +211,20 @@ UNARY_TEST(Sin, { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); }) + .CpuArmError(+[](NativeT val) { + // Flushes subnormals and minimum positive output to 0. + NativeT output = static_cast(std::sin(val)); + // TODO(b/365622116): Understand why ARM flushes these but x86 doesn't. + if (IsSubnormalOrMinNormal(output)) { + return ErrorSpec::Builder() + .abs_err(std::numeric_limits::min()) + .build(); + } + + // This error spec corresponds to a maximum relative error of 2 ULP. + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); + }) .OutputRangeCheck( +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }) .Run(); @@ -222,6 +236,20 @@ UNARY_TEST(Tan, { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build(); }) + .CpuArmError(+[](NativeT val) { + // Flushes positive subnormals and minimum positive output to 0. + NativeT output = static_cast(std::tan(val)); + // TODO(b/365622116): Understand why ARM flushes these but x86 doesn't. + if (IsSubnormalOrMinNormal(output)) { + return ErrorSpec::Builder() + .abs_err(std::numeric_limits::min()) + .build(); + } + + // This error spec corresponds to a maximum relative error of 4 ULP. + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build(); + }) .Run(); }) From 398037c847b4631b9dd8ca79e4d0657ea185b16a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 10:44:58 -0700 Subject: [PATCH 372/483] Move expensive variables on their last use to avoid copies. PiperOrigin-RevId: 679648192 --- tensorflow/core/kernels/linalg/einsum_op_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index 79c7e3f1729f2f..156bb80a02e1a9 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -603,7 +603,7 @@ class EinsumOp : public OpKernel { Tensor output; OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( ctx, output_inflated, output_permutation, &output)); - ctx->set_output(0, output); + ctx->set_output(0, std::move(output)); } string TraceString(const OpKernelContext& ctx, bool verbose) const override { From c88d62485bd2ead5c96fad16b9759fdcfdc02074 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 11:13:52 -0700 Subject: [PATCH 373/483] Removes the scaling coefficient for our solver-specific parameter `max_deterministic_time`. PiperOrigin-RevId: 679659247 --- .../hlo/experimental/auto_sharding/auto_sharding_solver.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index c8321348ab7c26..8fca1bc7b81ab7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -427,9 +427,8 @@ AutoShardingSolverResult CallORToolsSolver( ",share_binary_clauses:false,random_seed:1,interleave_search:true"); } if (request.has_solver_timeout()) { - absl::StrAppend( - &solver_parameter_str, ",max_deterministic_time:", - 0.1 * request.solver_timeout().solver_timeout_in_seconds()); + absl::StrAppend(&solver_parameter_str, ",max_deterministic_time:", + request.solver_timeout().solver_timeout_in_seconds()); } solver->SetSolverSpecificParametersAsString(solver_parameter_str); } From 91bb94cdfa6c22dcb6a4ce460717b9450307e561 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 11:26:43 -0700 Subject: [PATCH 374/483] Reverts d736421e0465b072962d1c309822c3f7deaccd48 PiperOrigin-RevId: 679664358 --- tensorflow/compiler/mlir/lite/BUILD | 6 +- .../compiler/mlir/lite/flatbuffer_import.cc | 258 ++---------------- .../tests/flatbuffer2mlir/debug_metadata.mlir | 36 --- 3 files changed, 18 insertions(+), 282 deletions(-) delete mode 100644 tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 2c26cf3cc5f166..7bb70a19f4f116 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1373,16 +1373,15 @@ cc_library( ], deps = [ ":const_tensor_utils", - ":control_edges", ":convert_type", ":flatbuffer_tflite_operator_lib", ":offset_buffer", ":size_utils", ":tensorflow_lite", + "//tensorflow/compiler/mlir/lite:control_edges", "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", - "//tensorflow/compiler/mlir/lite/schema:debug_metadata_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_utils", @@ -1399,9 +1398,7 @@ cc_library( "//tensorflow/core/platform:errors", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Support", @@ -1414,7 +1411,6 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@stablehlo//:stablehlo_ops", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 97dfa6ce44a7ec..a289126d26b6ca 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -29,9 +29,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "llvm/ADT/APFloat.h" @@ -82,7 +80,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/offset_buffer.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/schema/mutable/debug_metadata_generated.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" @@ -102,7 +99,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -134,17 +130,6 @@ using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; using ::tflite::IsValidBufferOffset; -struct DebugMetadata { - // Debug metadata locations. - std::vector debug_metadata_locations; - - // Maps from operator (subgraph_debug_metadata_idx, - // operator_debug_metadata_idx) to its top-level location index in - // `debug_metadata_locations`, which is: - // <, location_idx>. - absl::flat_hash_map> operator_location_map; -}; - // Create the MLIR NamedLoc location corresponding to a given tensor Location TensorLoc(const TensorT& tensor, Builder builder, Location base) { if (tensor.name.empty()) { @@ -153,223 +138,27 @@ Location TensorLoc(const TensorT& tensor, Builder builder, Location base) { return mlir::NameLoc::get(builder.getStringAttr(tensor.name), base); } -// Build and return the MLIR location. -StatusOr BuildLocation( - Builder builder, const debug_metadata::Location& location, - const std::vector& debug_metadata_locations, - const absl::flat_hash_map& - attribute_location_idx_map) { - switch (location.location_type()) { - // FileLineColLoc. - case debug_metadata::LocationType_FileLineColLoc: { - auto file_line_col_loc = - static_cast( - location.location()); - return mlir::FileLineColLoc::get( - builder.getContext(), - builder.getStringAttr(file_line_col_loc->filename()->string_view()), - file_line_col_loc->line(), file_line_col_loc->column()); - } - // CallSiteLoc. - case debug_metadata::LocationType_CallSiteLoc: { - auto callsite_loc = - static_cast(location.location()); - if (!attribute_location_idx_map.contains(callsite_loc->callee_index()) || - !attribute_location_idx_map.contains(callsite_loc->caller_index())) { - return absl::InternalError( - "Invalid/corrupt DebugMetadata, expected invariant broken (callee " - "or caller index of a CallSiteLoc is not valid)"); - } - return mlir::CallSiteLoc::get( - debug_metadata_locations[attribute_location_idx_map.at( - callsite_loc->callee_index())], - debug_metadata_locations[attribute_location_idx_map.at( - callsite_loc->caller_index())]); - } - // NameLoc. - case debug_metadata::LocationType_NameLoc: { - auto name_loc = - static_cast(location.location()); - if (!attribute_location_idx_map.contains(name_loc->child_index())) { - return absl::InternalError( - "Invalid/corrupt DebugMetadata, expected invariant broken (child " - "index of a NameLoc is not valid)"); - } - return mlir::NameLoc::get( - builder.getStringAttr(name_loc->name()->string_view()), - debug_metadata_locations[attribute_location_idx_map.at( - name_loc->child_index())]); - } - // FusedLoc. - case debug_metadata::LocationType_FusedLoc: { - auto fused_loc = - static_cast(location.location()); - auto fused_location_indexes = fused_loc->location_indexes(); - std::vector fused_locations; - fused_locations.reserve(fused_location_indexes->size()); - for (int fused_loc_idx = 0; - fused_loc_idx < fused_location_indexes->size(); ++fused_loc_idx) { - if (!attribute_location_idx_map.contains( - fused_location_indexes->Get(fused_loc_idx))) { - return absl::InternalError( - "Invalid/corrupt DebugMetadata, expected invariant broken " - "(location index of a FusedLoc is not valid)"); - } - fused_locations.push_back( - debug_metadata_locations[attribute_location_idx_map.at( - fused_location_indexes->Get(fused_loc_idx))]); - } - return mlir::FusedLoc::get( - fused_locations, mlir::StringAttr::get(builder.getContext(), ""), - builder.getContext()); - } - default: { - return mlir::UnknownLoc::get(builder.getContext()); - } - } -} - -// Parses all locations in ConversionDebugMetadata, build the mlir::location -// counterparts, and put them inside debug_metadata_. Additionally, maintain a -// map that maps the top location index of each operator. -Status ParseAndBuildLocation( - Builder builder, - const debug_metadata::ConversionDebugMetadata* conversion_debug_metadata, - DebugMetadata& debug_metadata_var) { - auto attribute_types = conversion_debug_metadata->attributes_type(); - auto attributes = conversion_debug_metadata->attributes(); - - auto& debug_metadata_locations = debug_metadata_var.debug_metadata_locations; - debug_metadata_locations.reserve(attribute_types->size()); - - // Map index in the attribute_vector to the index in the data structure we - // are building: DebugMetadata::debug_metadata_locations. - absl::flat_hash_map attribute_location_idx_map; - - for (int i = 0; i < attribute_types->size(); ++i) { - if (attribute_types->Get(i) == debug_metadata::Attribute_Location) { - auto location = - static_cast(attributes->Get(i)); - TF_ASSIGN_OR_RETURN( - auto mlir_location, - BuildLocation(builder, *location, debug_metadata_locations, - attribute_location_idx_map)); - debug_metadata_locations.push_back(mlir_location); - - // Create index mapping. - attribute_location_idx_map[i] = debug_metadata_locations.size() - 1; - } - } - - // Collect the top location idx of each operator. - auto subgraphs_debug_metadata = - conversion_debug_metadata->subgraphs_debug_metadata(); - for (int subgraph_idx = 0; subgraph_idx < subgraphs_debug_metadata->size(); - ++subgraph_idx) { - const auto* subgraph_debug_metadata = - subgraphs_debug_metadata->Get(subgraph_idx); - auto operators_debug_metadata = - subgraph_debug_metadata->operators_debug_metadata(); - for (int operator_idx = 0; operator_idx < operators_debug_metadata->size(); - ++operator_idx) { - const auto* operator_debug_metadata = - operators_debug_metadata->Get(operator_idx); - // Find the location attribute of the operator. Note that there should - // be at most one idx pointing to location attribute for each operator. - std::vector location_attribute_idxs; - for (int i = 0; - i < operator_debug_metadata->attribute_metadata_indexes()->size(); - ++i) { - auto attribute_idx = - operator_debug_metadata->attribute_metadata_indexes()->Get(i); - if (attribute_types->Get(attribute_idx) == - debug_metadata::Attribute_Location) { - location_attribute_idxs.push_back(attribute_idx); - } - } - if (location_attribute_idxs.size() > 1) { - return absl::InternalError( - "Invalid/corrupt DebugMetadata, expected invariant broken (more " - "than one location attribute for an operator)"); - } - if (location_attribute_idxs.empty()) { - continue; - } - - if (!attribute_location_idx_map.contains(location_attribute_idxs[0])) { - return absl::InternalError( - "Invalid/corrupt DebugMetadata, expected invariant broken " - "(location attribute index of an operator is not valid)"); - } - debug_metadata_var.operator_location_map[subgraph_idx][operator_idx] = - attribute_location_idx_map[location_attribute_idxs[0]]; - } - } - - return absl::OkStatus(); -} - -// Parse the DebugMetadata flatbuffer and store debug metadata in struct -// `debug_metadata`. -Status ParseDebugMetadata(Builder builder, const char* data, size_t size, - DebugMetadata& debug_metadata_var) { - auto debug_metadata_fb = debug_metadata::GetDebugMetadata(data); - - if (debug_metadata_fb->debug_metadata_type()->size() != - debug_metadata_fb->debug_metadata()->size()) { - return absl::InternalError( - "Invalid/corrupt DebugMetadata, expected invariant broken (size of " - "debug_metadata_type and debug_metadata not equal)"); - } - - for (int i = 0; i < debug_metadata_fb->debug_metadata_type()->size(); ++i) { - if (debug_metadata_fb->debug_metadata_type()->Get(i) == - debug_metadata::DebugMetadataType_ConversionDebugMetadata) { - auto conversion_debug_metadata = - static_cast( - debug_metadata_fb->debug_metadata()->Get(i)); - TF_RETURN_IF_ERROR(ParseAndBuildLocation( - builder, conversion_debug_metadata, debug_metadata_var)); - } else { - LOG(WARNING) << "Unsupported DebugMetadataType: " - << debug_metadata_fb->debug_metadata_type()->Get(i); - } - } - - return absl::OkStatus(); -} - -// Return MLIR location if it exists in the debug metadata. Otherwise, create a -// MLIR location by fusing its output tensor names. -Location OpLoc(const OperatorT& op, Builder builder, - DebugMetadata& debug_metadata, const tflite::SubGraphT& subgraph, - Location base) { - const int subgraph_debug_metadata_idx = subgraph.debug_metadata_index; - if (debug_metadata.operator_location_map.contains( - subgraph_debug_metadata_idx) && - debug_metadata.operator_location_map[subgraph_debug_metadata_idx] - .contains(op.debug_metadata_index)) { - int location_idx = - debug_metadata.operator_location_map[subgraph_debug_metadata_idx] - [op.debug_metadata_index]; - return debug_metadata.debug_metadata_locations[location_idx]; - } - +// Create the MLIR Location corresponding to a given op. This is an +// experimental/debugging feature and production code should not rely on names +// of intermediate tensors since importer doesn't guarantee to preserve tensor +// names except output tensors. +Location OpLoc(const OperatorT& op, + const std::vector>& tensors, + Builder builder, Location base) { if (op.outputs.empty()) return base; llvm::SmallVector locations; locations.reserve(op.outputs.size()); for (auto tensor_index : op.outputs) { - locations.push_back( - TensorLoc(*subgraph.tensors[tensor_index], builder, base)); + locations.push_back(TensorLoc(*tensors[tensor_index], builder, base)); } return mlir::FusedLoc::get(builder.getContext(), locations); } // Extract the min max information in the tensor and create the quant stats op. -// If the input `tensor` has scale/zero_point, `res` should have quantized type, -// thus none stats op is required and nullptr is returned. If the min max -// information is invalid, nullptr is returned. +// If the input `tensor` has scale/zero_point, `res` should have quantized +// type, thus none stats op is required and nullptr is returned. +// If the min max information is invalid, nullptr is returned. mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, Value res) { // If the `tensor` has scale/zero_point, it must have been quantized, then the @@ -889,8 +678,8 @@ StatusOr ConvertOp( } // While the last several tensors could be optional tensors for an tfl op, the - // number of input operands could vary. Gets the min/max number of operands - // from tflite op name. + // number of input operands could vary. Gets the min/max number of + // operands from tflite op name. // Also, since the above code special-handles the `tfl.reshape` op and add an // additional input, we put these function block here. llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name); @@ -1328,7 +1117,7 @@ StatusOr ConvertSubgraph( const tflite::SignatureDefT* signature, const tflite::ControlEdges& control_edges, const std::unique_ptr& model_ptr, - bool use_stablehlo_constant, DebugMetadata& debug_metadata) { + bool use_stablehlo_constant) { // Populate from metadata. ControlNodes control_nodes; for (const auto [from, to] : control_edges) { @@ -1512,12 +1301,11 @@ StatusOr ConvertSubgraph( TF_ASSIGN_OR_RETURN( mlir::TensorType type, tfl::GetTensorType(*subgraph.tensors[intermediate], builder, - /*is_constant=*/false, - /*is_intermediate=*/true)); + /*is_constant=*/false, /*is_intermediate=*/true)); intermediate_types.emplace_back(type); } - auto op_loc = OpLoc(*op, builder, debug_metadata, subgraph, base_loc); + auto op_loc = OpLoc(*op, subgraph.tensors, builder, base_loc); // If there's an optional argument, maybe_optional_arg_marker has been set // to a valid Value @@ -1747,7 +1535,6 @@ OwningOpRef tflite::FlatBufferToMlir( llvm::SmallVector metadata_attrs; mlir::StringSet<> seen_attr; - DebugMetadata debug_metadata; for (const auto& metadata : model->metadata) { if (metadata->name == tflite::kModelControlDependenciesMetadataKey) { const std::vector& data = model->buffers[metadata->buffer]->data; @@ -1772,17 +1559,6 @@ OwningOpRef tflite::FlatBufferToMlir( continue; } - if (metadata->name == "debug_metadata") { - const std::vector& data = model->buffers[metadata->buffer]->data; - auto status = ParseDebugMetadata( - builder, reinterpret_cast(data.data()), data.size(), - debug_metadata); - if (!status.ok()) { - return emitError(base_loc, std::string(status.message())), nullptr; - } - continue; - } - std::vector buffer = model->buffers[metadata->buffer]->data; metadata_attrs.emplace_back( builder.getStringAttr(metadata->name), @@ -1842,7 +1618,7 @@ OwningOpRef tflite::FlatBufferToMlir( ? subgraph_to_signature_map.at(subgraph_index) : nullptr, model_control_dependencies[subgraph_index], model_ptr, - use_stablehlo_constant, debug_metadata); + use_stablehlo_constant); if (!func_or_error.ok()) { return emitError(base_loc, "could not translate function ") << subgraph->name << ": " << func_or_error.status().message(), diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir deleted file mode 100644 index 61df9ad531515a..00000000000000 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir +++ /dev/null @@ -1,36 +0,0 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer --serialize-debug-metadata=true %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --mlir-print-debuginfo -o - | FileCheck %s -// This test verifies that debug locations are round-trippable. - -module @jit_relu attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, tfl._legalize_tfl_variables = true} { - func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { - %0 = "tfl.less"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> loc(#loc) - // CHECK-DAG: {{.*}} = tfl.less(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> loc([[LOC:.+]]) - %1 = "tf.If"(%0, %arg0, %arg1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> loc(#loc) - // CHECK-DAG: {{.*}} = "tf.If"(%0, %arg0, %arg1) {{.*}} -> tensor<1xf32> loc([[LOC]]) - func.return %1 : tensor<1xf32> loc(#loc) - } - - func.func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc(#loc4) - // CHECK-DAG: {{.*}} = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc([[LOC4:.+]]) - func.return %0 : tensor<*xf32> loc(#loc) - } - - func.func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc(#loc5) - // CHECK-DAG: {{.*}} = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc([[LOC5:.+]]) - func.return %0 : tensor<*xf32> loc(#loc) - } -} loc(#loc) -#loc = loc(unknown) -// CHECK-DAG: [[LOC]] = loc(unknown) -#loc1 = loc("":1:4) -// CHECK-DAG: [[LOC1:.+]] = loc("":1:4) -#loc2 = loc("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3066:16) -// CHECK-DAG: [[LOC2:.+]] = loc("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3066:16) -#loc3 = loc(callsite(#loc1 at #loc2)) -// CHECK-DAG: [[LOC3:.+]] = loc(callsite([[LOC1]] at [[LOC2]])) -#loc4 = loc("jit(relu)/jit(main)/max"(#loc3)) -// CHECK-DAG: [[LOC4]] = loc("jit(relu)/jit(main)/max"([[LOC3]])) -#loc5 = loc(fused<"">[#loc1, #loc2]) -// CHECK-DAG: [[LOC5]] = loc(fused<"">[[[LOC1]], [[LOC2]]]) \ No newline at end of file From dd300644722249493ae762b1e2fffab498532388 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 27 Sep 2024 11:30:42 -0700 Subject: [PATCH 375/483] Use `ShardyCallInliner` in XLA GPU pipeline. PiperOrigin-RevId: 679665714 --- third_party/xla/xla/service/gpu/BUILD | 1 + third_party/xla/xla/service/gpu/gpu_compiler.cc | 7 ++++--- .../xla/xla/service/spmd/shardy/shardy_call_inliner.h | 4 ++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index dbac4b1661ab61..90e0374e4b47ed 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1465,6 +1465,7 @@ cc_library( "//xla/service/gpu/transforms:triton_fusion_numerics_verifier", "//xla/service/gpu/transforms:windowed_einsum_handler", "//xla/service/llvm_ir:llvm_util", + "//xla/service/spmd/shardy:shardy_call_inliner", "//xla/service/spmd:collective_permute_motion", "//xla/service:algebraic_simplifier", "//xla/service:all_gather_broadcast_reorder", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 0cb1e5161bd363..1f0b2b82509fde 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -224,6 +224,7 @@ limitations under the License. #include "xla/service/slice_sinker.h" #include "xla/service/slow_operation_alarm.h" #include "xla/service/sort_simplifier.h" +#include "xla/service/spmd/shardy/shardy_call_inliner.h" #include "xla/service/stable_sort_expander.h" #include "xla/service/stochastic_convert_decomposer.h" #include "xla/service/sub_byte_normalization.h" @@ -551,7 +552,7 @@ absl::Status RunPreSPMDPartitionerPasses(HloModule* hlo_module) { // passes. pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); - pre_spmd_pipeline.AddPass(); + pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); @@ -708,7 +709,7 @@ absl::Status RunOptimizationPasses( pipeline.AddPass(); // TODO(b/64094172): make Call work on GPU instead of inlining. - pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); @@ -1585,7 +1586,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( options.key_value_store, gpu_target_config.device_description.runtime_version())); // Inline back the calls which have better performance with cuBLAS. - pipeline.AddPass(); + pipeline.AddPass(); // TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated // here for possibly better cuBLAS performance. AddGemmRewriterPasses(pipeline, debug_options, gpu_version, diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.h b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.h index 666e168322b5ab..9dbc52682a60a2 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.h +++ b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.h @@ -46,6 +46,10 @@ namespace xla { // sure we inline all functions except for the shmap_body's when using // Shardy. When Shardy is disabled, then we have the same behavior as // CallInliner. +// +// TODO(bartchr): Move the logic in here into the regular XLA `CallInliner`. +// Shardy is now proven out so we should have the parent `CallInliner` handle +// this. class ShardyCallInliner : public CallInliner { public: using CallInliner::CallInliner; From 19f3badefe887cbc98de85529267cb110ac518f5 Mon Sep 17 00:00:00 2001 From: Eric Salo Date: Fri, 27 Sep 2024 11:34:51 -0700 Subject: [PATCH 376/483] cleanup: remove api_version from BUILD files PiperOrigin-RevId: 679667421 --- tensorflow/core/protobuf/tpu/BUILD | 6 ------ tensorflow/tools/proto_splitter/BUILD | 2 -- tensorflow/tools/proto_splitter/testdata/BUILD | 1 - 3 files changed, 9 deletions(-) diff --git a/tensorflow/core/protobuf/tpu/BUILD b/tensorflow/core/protobuf/tpu/BUILD index 78d8761223b60b..f714172ae2b490 100644 --- a/tensorflow/core/protobuf/tpu/BUILD +++ b/tensorflow/core/protobuf/tpu/BUILD @@ -83,42 +83,36 @@ tf_pyclif_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "tpu_embedding_configuration_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":tpu_embedding_configuration_proto"], # ) # # py_proto_library( # name = "optimization_parameters_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":optimization_parameters_proto"], # ) # # py_proto_library( # name = "topology_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":topology_proto"], # ) # # py_proto_library( # name = "dynamic_padding_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":dynamic_padding_proto"], # ) # # py_proto_library( # name = "compilation_result_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":compilation_result_proto"], # ) # # py_proto_library( # name = "compile_metadata_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":compile_metadata_proto"], # ) diff --git a/tensorflow/tools/proto_splitter/BUILD b/tensorflow/tools/proto_splitter/BUILD index b05a7c3688ff58..5aa26d8e9aae33 100644 --- a/tensorflow/tools/proto_splitter/BUILD +++ b/tensorflow/tools/proto_splitter/BUILD @@ -53,7 +53,6 @@ cc_library( # # py_proto_library( # name = "versions_proto_py_pb2", -# api_version = 2, # deps = [ # ":versions_proto", # ], @@ -61,7 +60,6 @@ cc_library( # # py_proto_library( # name = "chunk_proto_py_pb2", -# api_version = 2, # deps = [ # ":chunk_proto", # ], diff --git a/tensorflow/tools/proto_splitter/testdata/BUILD b/tensorflow/tools/proto_splitter/testdata/BUILD index 5ab95c1313e7a4..11c508c8783fbb 100644 --- a/tensorflow/tools/proto_splitter/testdata/BUILD +++ b/tensorflow/tools/proto_splitter/testdata/BUILD @@ -53,7 +53,6 @@ tf_proto_library( # # py_proto_library( # name = "test_message_proto_py_pb2", -# api_version = 2, # deps = [ # ":test_message_proto", # ], From 51bb43dcc09307e1f13ef939c0bc7c849943831f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 27 Sep 2024 11:41:49 -0700 Subject: [PATCH 377/483] [xla:cpu] Prefer sequential execution from small thunk sequences PiperOrigin-RevId: 679670532 --- .../xla/xla/backends/cpu/runtime/thunk_executor.cc | 13 ++++++++++--- .../xla/xla/backends/cpu/runtime/thunk_executor.h | 11 ++++++++--- .../xla/backends/cpu/runtime/thunk_executor_test.cc | 6 ++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc index 62df8d8f5b12a7..028f97e3f3e887 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc @@ -60,6 +60,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, sink_.push_back(i); } } + // Erase redundant edges between nodes. int64_t num_erased_edges = RunTransitiveReductionAndUpdatePriorities(); @@ -69,7 +70,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, is_sequential_ &= (absl::c_count(nodes_defs_[i].in_edges, i - 1) != 0); } - // Maybe mark execution as sequential if all thunks use small buffers. + // Prefer sequential execution if all thunks use small buffers. auto uses_small_buffers = [&](const std::unique_ptr& thunk) { return absl::c_all_of(thunk->buffer_uses(), [&](const BufferUse& use) { return use.slice().size() <= options.execute_sequential_buffer_threshold; @@ -79,6 +80,10 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, bool small_buffers = absl::c_all_of(thunk_sequence_, uses_small_buffers); is_sequential_ |= small_buffers; + // Prefer sequential execution for small thunk sequences. + is_sequential_ |= + thunk_sequence_.size() <= options.execute_sequential_num_thunks_threshold; + VLOG(2) << absl::StreamFormat( "Constructed ThunkExecutor with %d nodes: #source_nodes=%d " "#sink_nodes=%d, #erased_edges=%d, is_sequential=%v, small_buffers=%v", @@ -159,8 +164,10 @@ tsl::AsyncValueRef ThunkExecutor::Execute( return thunk_sequence_[0]->Execute(params); } - // If thunk sequence dependencies form a sequential execution graph, we skip - // expensive async execution and simply run thunks one by one. + // When we choose sequential execution strategy (we rely on heuristics and + // a cost model to make the decision), we skip expensive async execution and + // simply run thunks one by one. This minimizes runtime overheads from small + // XLA programs with many cheap operations. if (is_sequential_) { return ExecuteSequential(params); } diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h index 5ba15b0432b504..c85fe39fdc9b15 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h @@ -42,11 +42,16 @@ namespace internal { // Clang does not allow defining a nested struct with member initializer, as // a workaround we define a struct in internal namespace and create an alias. struct ThunkExecutorOptions { - // If all thunks in a sequence use buffers of size less than or equal to - // `execute_sequential_buffer_threshold`, we mark execution as sequential, as - // concurrency overheads will likely dominate the overall execution time. + // If all thunks in a sequence use buffers of size less than or equal to the + // given threshold, we mark execution as sequential, as concurrency overheads + // will likely dominate the overall execution time. size_t execute_sequential_buffer_threshold = 512; + // If thunk sequence length is less than or equal to the given threshold, we + // mark execution as sequential, as concurrency overheads will likely dominate + // the overall execution time. + size_t execute_sequential_num_thunks_threshold = 8; + // Use priority ready queue to execute nodes according to their priority. By // default we use FIFO ready queue. bool use_priority_ready_queue = false; diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc index d0deefe5d4880d..2a97d0dd0c48ce 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc @@ -213,10 +213,8 @@ AddI32Thunk::ResourceUses AddI32Thunk::resource_uses() const { } static ThunkExecutor::Options OptionsForTest() { - // Override small buffers threshold to make sure that we test all execution - // paths, because in test we always use small buffers below the default - // threshold of `512`. - return ThunkExecutor::Options{/*execute_sequential_buffer_threshold=*/0}; + return ThunkExecutor::Options{/*execute_sequential_buffer_threshold=*/0, + /*execute_sequential_num_thunks_threshold=*/0}; } TEST(ThunkExecutorTest, FifoReadyQueueTest) { From 7946c19b622ad44f84a5ee3b631a85e5376346de Mon Sep 17 00:00:00 2001 From: Eric Salo Date: Fri, 27 Sep 2024 12:05:39 -0700 Subject: [PATCH 378/483] cleanup: remove api_version from BUILD files PiperOrigin-RevId: 679679636 --- tensorflow/lite/profiling/proto/BUILD | 2 -- tensorflow/lite/toco/BUILD | 3 --- tensorflow/lite/toco/logging/BUILD | 1 - tensorflow/lite/tools/BUILD | 1 - tensorflow/lite/tools/evaluation/proto/BUILD | 1 - 5 files changed, 8 deletions(-) diff --git a/tensorflow/lite/profiling/proto/BUILD b/tensorflow/lite/profiling/proto/BUILD index 907ac4df3e0e3b..4ce67e6947d0bf 100644 --- a/tensorflow/lite/profiling/proto/BUILD +++ b/tensorflow/lite/profiling/proto/BUILD @@ -42,14 +42,12 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "profiling_info_py_pb2", -# api_version = 2, # compatible_with = get_compatible_with_portable(), # deps = [":profiling_info_proto"], # ) # # py_proto_library( # name = "model_runtime_info_py_pb2", -# api_version = 2, # compatible_with = get_compatible_with_portable(), # deps = [":model_runtime_info_proto"], # ) diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 19644b6678110b..2c2e5e41081a9c 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -519,21 +519,18 @@ tf_cc_test( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "model_flags_proto_py", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":model_flags_proto"], # ) # # py_proto_library( # name = "toco_flags_proto_py", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":toco_flags_proto"], # ) # # py_proto_library( # name = "types_proto_py", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":toco_flags_proto"], # ) diff --git a/tensorflow/lite/toco/logging/BUILD b/tensorflow/lite/toco/logging/BUILD index 2b9f9205f86022..06c83facb5f977 100644 --- a/tensorflow/lite/toco/logging/BUILD +++ b/tensorflow/lite/toco/logging/BUILD @@ -107,7 +107,6 @@ py_strict_test( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "toco_conversion_log_proto_py", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":toco_conversion_log_proto"], # ) diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index b9260be0b9eac3..350e7d1a33ad6e 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -465,7 +465,6 @@ tflite_portable_test_suite() # # py_proto_library( # name = "op_kernel_set_py_pb2", -# api_version = 2, # deps = [":op_kernel_set_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/lite/tools/evaluation/proto/BUILD b/tensorflow/lite/tools/evaluation/proto/BUILD index 696876aae1f1ac..524cf5962d4d6f 100644 --- a/tensorflow/lite/tools/evaluation/proto/BUILD +++ b/tensorflow/lite/tools/evaluation/proto/BUILD @@ -92,7 +92,6 @@ cc_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "evaluation_stages_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":evaluation_stages_proto"], # ) From 773c296448128d156f872207b521a8597fc5ffaf Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 27 Sep 2024 12:19:48 -0700 Subject: [PATCH 379/483] [xla:ffi] Add support for encoding mlir::DictionaryAttr PiperOrigin-RevId: 679684896 --- third_party/xla/docs/custom_call.md | 4 +- .../backends/cpu/runtime/custom_call_thunk.cc | 2 +- third_party/xla/xla/ffi/BUILD | 2 + third_party/xla/xla/ffi/api/ffi_test.cc | 8 +- third_party/xla/xla/ffi/attribute_map.cc | 203 +++++++++--------- third_party/xla/xla/ffi/attribute_map.h | 2 +- third_party/xla/xla/ffi/call_frame.cc | 15 +- third_party/xla/xla/ffi/call_frame.h | 15 +- third_party/xla/xla/ffi/call_frame_test.cc | 6 +- third_party/xla/xla/ffi/ffi_test.cc | 6 +- .../service/cpu/runtime_handle_ffi_call.cc | 2 +- .../xla/xla/service/gpu/custom_call_test.cc | 28 ++- .../service/gpu/runtime/custom_call_thunk.h | 4 +- 13 files changed, 161 insertions(+), 136 deletions(-) diff --git a/third_party/xla/docs/custom_call.md b/third_party/xla/docs/custom_call.md index 2471df68331057..1bd39c0e070405 100644 --- a/third_party/xla/docs/custom_call.md +++ b/third_party/xla/docs/custom_call.md @@ -267,8 +267,8 @@ struct Range { int64_t hi; }; -XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember("i64"), - StructMember("i64")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember("lo"), + StructMember("hi")); auto handler = Ffi::Bind().Attr("range").To([](Range range) -> Error{ return Error::Success(); diff --git a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc index 8ce3106213b07f..8f693a1e3c5378 100644 --- a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -60,7 +60,7 @@ limitations under the License. namespace xla::cpu { namespace { -using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; +using AttributesMap = ffi::CallFrameBuilder::AttributesMap; absl::StatusOr ParseAttributes( absl::string_view backend_config) { diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index f8105e33fffe54..2e163a8fcbc6ee 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -180,6 +180,7 @@ cc_library( hdrs = ["attribute_map.h"], deps = [ ":call_frame", + "//xla:util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -187,6 +188,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index bea5176a560b6e..74837790c8449c 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -822,10 +822,10 @@ TEST(FfiTest, AttrsAsDictionary) { } TEST(FfiTest, DictionaryAttr) { - CallFrameBuilder::FlatAttributesMap dict0; + CallFrameBuilder::AttributesMap dict0; dict0.try_emplace("i32", 42); - CallFrameBuilder::FlatAttributesMap dict1; + CallFrameBuilder::AttributesMap dict1; dict1.try_emplace("f32", 42.0f); CallFrameBuilder::AttributesBuilder attrs; @@ -864,7 +864,7 @@ TEST(FfiTest, DictionaryAttr) { } TEST(FfiTest, StructAttr) { - CallFrameBuilder::FlatAttributesMap dict; + CallFrameBuilder::AttributesMap dict; dict.try_emplace("i32", 42); dict.try_emplace("f32", 42.0f); @@ -977,7 +977,7 @@ TEST(FfiTest, EnumAttr) { } TEST(FfiTest, WrongEnumAttrType) { - CallFrameBuilder::FlatAttributesMap dict; + CallFrameBuilder::AttributesMap dict; dict.try_emplace("i32", 42); CallFrameBuilder::AttributesBuilder attrs; diff --git a/third_party/xla/xla/ffi/attribute_map.cc b/third_party/xla/xla/ffi/attribute_map.cc index 33d756f6fcce2b..79275d11a19b62 100644 --- a/third_party/xla/xla/ffi/attribute_map.cc +++ b/third_party/xla/xla/ffi/attribute_map.cc @@ -16,7 +16,9 @@ limitations under the License. #include "xla/ffi/attribute_map.h" #include +#include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -27,122 +29,123 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "xla/ffi/call_frame.h" #include "tsl/platform/errors.h" - -using FlatAttribute = xla::ffi::CallFrameBuilder::FlatAttribute; -using FlatAttributesMap = xla::ffi::CallFrameBuilder::FlatAttributesMap; +#include "tsl/platform/statusor.h" namespace xla::ffi { -absl::StatusOr BuildAttributesMap( - mlir::DictionaryAttr dict) { - FlatAttributesMap attributes; - for (auto& kv : dict) { - std::string_view name = kv.getName().strref(); +static absl::StatusOr ConvertBoolAttr( + std::string_view name, mlir::BoolAttr boolean) { + return static_cast(boolean.getValue()); +} - auto boolean = [&](mlir::BoolAttr boolean) { - attributes[name] = static_cast(boolean.getValue()); - return absl::OkStatus(); - }; +static absl::StatusOr ConvertStringAttr( + std::string_view name, mlir::StringAttr str) { + return str.getValue().str(); +} - auto integer = [&](mlir::IntegerAttr integer) { - if (integer.getType().isUnsignedInteger()) { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 8: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 16: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 32: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", - name)); - } - } else { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 8: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 16: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 32: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", - name)); - } - } - }; +static absl::StatusOr ConvertIntegerAttr( + std::string_view name, mlir::IntegerAttr integer) { + if (integer.getType().isUnsignedInteger()) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 8: + return static_cast(integer.getUInt()); + case 16: + return static_cast(integer.getUInt()); + case 32: + return static_cast(integer.getUInt()); + case 64: + return static_cast(integer.getUInt()); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + } else { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 8: + return static_cast(integer.getInt()); + case 16: + return static_cast(integer.getInt()); + case 32: + return static_cast(integer.getInt()); + case 64: + return static_cast(integer.getInt()); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + } +} - auto fp = [&](mlir::FloatAttr fp) { - switch (fp.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(fp.getValue().convertToFloat()); - return absl::OkStatus(); - case 64: - attributes[name] = - static_cast(fp.getValue().convertToDouble()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported float attribute bit width for attribute: ", name)); - } - }; +static absl::StatusOr ConvertFloatAttr( + std::string_view name, mlir::FloatAttr fp) { + switch (fp.getType().getIntOrFloatBitWidth()) { + case 32: + return static_cast(fp.getValue().convertToFloat()); + case 64: + return static_cast(fp.getValue().convertToDouble()); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported float attribute bit width for attribute: ", name)); + } +} - auto arr = [&](mlir::DenseArrayAttr arr) { - if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else { - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported array element type for attribute: ", name)); - } - }; +static absl::StatusOr ConvertArrayAttr( + std::string_view name, mlir::DenseArrayAttr arr) { + if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported array element type for attribute: ", name)); + } +} - auto str = [&](mlir::StringAttr str) { - attributes[name] = str.getValue().str(); - return absl::OkStatus(); +static absl::StatusOr ConvertDictionaryAttr( + std::string_view name, mlir::DictionaryAttr dict) { + TF_ASSIGN_OR_RETURN(auto attrs, BuildAttributesMap(dict)); + return CallFrameBuilder::Dictionary{ + std::make_shared(std::move(attrs))}; +} + +absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict) { + CallFrameBuilder::AttributesMap attributes; + for (auto& kv : dict) { + std::string_view name = kv.getName().strref(); + mlir::Attribute value = kv.getValue(); + + // Wraps attribute conversion function into callable object. + auto convert_with = [&](auto converter_fn) { + return [&, fn = converter_fn](auto attr) -> absl::Status { + TF_ASSIGN_OR_RETURN(attributes[name], fn(name, attr)); + return absl::OkStatus(); + }; }; TF_RETURN_IF_ERROR( - llvm::TypeSwitch(kv.getValue()) - .Case(boolean) - .Case(integer) - .Case(fp) - .Case(arr) - .Case(str) + llvm::TypeSwitch(value) + .Case(convert_with(ConvertBoolAttr)) + .Case(convert_with(ConvertIntegerAttr)) + .Case(convert_with(ConvertFloatAttr)) + .Case(convert_with(ConvertArrayAttr)) + .Case(convert_with(ConvertStringAttr)) + .Case(convert_with(ConvertDictionaryAttr)) .Default([&](mlir::Attribute) { return absl::InvalidArgumentError(absl::StrCat( "Unsupported attribute type for attribute: ", name)); })); } + return attributes; } + } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/attribute_map.h b/third_party/xla/xla/ffi/attribute_map.h index cb9415ff3eb09b..43ad41888772cb 100644 --- a/third_party/xla/xla/ffi/attribute_map.h +++ b/third_party/xla/xla/ffi/attribute_map.h @@ -24,7 +24,7 @@ namespace xla::ffi { // Converts MLIR dictionary attribute attached to a custom call operation to a // custom call handler attributes that are forwarded to the FFI handler. -absl::StatusOr BuildAttributesMap( +absl::StatusOr BuildAttributesMap( mlir::DictionaryAttr dict); } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc index 655aa6a02f69a2..12fed1ba745440 100644 --- a/third_party/xla/xla/ffi/call_frame.cc +++ b/third_party/xla/xla/ffi/call_frame.cc @@ -65,20 +65,17 @@ CallFrameBuilder::AttributesBuilder::AttributesBuilder() = default; CallFrameBuilder::AttributesBuilder::~AttributesBuilder() = default; void CallFrameBuilder::AttributesBuilder::Insert(std::string name, - FlatAttribute attr) { - attrs_.try_emplace(std::move(name), FromFlatAttribute(std::move(attr))); + Attribute attr) { + attrs_.try_emplace(std::move(name), std::move(attr)); } void CallFrameBuilder::AttributesBuilder::Insert(std::string name, - FlatAttributesMap attrs) { - AttributesBuilder builder; - for (auto& [name, attr] : attrs) builder.Insert(name, std::move(attr)); - - auto attrs_map = std::make_unique(builder.Build()); - attrs_.try_emplace(std::move(name), Dictionary{std::move(attrs_map)}); + AttributesMap attrs) { + attrs_.try_emplace(std::move(name), + Dictionary{std::make_shared(attrs)}); } -void CallFrameBuilder::AttributesBuilder::Append(FlatAttributesMap attrs) { +void CallFrameBuilder::AttributesBuilder::Append(AttributesMap attrs) { for (auto& [name, attr] : attrs) Insert(name, std::move(attr)); } diff --git a/third_party/xla/xla/ffi/call_frame.h b/third_party/xla/xla/ffi/call_frame.h index 526723b3a92d80..0614bd750fd29e 100644 --- a/third_party/xla/xla/ffi/call_frame.h +++ b/third_party/xla/xla/ffi/call_frame.h @@ -81,9 +81,10 @@ class CallFrameBuilder { using AttributesMap = absl::flat_hash_map; // Dictionary is just a wrapper around AttributesMap. We need an indirection - // through `std::unique_ptr` to be able to define recursive `std::variant`. + // through `std::shared_ptr` to be able to define recursive `std::variant`. We + // use shared pointer to keep `AttributesMap` copyable. struct Dictionary { - std::unique_ptr attrs; + std::shared_ptr attrs; }; // A helper class to build call frame attributes. @@ -92,14 +93,14 @@ class CallFrameBuilder { AttributesBuilder(); ~AttributesBuilder(); + void Insert(std::string name, Attribute attr); + void Insert(std::string name, AttributesMap attrs); + void Append(AttributesMap attrs); + // This overload is only necessary to support older GCC versions. void Insert(std::string name, const char* attr) { - Insert(std::move(name), std::string(attr)); + Insert(std::move(name), Attribute{std::string(attr)}); } - void Insert(std::string name, FlatAttribute attr); - void Insert(std::string name, FlatAttributesMap attrs); - - void Append(FlatAttributesMap attrs); AttributesMap Build(); diff --git a/third_party/xla/xla/ffi/call_frame_test.cc b/third_party/xla/xla/ffi/call_frame_test.cc index 7b767bfb841af8..89d306455e6a19 100644 --- a/third_party/xla/xla/ffi/call_frame_test.cc +++ b/third_party/xla/xla/ffi/call_frame_test.cc @@ -130,14 +130,14 @@ void BM_AddBufferArg(benchmark::State& state) { void BM_AddAttributes(benchmark::State& state) { size_t num_attrs = state.range(0); - CallFrameBuilder::FlatAttributesMap flat_attrs; + CallFrameBuilder::AttributesMap attrs; for (size_t i = 0; i < num_attrs; ++i) { - flat_attrs.try_emplace(absl::StrCat("attr_", i), 42); + attrs.try_emplace(absl::StrCat("attr_", i), 42); } for (auto _ : state) { CallFrameBuilder::AttributesBuilder attrs_builder; - attrs_builder.Append(flat_attrs); + attrs_builder.Append(attrs); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); builder.AddAttributes(attrs_builder.Build()); diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 128bf997ae263f..f7a310c5b8e61d 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -379,10 +379,10 @@ TEST(FfiTest, AttrsAsDictionary) { } TEST(FfiTest, DictionaryAttr) { - CallFrameBuilder::FlatAttributesMap dict0; + CallFrameBuilder::AttributesMap dict0; dict0.try_emplace("i32", 42); - CallFrameBuilder::FlatAttributesMap dict1; + CallFrameBuilder::AttributesMap dict1; dict1.try_emplace("f32", 42.0f); CallFrameBuilder::AttributesBuilder attrs; @@ -421,7 +421,7 @@ TEST(FfiTest, DictionaryAttr) { } TEST(FfiTest, StructAttr) { - CallFrameBuilder::FlatAttributesMap dict; + CallFrameBuilder::AttributesMap dict; dict.try_emplace("i32", 42); dict.try_emplace("f32", 42.0f); diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc index 6b07b41aad00cc..874d9b3fe1b508 100644 --- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc @@ -114,7 +114,7 @@ static absl::Status BuildAndCallFfi( } // For FFI handlers backend config must be a compatible MLIR dictionary. - ffi::CallFrameBuilder::FlatAttributesMap attributes; + ffi::CallFrameBuilder::AttributesMap attributes; if (!backend_config.empty() && backend_config != "{}") { // Backend config not empty, so proceed to parse it into an MLIR attribute // and build an MLIR compatible map of attributes out of it. diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc index f377c0e0ae3dd8..5877ab9e03f4ce 100644 --- a/third_party/xla/xla/service/gpu/custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/custom_call_test.cc @@ -78,6 +78,20 @@ limitations under the License. #define gpuMemcpyHostToDevice hipMemcpyHostToDevice #endif +namespace xla { + +struct Range { + int64_t lo; + int64_t hi; +}; + +} // namespace xla + +// Register struct types with XLA:FFI to enable automatic decoding from +// dictionary attributes to structs. +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(::xla::Range, StructMember("lo"), + StructMember("hi")); + namespace xla { namespace { @@ -614,20 +628,26 @@ TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) { //===----------------------------------------------------------------------===// static absl::Status FfiAttributes(ffi::Result, - absl::Span i32_arr) { + absl::Span i32_arr, + Range range) { if (i32_arr.size() != 4) return absl::InternalError("i32_arr size does not match"); if (i32_arr[0] != 1 || i32_arr[1] != 2 || i32_arr[2] != 3 || i32_arr[3] != 4) return absl::InternalError("i32_arr values do not match"); + if (range.lo != 0 || range.hi != 42) { + return absl::InternalError("range values do not match"); + } + return absl::OkStatus(); } XLA_FFI_DEFINE_HANDLER(kFfiAttributes, FfiAttributes, ffi::Ffi::Bind() .Ret() - .Attr>("i32_arr")); + .Attr>("i32_arr") + .Attr("range")); XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla.gpu.ffi_attributes", PLATFORM, kFfiAttributes); @@ -636,7 +656,9 @@ TEST_F(CustomCallTest, FfiAttributes) { XlaBuilder b(TestName()); CustomCall(&b, "xla.gpu.ffi_attributes", /*operands=*/{}, ShapeUtil::MakeShape(F32, {}), - /*opaque=*/"{ i32_arr = array }", + /*opaque=*/ + "{ i32_arr = array," + " range = { lo = 0 : i64, hi = 42 : i64 } }", /*has_side_effect=*/false, /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h index e67b9e89d3a867..925d2c04abc27b 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h @@ -77,8 +77,8 @@ class CustomCallThunk : public Thunk { Shape shape; }; - using Attribute = ffi::CallFrameBuilder::FlatAttribute; - using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; + using Attribute = ffi::CallFrameBuilder::Attribute; + using AttributesMap = ffi::CallFrameBuilder::AttributesMap; static absl::StatusOr> Create( ThunkInfo thunk_info, CustomCallTarget call_target, From 606ebefc9fde643ca8f7a78a175749f6fec31ee4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 12:27:53 -0700 Subject: [PATCH 380/483] Rename CallSolver --> CreateAutoShardingSolverRequestAndCallSolver and CallORToolsSolver --> FormulateAndSolveMIPFromAutoShardingSolverRequest to better capture the function implementation. PiperOrigin-RevId: 679688096 --- .../auto_sharding/auto_sharding.cc | 4 +- .../auto_sharding/auto_sharding_impl.cc | 14 +-- .../auto_sharding/auto_sharding_solver.cc | 2 +- .../auto_sharding/auto_sharding_solver.h | 2 +- .../auto_sharding_solver_test.cc | 88 ++++++++++++------- .../auto_sharding/auto_sharding_wrapper.h | 2 +- 6 files changed, 66 insertions(+), 46 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 06141018c5b8a3..d69764435604cc 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1747,7 +1747,7 @@ std::unique_ptr CreateReshapeStrategies( return strategy_group; } -AutoShardingSolverResult CallSolver( +AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, @@ -1969,7 +1969,7 @@ AutoShardingSolverResult CallSolver( PopulateTemporalValues(cost_graph, request); - return CallORToolsSolver(request); + return FormulateAndSolveMIPFromSolverRequest(request); } void CheckHloSharding( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index e0fdd6ad71bbf8..7a92ac5715039a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -48,13 +48,13 @@ AutoShardingSolverResult Solve( const AutoShardingOption& option, absl::string_view request_prefix, const absl::flat_hash_map& sharding_propagation_solution) { - return CallSolver(hlo_module, hlo_live_range, strategy_map, strategy_groups, - cost_graph, alias_set, node_intervals, edge_intervals, - node_groups, edge_groups, /*s_hint*/ {}, - /*compute_iis*/ true, option.solver_timeout_in_seconds, - option, /*max_cost*/ std::nullopt, request_prefix, - sharding_propagation_solution, - /*deterministic mode*/ true); + return CreateAutoShardingSolverRequestAndCallSolver( + hlo_module, hlo_live_range, strategy_map, strategy_groups, cost_graph, + alias_set, node_intervals, edge_intervals, node_groups, edge_groups, + /*s_hint*/ {}, + /*compute_iis*/ true, option.solver_timeout_in_seconds, option, + /*max_cost*/ std::nullopt, request_prefix, sharding_propagation_solution, + /*deterministic mode*/ true); } void PopulateTemporalValues(const CostGraph& cost_graph, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 8fca1bc7b81ab7..cf18e7a5c56c6e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -399,7 +399,7 @@ void AddMemoryTerms( // can be a few (usually < 10) edges in the problem with negative costs. This // is guaranteed to never produce a negative overall cost for the graph, // however. -AutoShardingSolverResult CallORToolsSolver( +AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& unscaled_request) { const absl::Time start_time = absl::Now(); const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index cb051f7718fd44..88884f7286d0b6 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -47,7 +47,7 @@ struct AutoShardingSolverResult { bool skip_auto_sharding; }; -AutoShardingSolverResult CallORToolsSolver( +AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& request); enum AutoShardingViolationCode { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 81c02acd354bd5..3e0c82d3b75510 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -250,10 +250,11 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() { return request; } -TEST(CallORToolsSolverTest, SolvesOptimally) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -262,12 +263,13 @@ TEST(CallORToolsSolverTest, SolvesOptimally) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, SolvesOverbudget) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.set_memory_budget(100000); request.mutable_overbudget_coeff()->set_coeff(10.0); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 9007650.0; @@ -276,11 +278,12 @@ TEST(CallORToolsSolverTest, SolvesOverbudget) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, SolvesMaxDepartures) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_departures()->set_coeff(3.0); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -289,11 +292,12 @@ TEST(CallORToolsSolverTest, SolvesMaxDepartures) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, MinimizesDepartures) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.set_minimize_departures(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 1, 0, 0, 1}; const double objective_value = 3.0; @@ -302,13 +306,14 @@ TEST(CallORToolsSolverTest, MinimizesDepartures) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_computation_costs(0)->set_costs(0, kInfinityCost); request.mutable_computation_costs(0)->set_costs(1, kInfinityCost); request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {3, 0, 0, 0, 0}; const double objective_value = 10683.0; @@ -317,11 +322,12 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -330,7 +336,7 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesFollowedEdges) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); AutoShardingSolverRequest_Pair edge; edge.set_first(1); @@ -346,7 +352,8 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { 70000, 71000, 72000, 73000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 12650.0; @@ -355,7 +362,7 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesCollapsedEdge) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); AutoShardingSolverRequest_Pair edge; edge.set_first(2); @@ -373,7 +380,8 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) { 80000, 81000, 82000, 83000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 13972.0; @@ -382,12 +390,13 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, UsesHint) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const auto s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close. request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -396,20 +405,22 @@ TEST(CallORToolsSolverTest, UsesHint) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HonorsMaxCost) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0 - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); EXPECT_TRUE(absl::IsInternal(result.status.status())); } -TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(1e19); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -418,7 +429,7 @@ TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const EdgeMatrix live_edges = {{}, {0}, {0, 1}, {1}, {}}; const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300, @@ -432,7 +443,8 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -441,7 +453,7 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesIntervals) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}}; @@ -460,7 +472,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -469,7 +482,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesReducedIntervalsAndGroups) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -492,7 +506,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -501,7 +516,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -511,7 +527,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { AddGroups(request.mutable_node_groups(), node_groups); request.set_enable_memory_edge_costs(false); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -520,7 +537,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesGroupsWithTinyMemoryCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -551,7 +569,8 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { request.set_enable_memory_edge_costs(true); request.set_memory_budget(4321); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -560,11 +579,12 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, SolvesWithEquivalences) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) { const AutoShardingSolverRequest request = AutoShardingSolverRequestWithEquivalences(); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 5, 5, 1}; const double objective_value = 7650.0; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index 069fde4e14c580..f9058802eea52d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -41,7 +41,7 @@ namespace spmd { // A wrapper around the solver that converts the given objects into a // combinatorial optimization problem & solves it. -AutoShardingSolverResult CallSolver( +AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, From 5705625cf84fee2d6b6fdeb150265ce26aa2f0dd Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 27 Sep 2024 12:43:00 -0700 Subject: [PATCH 381/483] [XLA:GPU][IndexAnalysis] Remove is_simplified flag. The benchmarks don't show a lot of improvements in compile time. PiperOrigin-RevId: 679693188 --- .../service/gpu/fusions/ir/tests/attrs.mlir | 30 +- .../gpu/fusions/ir/tests/canonicalize.mlir | 60 +-- .../service/gpu/fusions/ir/tests/invalid.mlir | 74 +-- .../xla/service/gpu/fusions/ir/tests/ops.mlir | 20 +- .../service/gpu/fusions/ir/xla_gpu_attrs.cc | 14 +- .../service/gpu/fusions/ir/xla_gpu_attrs.td | 5 +- .../xla/service/gpu/fusions/ir/xla_gpu_ops.cc | 15 +- .../gpu/fusions/legacy/concatenate_test.cc | 3 +- .../in_place_dynamic_update_slice_test.cc | 3 +- .../gpu/fusions/legacy/input_slices_test.cc | 3 +- .../service/gpu/fusions/legacy/loop_test.cc | 15 +- .../gpu/fusions/legacy/reduction_test.cc | 6 +- .../gpu/fusions/legacy/scatter_test.cc | 6 +- .../gpu/fusions/legacy/transpose_test.cc | 24 +- .../mlir/elemental_hlo_to_mlir_test.cc | 60 +-- .../transforms/tests/flatten_tensors.mlir | 10 +- .../fusions/transforms/tests/fuse_loops.mlir | 36 +- .../tests/lower_xla_gpu_loops_to_scf.mlir | 6 +- .../tests/lower_xla_gpu_to_scf.mlir | 22 +- .../transforms/tests/optimize_loops.mlir | 7 +- .../fusions/transforms/tests/peel_loops.mlir | 14 +- .../transforms/tests/simplify_affine.mlir | 8 +- .../transforms/tests/simplify_arith.mlir | 9 +- .../tests/vectorize_loads_stores.mlir | 37 +- .../triton_fusion_emitter_device_test.cc | 14 +- .../service/gpu/model/indexing_analysis.cc | 8 +- .../gpu/model/indexing_analysis_test.cc | 489 ++++++------------ .../xla/xla/service/gpu/model/indexing_map.cc | 23 +- .../xla/xla/service/gpu/model/indexing_map.h | 12 +- .../gpu/model/indexing_map_serialization.cc | 53 +- .../model/indexing_map_serialization_test.cc | 21 +- .../service/gpu/model/indexing_map_test.cc | 404 +++++---------- .../gpu/model/symbolic_tile_analysis_test.cc | 45 +- 33 files changed, 570 insertions(+), 986 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir index b990103ea2cfab..6a199f5f024241 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir @@ -8,8 +8,7 @@ // CHECK-SAME: d2 in [10, 12], // CHECK-SAME: s0 in [0, 32], // CHECK-SAME: d0 + s0 in [1, 10], -// CHECK-SAME: d0 mod 2 in [0, 1], -// CHECK-SAME: is_simplified: true" +// CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: > #map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," "domain:" @@ -18,8 +17,7 @@ "d2 in [10, 12]," "s0 in [0, 32]," "d0 mod 2 in [0, 1]," - "d0 + s0 in [1, 10]," - "is_simplified: true" + "d0 + s0 in [1, 10]" > func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>) @@ -39,7 +37,6 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map> // CHECK-SAME: d0 + s0 in [1, 10] // CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: d1 + s1 + s2 in [1, 32] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map< "(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)," @@ -51,8 +48,7 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map> "s2 in [0, 32]," "d0 mod 2 in [0, 1]," "d0 + s0 in [1, 10]," - "d1 + s1 + s2 in [1, 32]," - "is_simplified: false" + "d1 + s1 + s2 in [1, 32]" > func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // CHECK-LABEL: @more_range_vars @@ -65,13 +61,11 @@ func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: d0 in [0, 100] // CHECK-SAME: s0 in [-3, -1] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0)," "domain:" "d0 in [0, 100]," - "s0 in [-3, -1]," - "is_simplified: false" + "s0 in [-3, -1]" > func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @indexing_map_small @@ -86,15 +80,13 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: d1 in [5, 8] // CHECK-SAME: d2 in [10, 12] // CHECK-SAME: s0 in [0, 32] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," "domain:" "d0 in [1, 2]," "d1 in [5, 8]," "d2 in [10, 12]," - "s0 in [0, 32]," - "is_simplified: false" + "s0 in [0, 32]" > func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // CHECK-LABEL: @no_constraints @@ -107,13 +99,11 @@ func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: s0 in [3, 5] // CHECK-SAME: s0 mod 2 in [0, 1] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map<"()[s0] -> (s0)," "domain:" "s0 in [3, 5]," - "s0 mod 2 in [0, 1]," - "is_simplified: false" + "s0 mod 2 in [0, 1]" > func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @no_dimensions @@ -126,13 +116,11 @@ func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: d0 in [3, 5] // CHECK-SAME: d0 mod 2 in [0, 1] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map<"(d0) -> (d0)," "domain:" "d0 in [3, 5]," "d0 mod 2 in [0, 1]," - "is_simplified: false" > func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @no_symbols @@ -152,8 +140,6 @@ func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>) func.func private @tensor_layout( %in0: tensor<42xf32, #xla_gpu.layout<"shmem", - "(d0) -> ()," - "domain: d0 in [0, 42], is_simplified: true">>) -// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (), -// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true"> + "(d0) -> ()," "domain: d0 in [0, 42]">>) +// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (), domain: // CHECK: tensor<42xf32, #layout> diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir index bfca90e5c64f53..08086e34f60b05 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir @@ -1,13 +1,12 @@ // RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s -#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2]"> func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1] func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1, d0 mod 2), -// CHECK-SAME: domain: d0 in [-10, 10] -// CHECK-SAME: is_simplified: true"> +// CHECK-SAME: domain: d0 in [-10, 10]"> // CHECK-LABEL: func.func @simplify_apply_indexing // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) @@ -15,7 +14,7 @@ func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { // ----- -#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]"> func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, %d2: index, %s0: index, %s1: index) -> (index, index, index) { %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1] @@ -35,16 +34,7 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, // ----- -#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 10), domain: d0 in [0, 9], is_simplified: true"> -func.func @do_not_simplify_if_is_simplified_is_true(%d0: index) -> (index) { - %0 = xla_gpu.apply_indexing #map0(%d0) - func.return %0 : index -} -// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 10) - -// ----- - -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) -> (index, index, index, index, index) { %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] @@ -64,7 +54,7 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) // ----- #map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)," - "domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false"> + "domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#2 : index @@ -81,8 +71,7 @@ func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) // ----- #map0 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)," - "domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]," - "is_simplified: false"> + "domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]"> func.func @fold_operands(%d0: index) -> index { %d1 = arith.constant 1 : index %s0 = arith.constant 2 : index @@ -102,7 +91,7 @@ func.func @fold_operands(%d0: index) -> index { func.func @fold_operands_and_results(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (0, d1)," - "domain: d0 in [0, 4], d1 in [0, 5], is_simplified: false">(%arg0, %arg1) + "domain: d0 in [0, 4], d1 in [0, 5]">(%arg0, %arg1) return %0#0, %0#1 : index, index } @@ -115,10 +104,9 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index) func.func @fold_sequence(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< - "(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]," - "is_simplified: false">(%arg0, %arg1) + "(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 100 + 42)," - "domain: d0 in [0, 10000], is_simplified: false">(%0) + "domain: d0 in [0, 10000]">(%0) func.return %1 : index } @@ -133,10 +121,9 @@ func.func @fold_sequence(%arg0: index, %arg1: index) -> index { func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), " - "domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false">(%arg0, %arg1) + "domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map< - "()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000]," - "is_simplified: false">(%0) + "()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000]">(%0) func.return %1 : index } @@ -150,10 +137,10 @@ func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { // ----- #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0 + 8512)," - "domain: d0 in [0, 1], d1 in [0, 607], is_simplified: false"> + "domain: d0 in [0, 1], d1 in [0, 607]"> #indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (" "((d1 floordiv 32 + 1) mod 3) * 64 + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2)," - "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> func.func @fold_sequence_no_simplification_needed(%i: index) -> index { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} @@ -167,11 +154,11 @@ func.func @fold_sequence_no_simplification_needed(%i: index) -> index { // ----- #indexing_map1 = #xla_gpu.indexing_map< - "(d0) -> (3 * d0), domain: d0 in [0, 9407], is_simplified: false"> + "(d0) -> (3 * d0), domain: d0 in [0, 9407]"> #indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 1)," - "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> #indexing_map3 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 2)," - "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} @@ -186,9 +173,9 @@ func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," - "domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false">(%arg0, %arg1) + "domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," - "domain: d0 in [0, 4], d1 in [0, 10000], is_simplified: false">(%arg1, %0) + "domain: d0 in [0, 4], d1 in [0, 10000]">(%arg1, %0) func.return %1 : index } @@ -234,7 +221,7 @@ func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index) // ----- #map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," - "domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 2]"> func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) -> index { %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] @@ -249,10 +236,9 @@ func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) // // ----- -#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]," - "is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]"> #map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1)," - "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %idx = xla_gpu.apply_indexing #map0(%dim) %sum = xla_gpu.loop (%idx)[%i, %j] -> (%r0, %r1) in #map1 iter_args(%sum_ = %init) -> (f32) { @@ -273,9 +259,9 @@ func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: // ----- #map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," - "domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 2]"> #map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0 + s1)," - "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: tensor<1024x32xf32>, %init: f32) -> (f32) { %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] %sum = xla_gpu.loop (%0)[%i, %j] -> (%r0) in #map1 iter_args(%sum_ = %init) -> (f32) { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir index 3c50b5afcd8068..35064858b23150 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir @@ -1,6 +1,6 @@ // RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}} %0:2 = xla_gpu.apply_indexing #map0 (%d0) @@ -9,7 +9,7 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // ----- -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], d0 mod 2 in [0, 1], d0 + s0 in [1, 10], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], d0 mod 2 in [0, 1], d0 + s0 in [1, 10]"> func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{apply indexing op cannot have any constraints}} %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] @@ -18,7 +18,7 @@ func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index // ----- -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (f32) { // expected-error @+1 {{mismatch in number of loop-carried values and results}} @@ -36,7 +36,7 @@ func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, // ----- -#map = #xla_gpu.indexing_map<"()[s0] -> (s0, s0), domain: s0 in [0, 1024], is_simplified: false"> +#map = #xla_gpu.indexing_map<"()[s0] -> (s0, s0), domain: s0 in [0, 1024]"> func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (f32) { // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}} @@ -54,7 +54,7 @@ func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, // ----- -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) { // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}} @@ -72,7 +72,7 @@ func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}} @@ -87,7 +87,7 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> { @@ -99,8 +99,8 @@ func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @no_thread_id_in(%input: tensor<32x64xf32>, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -112,8 +112,8 @@ func.func @no_thread_id_in(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -125,8 +125,8 @@ func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32]"> func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{thread_id dimension must have the same bounds in both indexing maps}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -135,8 +135,8 @@ func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: inde // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -149,8 +149,8 @@ func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{number of symbols in both indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -159,8 +159,8 @@ func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, % // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{domain of symbols of indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -169,8 +169,8 @@ func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -182,8 +182,8 @@ func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -195,8 +195,8 @@ func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -209,8 +209,8 @@ func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64]"> func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -222,8 +222,8 @@ func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -236,8 +236,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -250,8 +250,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -264,8 +264,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 mod 16 + s0, d1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 mod 16 + s0, d1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -277,8 +277,8 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 mod 16, d1, d2), domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 mod 16, d1, d2), domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir index 81e08968db7590..f6fd03d8f1ed24 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir @@ -57,7 +57,7 @@ func.func @caller(%a: f32, %b: f32) -> f32 { // ----- #map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0)," - "domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], is_simplified: false"> + "domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1 : index, index @@ -78,7 +78,7 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // ----- #map0 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," - "domain: d0 in [0, 2], d1 in [1, 3], is_simplified: false"> + "domain: d0 in [0, 2], d1 in [1, 3]"> func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1) func.return %0#0, %0#1 : index, index @@ -98,7 +98,7 @@ func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { // ----- #map0 = #xla_gpu.indexing_map<"()[s0] -> (s0, s0)," - "domain: s0 in [2, 4], is_simplified: false"> + "domain: s0 in [2, 4]"> func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 [%s0] func.return %0#0, %0#1 : index, index @@ -116,7 +116,7 @@ func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), " - "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) @@ -141,11 +141,11 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 #map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)," - "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> #map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1)," - "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> #map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," - "domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> + "domain: d0 in [0, 32], d1 in [0, 2]"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -161,7 +161,7 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, // CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] // CHECK: #[[$MAP2:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1) -// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], +// CHECK-SAME: d0 in [0, 32], d1 in [0, 2]"> // CHECK-LABEL: @materialize_and_insert // CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at // CHECK-SAME: #[[$MAP]](%{{.*}}, %{{.*}}) @@ -216,7 +216,7 @@ func.func @reduce_middle_dim(%in: tensor<16x8x4xf32>, %init: f32) // ----- #map = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1)," - "domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false"> + "domain: d0 in [0, 15], d1 in [0, 63]"> func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { %0 = xla_gpu.reindex %in0 at #map : tensor<1024xf32> -> tensor<16x64xf32> func.return %0 : tensor<16x64xf32> @@ -231,7 +231,7 @@ func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { // ----- #map = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1)," - "domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false"> + "domain: d0 in [0, 15], d1 in [0, 63]"> func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> { %c0 = arith.constant 0.0 : f32 %0 = xla_gpu.reindex %in0 at #map default %c0 diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index 8a0380b0706f75..535a24fd55788a 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -48,8 +48,7 @@ using mlir::success; // Parses a chain of string attributes into an indexing map. // Example: // "()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)," -// " domain: s0 in [-10, 10], s1 in [0, 2]," -// " is_simplified: false" +// " domain: s0 in [-10, 10], s1 in [0, 2]" // will be parsed as 3 StringAttrs, concatenated into a single string, and then // parsed into an IndexingMap. std::optional parseChainOfStringsAsIndexingMap( @@ -84,17 +83,16 @@ IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, constraints.push_back({constraint.first, constraint.second}); } return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(), - indexing_map.GetRangeVars(), constraints, - indexing_map.IsSimplified()); + indexing_map.GetRangeVars(), constraints); } mlir::LogicalResult IndexingMapAttr::verify( mlir::function_ref emitError, mlir::AffineMap map, ArrayRef dim_vars, ArrayRef range_vars, - ArrayRef> constraints, bool is_simplified) { - auto indexing_map = IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, - constraints, is_simplified); + ArrayRef> constraints) { + auto indexing_map = + IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, constraints); std::stringstream ss; if (!indexing_map.Verify(ss)) { return emitError() << ss.str(); @@ -104,7 +102,7 @@ mlir::LogicalResult IndexingMapAttr::verify( IndexingMap IndexingMapAttr::getIndexingMap() const { return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{}, - getConstraints(), getIsSimplified()); + getConstraints()); } int64_t IndexingMapAttr::getNumResults() const { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td index 44e8dd4353a5b6..f42a2254558724 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td @@ -36,8 +36,6 @@ def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar", "RangeVarArray"> { } -def XLAGPU_BoolParameter : AttrOrTypeParameter<"bool", ""> {} - def XLAGPU_ConstraintsParameter : ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>", "ContraintsArray"> { @@ -52,8 +50,7 @@ def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { let parameters = (ins XLAGPU_AffineMapParameter:$map, XLAGPU_DimVarsParameter:$dim_vars, XLAGPU_RangeVarsParameter:$range_vars, - XLAGPU_ConstraintsParameter:$constraints, - XLAGPU_BoolParameter:$is_simplified); + XLAGPU_ConstraintsParameter:$constraints); let hasCustomAssemblyFormat = 1; let builders = [ AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>, diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index a4724eb8b5c9f6..2aa00180e326b1 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -183,13 +183,13 @@ ParseResult ApplyIndexingOp::parse(OpAsmParser& parser, parser.parseOptionalAttrDict(result.attributes)) { return failure(); } - auto map = indexing_map_attr.getMap(); + auto map = indexing_map_attr.getIndexingMap().GetAffineMap(); result.addTypes(SmallVector(map.getNumResults(), index_type)); return success(); } void ApplyIndexingOp::print(OpAsmPrinter& p) { - AffineMap affine_map = getIndexingMapAttr().getMap(); + AffineMap affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); p << " " << getIndexingMapAttr(); auto operands = getOperands(); @@ -214,14 +214,14 @@ void ApplyIndexingOp::print(OpAsmPrinter& p) { } LogicalResult ApplyIndexingOp::verify() { - auto affine_map = getIndexingMapAttr().getMap(); + auto affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols(); if (getOperands().size() != num_variables) { return emitOpError( "operand count must match the number of dimensions and symbols in the " "affine map"); } - if (!getIndexingMapAttr().getConstraints().empty()) { + if (!getIndexingMap().GetConstraints().empty()) { return emitOpError("apply indexing op cannot have any constraints"); } return success(); @@ -310,11 +310,10 @@ struct SimplifyIndexingMap : public mlir::OpRewritePattern { LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, PatternRewriter& rewriter) const override { IndexingMap indexing_map = indexing_op.getIndexingMap(); - if (indexing_map.IsSimplified()) { + if (!indexing_map.Simplify()) { return rewriter.notifyMatchFailure(indexing_op, "IndexingMap is already simplified"); } - indexing_map.Simplify(); rewriter.replaceOpWithNewOp( indexing_op, indexing_op.getOperands(), indexing_map); return success(); @@ -1046,12 +1045,12 @@ LogicalResult MaterializeOp::verify() { //===----------------------------------------------------------------------===// LogicalResult InsertOp::verify() { - if (!getMap().getRangeVars().empty()) { + if (!getMap().getIndexingMap().GetRangeVars().empty()) { return emitOpError() << "insert_op map must not have any symbols"; } int64_t vector_map_num_results = getSource().getType().getIndexingMapAttr().getNumResults(); - if (vector_map_num_results != getMap().getDimVars().size()) { + if (vector_map_num_results != getMap().getIndexingMap().GetDimVars().size()) { return emitOpError() << "source map result count must equal insert_op's " "map's dimension count"; } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc index 32437d5bca3772..ce7da7bcb22485 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc @@ -82,8 +82,7 @@ TEST_F(ConcatenateTest, ThreadIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 399], - is_simplified: true + bl_x * 128 + th_x in [0, 399] )"; EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc index 6bf9ea865e1c45..27d3aa2170be3f 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc @@ -88,8 +88,7 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc index 08fcc0d387c777..9de13b8bd7df5c 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc @@ -88,8 +88,7 @@ TEST_F(InputSlicesTest, ThreadIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 29], - is_simplified: true + bl_x * 128 + th_x in [0, 29] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc index 60ae18e5cc6a17..b23e9b2b19a213 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc @@ -93,8 +93,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 3], - bl_x * 128 + th_x in [0, 1499999], - is_simplified: true + bl_x * 128 + th_x in [0, 1499999] )")); } @@ -133,8 +132,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( @@ -152,8 +150,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); } @@ -196,8 +193,7 @@ TEST_F(LoopTest, Broadcast) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 5999], - is_simplified: true + bl_x * 128 + th_x in [0, 5999] )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( @@ -217,8 +213,7 @@ TEST_F(LoopTest, Broadcast) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 5999], - is_simplified: true + bl_x * 128 + th_x in [0, 5999] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc index 46c7a26970e538..54fff94a6ed775 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc @@ -86,8 +86,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { s0 in [0, 0], s1 in [0, 0], s2 in [0, 7], - s3 in [0, 1], - is_simplified: true + s3 in [0, 1] )")); EXPECT_THAT( ToString(*fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)), @@ -103,8 +102,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { d3 in [0, 799], d4 in [0, 0], d5 in [0, 0], - d0 mod 32 in [0, 0], - is_simplified: true + d0 mod 32 in [0, 0] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc index 7381d375645660..e7d1d8eae303c9 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc @@ -156,8 +156,7 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 8399], - is_simplified: true + bl_x * 128 + th_x in [0, 8399] )"; mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}; @@ -197,8 +196,7 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { chunk_id in [0, 0], unroll_id in [0, 0], index_id in [0, 0], - bl_x * 128 + th_x in [0, 8399], - is_simplified: true + bl_x * 128 + th_x in [0, 8399] )"; EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc index c66094061e6366..1e503025d889d3 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc @@ -95,8 +95,7 @@ TEST_F(TransposeTest, ThreadIndexing021) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), @@ -116,8 +115,7 @@ TEST_F(TransposeTest, ThreadIndexing021) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } @@ -159,8 +157,7 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), @@ -180,8 +177,7 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } @@ -225,8 +221,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { s0 in [0, 5], s1 in [0, 0], s2 in [0, 0], - d0 mod 32 in [0, 23], - is_simplified: true + d0 mod 32 in [0, 23] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), @@ -246,8 +241,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { s0 in [0, 5], s1 in [0, 0], s2 in [0, 0], - d0 mod 32 in [0, 23], - is_simplified: true + d0 mod 32 in [0, 23] )")); } @@ -322,8 +316,7 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)), @@ -343,8 +336,7 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 5c87db0045dac0..522d2653153292 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -234,10 +234,10 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) - // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 2], is_simplified: true">(%[[Y]]) + // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 2]">(%[[Y]]) // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 3), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 6], is_simplified: true">(%[[Z]], %[[I]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 6]">(%[[Z]], %[[I]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], @@ -285,7 +285,7 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // `d1 floordiv ` in the map: // CHECK: %[[K:.*]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 18], d1 in [0, 3], is_simplified: true">(%[[X]], %[[I]]) + // CHECK-SAME: d0 in [0, 18], d1 in [0, 3]">(%[[X]], %[[I]]) // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] )")); @@ -505,7 +505,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -517,9 +517,9 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true">(%[[Y]]) + // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7]">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -547,7 +547,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -559,9 +559,9 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true">(%[[Y]]) + // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7]">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -879,10 +879,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -925,10 +925,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 2], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 2], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 3], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -971,21 +971,21 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 2), - // CHECK-SAME: d0 in [0, 11], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1025,17 +1025,17 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), - // CHECK-SAME: d0 in [0, 12], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), - // CHECK-SAME: d0 in [0, 18], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1078,10 +1078,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 3], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 3], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1124,13 +1124,13 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 floordiv 8) * 2 + d1), - // CHECK-SAME: d0 in [0, 15], d1 in [0, 1], is_simplified: true">(%[[O]], %[[I]]) + // CHECK-SAME: d0 in [0, 15], d1 in [0, 1]">(%[[O]], %[[I]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1175,10 +1175,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1645,7 +1645,7 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), - // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true">(%[[X]], %[[Y]]) + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]">(%[[X]], %[[Y]]) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] )")); @@ -1669,7 +1669,7 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 // CHECK: %[[IDX:.*]] = // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), - // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true">(%[[X]], %[[Y]]) + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]">(%[[X]], %[[Y]]) // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir index e88324f698d489..d35dc71ddad023 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir @@ -8,7 +8,7 @@ func.func @tensor_extract( : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2], is_simplified: true"> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2]"> // CHECK-LABEL: func.func @tensor_extract( // CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, @@ -67,7 +67,7 @@ func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) } return %ret : tensor<2x4xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3], is_simplified: true"> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3]"> // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { @@ -114,9 +114,9 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) // ----- -#map = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> -#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> +#map = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749]"> func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) -> tensor<4000x4x9xf32> { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir index 594c8e1deec7d2..1287b8fc3e91a5 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir @@ -8,8 +8,7 @@ " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -17,8 +16,7 @@ " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -67,8 +65,7 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -76,8 +73,7 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -115,8 +111,7 @@ func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1 " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -124,8 +119,7 @@ func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1 " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -165,8 +159,7 @@ func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -174,8 +167,7 @@ func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 5], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -214,8 +206,7 @@ func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> ten " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -223,8 +214,7 @@ func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> ten " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 5) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 5) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -263,8 +253,7 @@ func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1, s2] ->" " (0," " d0 mod 32," @@ -272,8 +261,7 @@ func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_unused_loop_iv(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir index 427e764d12b914..f981cef83029d8 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir @@ -2,8 +2,7 @@ // RUN: --split-input-file | FileCheck %s #map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," - "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]," - "is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%ra, %rb) @@ -61,8 +60,7 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- #map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," - "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]," - "is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]"> func.func @loop_yields_value_from_above(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir index 347ed9a943ef82..f53ccc1e8ae54f 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir @@ -124,8 +124,8 @@ func.func @predicated_extract( func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) -> !xla_gpu.indexed_vector<32x2x2xf32, #map1> { @@ -149,8 +149,8 @@ func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -181,9 +181,9 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -199,8 +199,8 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.func private @exp(%p0: tensor<32x64xcomplex>, %i: index, %j: index) -> complex -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 2], s1 in [0, 3]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> func.func @materialize_complex( %input: tensor<32x64xcomplex>, %output: tensor<32x64xcomplex>, @@ -227,8 +227,8 @@ func.func @materialize_complex( // ----- -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @insert_complex( %input: !xla_gpu.indexed_vector<32x3x4xcomplex, #map1>, %output: tensor<32x64xcomplex>, diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir index 17f478b2838dde..1094b51a2a6841 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir @@ -1,7 +1,8 @@ // RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s -#map = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 8), domain: d0 in [0, 31], is_simplified: false"> #map1 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 8), domain: d0 in [0, 31], is_simplified: false"> -#map2 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 8), domain: d0 in [0, 31]"> +#map1 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 8), domain: d0 in [0, 31]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]"> module { func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>, @@ -150,7 +151,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15], is_simplified: false">(%i) + %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15]">(%i) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir index f965b069a772cc..9ffd7bdc0fbfd1 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir @@ -3,7 +3,7 @@ #map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain:" "d0 in [0, 3], s0 in [0, 7], s1 in [0, 10], d0 + s0 in [0, 9]," - "d0 + s1 in [0, 12], is_simplified: false"> + "d0 + s1 in [0, 12]"> func.func @peel_both_loops(%input: tensor<16x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) @@ -14,9 +14,9 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, } func.return %sum : f32 } -// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9], is_simplified: true"> -// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9], is_simplified: true"> -// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10], is_simplified: true"> +// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9]"> +// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9]"> +// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10]"> // CHECK-LABEL: func.func @peel_both_loops( // CHECK-SAME: %[[INPUT:.*]]: tensor<16x32xf32>, @@ -42,7 +42,7 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (s0)," - "domain: d0 in [0, 3], s0 in [0, 7], is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 7]"> func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i] -> (%r0) @@ -64,9 +64,7 @@ func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, " domain:" " d0 in [0, 3]," " s0 in [0, 7]," -" s0 mod 5 in [0, 1]," -" is_simplified: false" -> +" s0 mod 5 in [0, 1]"> func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i] -> (%r0) diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir index bfddbd60e2bde7..e62a530de0e7db 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir @@ -65,7 +65,7 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %2 = xla_gpu.apply_indexing #xla_gpu.indexing_map< "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))," - "domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3], is_simplified: false">[%1, %0, %i] + "domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3]">[%1, %0, %i] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -95,7 +95,7 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< "()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)," - "domain: s0 in [0, 42], s1 in [0, 1000], is_simplified: false">[%arg0, %arg1] + "domain: s0 in [0, 42], s1 in [0, 1000]">[%arg0, %arg1] return %0 : index } @@ -109,7 +109,7 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)," - "domain: s0 in [-10, 42], s1 in [0, 1000], is_simplified: false">[%arg0, %arg1] + "domain: s0 in [-10, 42], s1 in [0, 1000]">[%arg0, %arg1] return %0#0, %0#1 : index, index } @@ -128,7 +128,7 @@ func.func @order_summands(%arg1: index) { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< "()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)," - "domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3], is_simplified: false">[%arg2, %arg1, %arg3] + "domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3]">[%arg2, %arg1, %arg3] "dummy.op"(%0) : (index) -> () } } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir index b301a3bbc93a74..e6fea946e6e827 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir @@ -249,7 +249,7 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { %loop = scf.for %i = %c0 to %c3 step %c1 iter_args(%in_ = %tensor) -> (tensor<100xf32>) { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 4)," - "domain: d0 in [0, 9], is_simplified: false">(%i) + "domain: d0 in [0, 9]">(%i) %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32> scf.yield %updated :tensor<100xf32> } @@ -265,9 +265,9 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { #map = #xla_gpu.indexing_map< "(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)," - "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3], is_simplified: false"> + "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3]"> #map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)," - "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3], is_simplified: false"> + "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3]"> func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> { %c0 = arith.constant 0 : index @@ -306,8 +306,7 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, "d4 in [0, 0]," "d5 in [0, 0]," "s0 in [0, 3]," - "d0 * 4 + s0 in [0, 29]," - "is_simplified: false"> + "d0 * 4 + s0 in [0, 29]"> func.func @dus(%arg0: tensor<20x30xf32>, %arg1: tensor<5x6xf32>, %arg2: i32, %arg3: i32, %arg4: tensor<20x30xf32>) -> tensor<20x30xf32> { %c24 = arith.constant 24 : index %c15 = arith.constant 15 : index diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir index 0c734ca19882e5..ceaa3a0748cbff 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir @@ -2,7 +2,7 @@ // RUN: -xla-gpu-vectorize-loads-stores -cse -canonicalize | FileCheck %s #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -20,7 +20,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { } return %outer : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 63], is_simplified: true"> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 63]"> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -37,7 +37,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0 + 1)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -61,7 +61,7 @@ func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 3 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -85,7 +85,7 @@ func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (3 * d0 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -109,7 +109,7 @@ func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0 * 2)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -135,7 +135,7 @@ func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { // We could vectorize this as a float vector load of double the size, but we // don't currently. #map = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 + s0)," - "domain: d0 in [0, 127], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 127], s0 in [0, 1]"> func.func @simple_read_complex(%arg0: tensor<128xcomplex>, %i: index) -> (complex) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -251,11 +251,10 @@ func.func @write_not_yielded(%arg0: tensor<64xf32>) -> tensor<64xf32> { // ----- #map = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)," - "domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7], is_simplified: true"> + "domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7]"> #map1 = #xla_gpu.indexing_map< "(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512)," - "domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]," - "is_simplified: true"> + "domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]"> func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<32xf32>, %arg3: tensor<131072xf32>, %arg4: index) -> (tensor<131072xf32>, f32) { @@ -282,8 +281,8 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, } return %0#0, %0#1 : tensor<131072xf32>, f32 } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7], is_simplified: true"> -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7], is_simplified: true"> +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7]"> +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7]"> // CHECK-LABEL: @multiple // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index @@ -307,7 +306,7 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 64 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -335,7 +334,7 @@ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 65 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -359,9 +358,9 @@ func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { // ----- #map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," - "domain: d0 in [0, 63], is_simplified: true"> + "domain: d0 in [0, 63]"> #map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> module { func.func @apply_indexing_sequence(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -384,7 +383,7 @@ module { } // CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2 + 10), -// CHECK-SAME: domain: d0 in [0, 63], is_simplified: true"> +// CHECK-SAME: domain: d0 in [0, 63]"> // CHECK-LABEL: @apply_indexing_sequence // CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]] // CHECK: vector.transfer_read {{.*}}[%[[BASE]]] @@ -393,9 +392,9 @@ module { #map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," - "domain: d0 in [0, 63], is_simplified: true"> + "domain: d0 in [0, 63]"> #map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> module { func.func @apply_indexing_sequence_same_block(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index d79a54f4615681..1cb92de9436e49 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -216,7 +216,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true"> +CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 @@ -281,7 +281,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true"> +CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn( CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -352,9 +352,9 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249], is_simplified: true"> -CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249], is_simplified: true"> -CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249], is_simplified: true"> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249]"> +CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249]"> +CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[ZERO_64:.*]] = arith.constant 0 : i64 @@ -545,8 +545,8 @@ ENTRY main { TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047], is_simplified: true"> -// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047], is_simplified: true"> +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047]"> +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047]"> // CHECK-LABEL: tt.func @triton_fn( // CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr // CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index 9477826b0f801f..a8f70449e10f9b 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -528,7 +528,7 @@ HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( output_shape.dimensions(), parallel_dims_sizes); IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), - output_shape.dimensions(), {}, /*is_simplified=*/true); + output_shape.dimensions(), {}); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(reduce->operand_count()); @@ -661,8 +661,7 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( // Indexing map for the init value. IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), - output_shape.dimensions(), /*symbol_upper_bounds=*/{}, - /*is_simplified=*/true); + output_shape.dimensions(), /*symbol_upper_bounds=*/{}); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(reduce_window->operand_count()); @@ -1154,8 +1153,7 @@ IndexingMap CreateIdentityMap(absl::Span dimensions, mlir::MLIRContext* mlir_context) { return IndexingMap::FromTensorSizes( AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), - /*dim_upper_bounds=*/dimensions, /*symbol_upper_bounds=*/{}, - /*is_simplified=*/dimensions.empty()); + /*dim_upper_bounds=*/dimensions, /*symbol_upper_bounds=*/{}); } IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) { diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index b3f4043d73825f..be577ffa434b02 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -64,15 +64,13 @@ TEST_F(IndexingAnalysisTest, FuseProducerConsumerOutputToInputIndexing) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))))); } @@ -98,29 +96,25 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"))), Pair(parameter, UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"))))); } @@ -159,34 +153,29 @@ TEST_F(IndexingAnalysisTest, Pair(root, ElementsAre(MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 31], - is_simplified: false + d0 in [0, 31] )"))), Pair(root->operand(0), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 31], - s0 in [0, 39], - is_simplified: true + s0 in [0, 39] )"))), Pair(root->operand(1), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 31], - s0 in [0, 39], - is_simplified: true + s0 in [0, 39] )"))), Pair(root->operand(2), ElementsAre(MatchIndexingMap(R"( (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )"))), Pair(root->operand(3), ElementsAre(MatchIndexingMap(R"( (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )"))))); } @@ -216,8 +205,7 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing_SingleOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))))); } @@ -261,8 +249,7 @@ TEST_F(IndexingAnalysisTest, d0 in [0, 14], d1 in [0, 31], d2 in [0, 19], - d3 in [0, 63], - is_simplified: false + d3 in [0, 63] )"))), Pair(¶meter_0.instruction(), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2), @@ -270,8 +257,7 @@ TEST_F(IndexingAnalysisTest, d0 in [0, 14], d1 in [0, 31], d2 in [0, 19], - d3 in [0, 63], - is_simplified: true + d3 in [0, 63] )"))))); } @@ -291,8 +277,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -303,8 +288,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); } @@ -367,8 +351,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -379,8 +362,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); } @@ -400,8 +382,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -412,8 +393,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); } @@ -432,14 +412,12 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -448,8 +426,7 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -458,8 +435,7 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); } @@ -483,14 +459,12 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -499,8 +473,7 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -509,8 +482,7 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); } @@ -528,8 +500,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsReshape) { domain: d0 in [0, 3], d1 in [0, 7], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); } @@ -548,8 +519,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTranspose) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: true + d3 in [0, 12287] )")); } @@ -567,8 +537,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { (d0, d1) -> (d1, d0 floordiv 3, d0 mod 3), domain: d0 in [0, 50], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -577,8 +546,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { domain: d0 in [0, 15], d1 in [0, 16], - d2 in [0, 2], - is_simplified: true + d2 in [0, 2] )")); } @@ -597,8 +565,7 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -607,8 +574,7 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { domain: d0 in [0, 19], s0 in [0, 9], - s1 in [0, 29], - is_simplified: false + s1 in [0, 29] )")); } @@ -641,22 +607,19 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 4], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] operand id = 1 (d0, d1, d2) -> (d0, d1 - 5, d2), domain: d0 in [0, 1], d1 in [5, 15], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] operand id = 2 (d0, d1, d2) -> (d0, d1 - 16, d2), domain: d0 in [0, 1], d1 in [16, 32], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -666,8 +629,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 4], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -677,8 +639,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 10], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); @@ -688,8 +649,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 16], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); } @@ -721,29 +681,25 @@ TEST_F(IndexingAnalysisTest, DynamicSliceOp) { (d0, d1, d2) -> (), s2 in [0, 226], hlo: %of3 = s32[] parameter(3), - (d0, d1, d2) -> (), - is_simplified: false + (d0, d1, d2) -> () operand id = 1 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] operand id = 2 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] operand id = 3 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] )")); } @@ -764,8 +720,7 @@ TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] operand id = 1 (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1), domain: @@ -776,20 +731,17 @@ TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { (d0, d1) -> (), s1 in [0, 20], hlo: %of2 = s32[] parameter(3), - (d0, d1) -> (), - is_simplified: false + (d0, d1) -> () operand id = 2 (d0, d1) -> (), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] )")); } @@ -811,13 +763,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSingleBinaryOp) { operand id = 0 (d0) -> (d0), domain: - d0 in [0, 99], - is_simplified: true + d0 in [0, 99] operand id = 1 (d0) -> (d0), domain: - d0 in [0, 99], - is_simplified: true + d0 in [0, 99] )")); } @@ -891,8 +841,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 1 (d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0), domain: @@ -902,8 +851,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 2 (d0, d1, d2, d3, d4, d5) -> (d1), domain: @@ -912,8 +860,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d2 in [0, 2], d3 in [0, 0], d4 in [0, 5], - d5 in [0, 127], - is_simplified: true + d5 in [0, 127] operand id = 3 (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0), domain: @@ -923,8 +870,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 4 (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0), domain: @@ -934,8 +880,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 5 (d0, d1, d2, d3, d4, d5) -> (d2, d4, d5), domain: @@ -944,8 +889,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d2 in [0, 2], d3 in [0, 0], d4 in [0, 5], - d5 in [0, 127], - is_simplified: true + d5 in [0, 127] )")); } @@ -1002,16 +946,14 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSoftmax) { d0 in [0, 1], d1 in [0, 64], d2 in [0, 124], - s0 in [0, 124], - is_simplified: true + s0 in [0, 124] )"), MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2), domain: d0 in [0, 1], d1 in [0, 64], - d2 in [0, 124], - is_simplified: true + d2 in [0, 124] )")))); } @@ -1033,15 +975,13 @@ TEST_F(IndexingAnalysisTest, FusionOpTensorPlusTransposedTensor) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )")))); } @@ -1071,38 +1011,32 @@ TEST_F(IndexingAnalysisTest, FusionExponentialDuplication) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 1), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0 + 2), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )")), UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 2), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0 + 1), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )")))); } @@ -1130,8 +1064,7 @@ TEST_F(IndexingAnalysisTest, GatherOp) { (d0, d1, d2, d3) -> (d0, 0), s1 in [0, 68], hlo: %indices = s32[1806,2]{1,0} parameter(1), - (d0, d1, d2, d3) -> (d0, 1), - is_simplified: false + (d0, d1, d2, d3) -> (d0, 1) operand id = 1 (d0, d1, d2, d3)[s0] -> (d0, s0), domain: @@ -1139,8 +1072,7 @@ TEST_F(IndexingAnalysisTest, GatherOp) { d1 in [0, 6], d2 in [0, 7], d3 in [0, 3], - s0 in [0, 1], - is_simplified: false + s0 in [0, 1] )")); } @@ -1173,13 +1105,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfReduce) { d0 in [0, 9], s0 in [0, 149], s1 in [0, 49], - s2 in [0, 19], - is_simplified: true + s2 in [0, 19] operand id = 1 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); } @@ -1211,14 +1141,12 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfBroadcast) { domain: d0 in [0, 14], d1 in [0, 63], - s0 in [0, 19], - is_simplified: true + s0 in [0, 19] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 14], - d1 in [0, 63], - is_simplified: true + d1 in [0, 63] )")); } @@ -1253,8 +1181,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithTransposeOfTranspose) { domain: d0 in [0, 9], d1 in [0, 49], - d2 in [0, 19], - is_simplified: true + d2 in [0, 19] )")); } @@ -1286,13 +1213,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReducedSlice) { domain: d0 in [0, 31], s0 in [0, 15], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] operand id = 1 (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )")); } @@ -1313,8 +1238,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { operand id = 0 (d0) -> (d0), domain: - d0 in [0, 127], - is_simplified: true + d0 in [0, 127] )")); } @@ -1336,8 +1260,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } @@ -1360,8 +1283,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { domain: d0 in [0, 9], d1 in [0, 9], - d2 in [0, 9], - is_simplified: true + d2 in [0, 9] )")); } @@ -1386,8 +1308,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSliceOfSlice) { domain: d0 in [0, 6], d1 in [0, 8], - d2 in [0, 23], - is_simplified: true + d2 in [0, 23] )")); } @@ -1434,32 +1355,27 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDynSliceOfDynSlice) { (d0, d1) -> (), s3 in [0, 16], hlo: %of22 = s32[] parameter(4), - (d0, d1) -> (), - is_simplified: true + (d0, d1) -> () operand id = 1 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 4 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } @@ -1488,22 +1404,19 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfAllConcatenateOpInputs) { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 1 (d0, d1, d2) -> (d0, d1 * 3 - 5, d2), domain: d0 in [0, 1], d1 in [2, 5], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 2 (d0, d1, d2) -> (d0, d1 * 3 - 16, d2), domain: d0 in [0, 1], d1 in [6, 10], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] )")); } @@ -1532,8 +1445,7 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfOneOfConcatenateOpInputs) { domain: d0 in [0, 1], d1 in [0, 2], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 1 KNOWN EMPTY operand id = 2 @@ -1562,15 +1474,13 @@ TEST_F(IndexingAnalysisTest, FusionOpReshapeOfConcat) { domain: d0 in [0, 3], d1 in [0, 7], - d0 * 8 + d1 in [0, 1], - is_simplified: true + d0 * 8 + d1 in [0, 1] operand id = 1 (d0, d1) -> (d0 * 8 + d1 - 2), domain: d0 in [0, 3], d1 in [0, 7], - d0 * 8 + d1 in [2, 31], - is_simplified: true + d0 * 8 + d1 in [2, 31] )")); } @@ -1597,8 +1507,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpCollapseShape) { operand id = 0 (d0) -> (d0 floordiv 8, d0 mod 8), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )")); } @@ -1615,8 +1524,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandShape) { (d0, d1) -> (d0 * 8 + d1), domain: d0 in [0, 3], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); } @@ -1635,8 +1543,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { domain: d0 in [0, 31], d1 in [0, 2], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); auto output_indexing = GetInputToOutputIndexing(root); @@ -1646,8 +1553,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { domain: d0 in [0, 3], d1 in [0, 7], - d2 in [0, 11], - is_simplified: true + d2 in [0, 11] )")); } @@ -1665,8 +1571,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandSubshapeOnly) { domain: d0 in [0, 3], d1 in [0, 3], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); } @@ -1684,8 +1589,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape2DTo3D) { domain: d0 in [0, 1], d1 in [0, 3], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); } @@ -1704,8 +1608,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape3DTo2D) { d1 mod 4), domain: d0 in [0, 3], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); } @@ -1724,14 +1627,12 @@ TEST_F(IndexingAnalysisTest, PadOp) { domain: d0 in [1, 7], d1 in [4, 7], - (d0 - 1) mod 2 in [0, 0], - is_simplified: false + (d0 - 1) mod 2 in [0, 0] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 11], - d1 in [0, 15], - is_simplified: false + d1 in [0, 15] )")); } @@ -1749,14 +1650,12 @@ TEST_F(IndexingAnalysisTest, PadOpNoInterior) { (d0, d1) -> (d0 - 1, d1), domain: d0 in [1, 2], - d1 in [0, 7], - is_simplified: false + d1 in [0, 7] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 9], - d1 in [0, 7], - is_simplified: false + d1 in [0, 7] )")); } @@ -1779,13 +1678,11 @@ TEST_F(IndexingAnalysisTest, PadOpNegativePadding) { (d0) -> ((d0 + 3) floordiv 2), domain: d0 in [0, 4], - (d0 + 3) mod 2 in [0, 0], - is_simplified: false + (d0 + 3) mod 2 in [0, 0] operand id = 1 (d0) -> (), domain: - d0 in [0, 4], - is_simplified: false + d0 in [0, 4] )")); } @@ -1812,14 +1709,12 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { d0 in [0, 149], d1 in [0, 9], s0 in [0, 19], - s1 in [0, 49], - is_simplified: false + s1 in [0, 49] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 149], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, 0); @@ -1830,8 +1725,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { d0 in [0, 149], d1 in [0, 19], d2 in [0, 9], - d3 in [0, 49], - is_simplified: false + d3 in [0, 49] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, 1); EXPECT_THAT(output_indexing_1.ToString(), MatchIndexingString(R"( @@ -1839,8 +1733,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { ()[s0, s1] -> (s0, s1), domain: s0 in [0, 149], - s1 in [0, 9], - is_simplified: false + s1 in [0, 9] )")); } @@ -1873,24 +1766,20 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 1 (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 2 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] operand id = 3 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); auto output_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); @@ -1899,32 +1788,27 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 1 (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 2 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] operand id = 3 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); constexpr std::string_view kInputToOutputIndexing = R"( (d0, d1) -> (d1), domain: d0 in [0, 255], - d1 in [0, 9], - is_simplified: false + d1 in [0, 9] )"; auto input_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); EXPECT_THAT( @@ -1941,8 +1825,7 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { constexpr std::string_view kInitToOutputIndexing = R"( ()[s0] -> (s0), domain: - s0 in [0, 9], - is_simplified: false + s0 in [0, 9] )"; auto input_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); EXPECT_THAT( @@ -1978,14 +1861,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_NoPadding) { domain: d0 in [0, 1023], d1 in [0, 2], - s0 in [0, 511], - is_simplified: true + s0 in [0, 511] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 1023], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] )")); } @@ -2014,14 +1895,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_PaddingAndWindowStride) { s0 in [0, 2], s1 in [0, 1], d0 * 2 + s0 in [1, 13], - d1 + s1 in [0, 16], - is_simplified: true + d1 + s1 in [0, 16] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 6], - d1 in [0, 16], - is_simplified: true + d1 in [0, 16] )")); } @@ -2048,14 +1927,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_BaseDilation) { d0 in [0, 2], d1 in [0, 4], d0 mod 2 in [0, 0], - d1 mod 2 in [0, 0], - is_simplified: true + d1 mod 2 in [0, 0] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 2], - d1 in [0, 4], - is_simplified: true + d1 in [0, 4] )")); } @@ -2081,14 +1958,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_WindowDilation) { domain: d0 in [0, 3], d1 in [0, 2], - s0 in [0, 1], - is_simplified: true + s0 in [0, 1] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 3], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] )")); } @@ -2122,28 +1997,24 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 1 (d0, d1)[s0, s1] -> (s0, d1 + s1), domain: d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); auto input_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); EXPECT_THAT(input_indexing_1.ToString(), MatchIndexingString(R"( @@ -2153,28 +2024,24 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 1 (d0, d1)[s0, s1] -> (s0, d1 + s1), domain: d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); } @@ -2199,8 +2066,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2210,8 +2076,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2238,8 +2103,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { s1 in [0, 4], s2 in [0, 3], d1 * 2 + s0 in [1, 12], - d2 * 2 + s1 in [2, 11], - is_simplified: false + d2 * 2 + s1 in [2, 11] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2249,8 +2113,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2277,8 +2140,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { s1 in [0, 4], s2 in [0, 3], (d1 + s0) mod 2 in [0, 0], - (d2 + s1) mod 2 in [0, 0], - is_simplified: false + (d2 + s1) mod 2 in [0, 0] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2288,8 +2150,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2314,8 +2175,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2325,8 +2185,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2351,8 +2210,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { d3 in [0, 47], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2362,8 +2220,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { d3 in [0, 47], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2389,8 +2246,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { s0 in [0, 2], s1 in [0, 4], s2 in [0, 3], - s3 in [0, 6], - is_simplified: false + s3 in [0, 6] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2400,8 +2256,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { d3 in [0, 20], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2421,8 +2276,7 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { d0 in [0, 0], d1 in [0, 16], d2 in [0, 8], - d3 in [0, 8], - is_simplified: false + d3 in [0, 8] )")); auto output_indexing = GetInputToOutputIndexing(root); @@ -2433,8 +2287,7 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { d0 in [0, 0], d1 in [0, 16], d2 in [0, 8], - d3 in [0, 8], - is_simplified: false + d3 in [0, 8] )")); } @@ -2459,8 +2312,7 @@ TEST_F(IndexingAnalysisTest, ReverseReshape) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 10], - is_simplified: true + d1 in [0, 10] )")); } @@ -2480,8 +2332,7 @@ TEST_F(IndexingAnalysisTest, SliceOp) { domain: d0 in [0, 4], d1 in [0, 2], - d2 in [0, 24], - is_simplified: false + d2 in [0, 24] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -2496,8 +2347,7 @@ TEST_F(IndexingAnalysisTest, SliceOp) { d1 in [3, 17], d2 in [0, 48], (d1 - 3) mod 7 in [0, 0], - d2 mod 2 in [0, 0], - is_simplified: false + d2 mod 2 in [0, 0] )")); } @@ -2517,8 +2367,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: false + d3 in [0, 12287] )")); EXPECT_THAT(GetInputToOutputIndexing(root).ToString(), MatchIndexingString(R"( operand id = 0 @@ -2527,8 +2376,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { d0 in [0, 2], d1 in [0, 12287], d2 in [0, 5], - d3 in [0, 127], - is_simplified: false + d3 in [0, 127] )")); } @@ -2547,8 +2395,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp4D) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: true + d3 in [0, 12287] )")); } @@ -2574,8 +2421,7 @@ TEST_F(IndexingAnalysisTest, DotOp) { d4 in [0, 15], d5 in [0, 21], s0 in [0, 17], - s1 in [0, 16], - is_simplified: false + s1 in [0, 16] operand id = 1 (d0, d1, d2, d3, d4, d5)[s0, s1] -> (s1, d0, d4, s0, d5, d1), domain: @@ -2586,8 +2432,7 @@ TEST_F(IndexingAnalysisTest, DotOp) { d4 in [0, 15], d5 in [0, 21], s0 in [0, 17], - s1 in [0, 16], - is_simplified: false + s1 in [0, 16] )")); } @@ -2648,8 +2493,7 @@ TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { (d0, d1) -> (d0 * 6, d1 * 2), domain: d0 in [0, 3], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] operand id = 1 unknown indexing operand id = 2 @@ -2686,8 +2530,7 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing) { (d0, d1) -> (d1 * 1000 + d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )")); } @@ -2716,8 +2559,7 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )")); } @@ -2735,18 +2577,15 @@ TEST_F(IndexingAnalysisTest, BroadcastingElementwise) { operand id = 0 (d0, d1) -> (), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] operand id = 2 (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )")); } @@ -2778,14 +2617,12 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDUS) { s0 in [0, 4096], hlo: %slice = s32[1]{0} parameter(1), (d0, d1) -> (0), - d1 + s0 in [4096, 8191], - is_simplified: true + d1 + s0 in [4096, 8191] operand id = 1 (d0, d1) -> (0), domain: d0 in [0, 0], - d1 in [0, 4095], - is_simplified: true + d1 in [0, 4095] )")); } diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 8e431976467734..f7ec1f1f83dd76 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -1001,8 +1001,7 @@ std::vector RangeVarsFromTensorSizes( IndexingMap::IndexingMap( AffineMap affine_map, std::vector dimensions, std::vector range_vars, std::vector rt_vars, - absl::Span const> constraints, - bool is_simplified) + absl::Span const> constraints) : affine_map_(affine_map), dim_vars_(std::move(dimensions)), range_vars_(std::move(range_vars)), @@ -1014,7 +1013,6 @@ IndexingMap::IndexingMap( for (const auto& [expr, range] : constraints) { AddConstraint(expr, range); } - is_simplified_ = is_simplified; } IndexingMap::IndexingMap( @@ -1034,13 +1032,10 @@ IndexingMap::IndexingMap( IndexingMap IndexingMap::FromTensorSizes( AffineMap affine_map, absl::Span dim_upper_bounds, - absl::Span symbol_upper_bounds, bool is_simplified) { - return IndexingMap{affine_map, - DimVarsFromTensorSizes(dim_upper_bounds), + absl::Span symbol_upper_bounds) { + return IndexingMap{affine_map, DimVarsFromTensorSizes(dim_upper_bounds), RangeVarsFromTensorSizes(symbol_upper_bounds), - /*rt_vars=*/{}, - /*constraints=*/{}, - is_simplified}; + /*rt_vars=*/{}}; } RangeEvaluator IndexingMap::GetRangeEvaluator() const { @@ -1052,7 +1047,6 @@ const Interval& IndexingMap::GetDimensionBound(int64_t dim_id) const { } Interval& IndexingMap::GetMutableDimensionBound(int64_t dim_id) { - is_simplified_ = false; return dim_vars_[dim_id].bounds; } @@ -1075,7 +1069,6 @@ const Interval& IndexingMap::GetSymbolBound(int64_t symbol_id) const { } Interval& IndexingMap::GetMutableSymbolBound(int64_t symbol_id) { - is_simplified_ = false; // Because affine map symbols are packed like [range_vars, rt_vars], // we have to pick the correct bounds. int64_t range_var_count = GetRangeVarsCount(); @@ -1131,7 +1124,6 @@ void IndexingMap::AddConstraint(mlir::AffineExpr expr, Interval range) { ResetToKnownEmpty(); } } - is_simplified_ = false; } void IndexingMap::EraseConstraint(mlir::AffineExpr expr) { @@ -1305,7 +1297,7 @@ bool IndexingMap::Verify(std::ostream& out) const { // simplification, because the ranges of constraints were already optimized once // when IndexingMap was constructed. bool IndexingMap::Simplify() { - if (IsSimplified() || IsUndefined() || IsKnownEmpty()) return false; + if (IsUndefined() || IsKnownEmpty()) return false; bool rtvars_were_eliminated = ReplaceConstantRTVars(); @@ -1336,7 +1328,6 @@ bool IndexingMap::Simplify() { if (affine_map_was_simplified) { affine_map_ = simplified_affine_map; } - is_simplified_ = true; return affine_map_was_simplified || constraints_were_simplified || rtvars_were_eliminated; } @@ -1639,7 +1630,6 @@ void IndexingMap::ResetToKnownEmpty() { } constraints_.clear(); is_known_empty_ = true; - is_simplified_ = true; } bool IndexingMap::VerifyVariableIntervals() { @@ -2124,8 +2114,7 @@ IndexingMap IndexingMap::ConvertSymbolsToDimensions() const { AffineMap canonical_map = affine_map_.replaceDimsAndSymbols({}, syms_replacements, num_vars, 0); IndexingMap new_indexing_map(canonical_map, new_dim_vars, /*range_vars=*/{}, - /*rt_vars=*/{}, new_constraints, - /*is_simplified=*/false); + /*rt_vars=*/{}, new_constraints); return new_indexing_map; } diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 36780ddd1841e2..5751cb4c886d10 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -297,8 +297,7 @@ class IndexingMap { IndexingMap( mlir::AffineMap affine_map, std::vector dimensions, std::vector range_vars, std::vector rt_vars, - absl::Span const> constraints = {}, - bool is_simplified = false); + absl::Span const> constraints = {}); IndexingMap(mlir::AffineMap affine_map, std::vector dimensions, std::vector range_vars, std::vector rt_vars, @@ -314,8 +313,7 @@ class IndexingMap { static IndexingMap FromTensorSizes( mlir::AffineMap affine_map, absl::Span dim_upper_bounds, - absl::Span symbol_upper_bounds, - bool is_simplified = false); + absl::Span symbol_upper_bounds); // Returns true if the indexing map is valid. bool Verify(std::ostream& out) const; @@ -397,10 +395,6 @@ class IndexingMap { // satisfies both constraints. bool IsKnownEmpty() const { return is_known_empty_; } - // Returns true if the indexing map is simplified. - void SetIsSimplified(bool is_simplified) { is_simplified_ = is_simplified; } - bool IsSimplified() const { return is_simplified_; } - bool IsUndefined() const { return affine_map_ == mlir::AffineMap(); } // Removes unused symbols from the `affine_map_` and constraints. @@ -474,8 +468,6 @@ class IndexingMap { llvm::DenseMap constraints_; // Flag to indicate that the domain is empty. bool is_known_empty_ = false; - // Flag to indicate that the indexing map is simplified. - bool is_simplified_ = false; }; std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); bool operator==(const IndexingMap& lhs, const IndexingMap& rhs); diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc index 4e72a4b56dd94f..3d6eb9bf1b1b23 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc @@ -293,9 +293,6 @@ Token Parser::GetNextTokenImpl() { if (spelling == "domain") { return Token{spelling, Token::Kind::kKeywordDomain}; } - if (spelling == "is_simplified") { - return Token{spelling, Token::Kind::kKeywordIsSimplified}; - } if (spelling == "in") { return Token{spelling, Token::Kind::kKeywordIn}; } @@ -599,7 +596,8 @@ std::optional ParseIndexingMap(llvm::StringRef input, if (!parser.ParseVarName(&var_name) || !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || - !parser.ConsumeToken(Token::Kind::kComma)) { + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { llvm::errs() << "Failed to parse DimVar\n"; return std::nullopt; } @@ -617,7 +615,8 @@ std::optional ParseIndexingMap(llvm::StringRef input, if (!parser.ParseVarName(&var_name) || !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || - !parser.ConsumeToken(Token::Kind::kComma)) { + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { llvm::errs() << "Failed to parse RangeVar\n"; return std::nullopt; } @@ -629,31 +628,20 @@ std::optional ParseIndexingMap(llvm::StringRef input, } // Parse constraints. SmallVector constraint_bounds; - while (!parser.ConsumeToken(Token::Kind::kKeywordIsSimplified)) { + while (!parser.ConsumeToken(Token::Kind::kEOF)) { std::string affine_expr_str; Interval interval; if (!parser.ParseAffineExprString(&affine_expr_str) || !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || - !parser.ConsumeToken(Token::Kind::kComma)) { + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { llvm::errs() << "Failed to parse constraint\n"; return std::nullopt; } affine_expr_strs.push_back(affine_expr_str); constraint_bounds.push_back(interval); } - // Parse is_simplified. - bool is_simplified; - if (!parser.ConsumeToken(Token::Kind::kColon) || - !parser.ParseBool(&is_simplified)) { - llvm::errs() << "Failed to parse is_simplified\n"; - return std::nullopt; - } - // Check that the input is consumed. - if (!parser.ConsumeToken(Token::Kind::kEOF)) { - return std::nullopt; - } - // Parse affine expressions. SmallVector affine_exprs; if (!ParseAffineExprsWithMLIR(dim_var_names, symbol_var_names, @@ -674,9 +662,8 @@ std::optional ParseIndexingMap(llvm::StringRef input, } auto map = AffineMap::get(dim_vars.size(), range_vars.size(), affine_map_results, context); - return IndexingMap{ - map, std::move(dim_vars), std::move(range_vars), /*rt_vars=*/{}, - constraints, is_simplified}; + return IndexingMap{map, std::move(dim_vars), std::move(range_vars), + /*rt_vars=*/{}, constraints}; } std::string ToString(AffineExpr affine_expr) { @@ -782,18 +769,29 @@ std::string ToString(const IndexingMap& indexing_map, return ss.str(); } ss << ", domain: "; + int64_t remaining_vars_to_print = + dim_vars.size() + range_vars.size() + rt_vars.size(); for (const auto& [index, dim_var] : llvm::enumerate(dim_vars)) { - ss << dim_names[index] << " in " << dim_var.bounds << ", "; + ss << dim_names[index] << " in " << dim_var.bounds; + if (--remaining_vars_to_print > 0) { + ss << ", "; + } } for (const auto& [index, range_var] : llvm::enumerate(range_vars)) { - ss << symbol_names[index] << " in " << range_var.range << ", "; + ss << symbol_names[index] << " in " << range_var.range; + if (--remaining_vars_to_print > 0) { + ss << ", "; + } } int64_t num_range_vars = range_vars.size(); for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { ss << GetSymbolName(num_range_vars + index, symbol_names) << " in " << rt_var.feasible_values << ", hlo: " << (rt_var.hlo == nullptr ? "NULL" : rt_var.hlo->ToString()) << ", " - << ToString(rt_var.map) << ", "; + << ToString(rt_var.map); + if (--remaining_vars_to_print > 0) { + ss << ", "; + } } std::vector expr_range_strings; const auto& constraints = indexing_map.GetConstraints(); @@ -803,10 +801,9 @@ std::string ToString(const IndexingMap& indexing_map, ToString(expr, dim_names, symbol_names), " in ", range.ToString())); } std::sort(expr_range_strings.begin(), expr_range_strings.end()); - for (const auto& expr_range_string : expr_range_strings) { - ss << expr_range_string << ", "; + if (!expr_range_strings.empty()) { + ss << ", " << absl::StrJoin(expr_range_strings, ", "); } - ss << "is_simplified: " << (indexing_map.IsSimplified() ? "true" : "false"); return ss.str(); } diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc index 28a7f7b60b4ac8..98fff83cc277fd 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc @@ -48,8 +48,7 @@ TEST_F(IndexingMapSerializationTest, DimsOnly) { (d0, d1) -> (d0 mod 2 + d1), domain: d0 in [0, 3], - d1 in [-4, 4], - is_simplified: true + d1 in [-4, 4] )"); } @@ -58,8 +57,7 @@ TEST_F(IndexingMapSerializationTest, SymbolsOnly) { ()[s0, s1] -> (s0 floordiv s1), domain: s0 in [0, 3], - s1 in [0, 4], - is_simplified: true + s1 in [0, 4] )"); } @@ -71,8 +69,7 @@ TEST_F(IndexingMapSerializationTest, DimsAndSymbolsNoConstraints) { d1 in [0, 4], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )"); } @@ -86,8 +83,7 @@ TEST_F(IndexingMapSerializationTest, DimsAndSymbolsAndConstraints) { s1 in [0, 1], s2 in [0, 3], d0 mod 4 in [0, 0], - d1 + s0 in [0, 45], - is_simplified: false + d1 + s0 in [0, 45] )"); } @@ -99,8 +95,7 @@ TEST_F(IndexingMapSerializationTest, AffineExprsWithParens) { d0 in [0, 9], d1 in [0, 19], s0 in [0, 29], - s1 in [0, 39], - is_simplified: false + s1 in [0, 39] )"); } @@ -116,8 +111,7 @@ TEST_F(IndexingMapSerializationTest, CustomNames) { reduced_dim in [0, 1], contracted_dim in [0, 3], th_x mod 4 in [0, 0], - bl_x + vector_elem in [0, 45], - is_simplified: false + bl_x + vector_elem in [0, 45] )"; auto indexing_map_golden = R"( (d0, d1)[s0, s1, s2] -> (s2, d0 + d1, s1, s0), @@ -128,8 +122,7 @@ TEST_F(IndexingMapSerializationTest, CustomNames) { s1 in [0, 1], s2 in [0, 3], d0 mod 4 in [0, 0], - d1 + s0 in [0, 45], - is_simplified: false + d1 + s0 in [0, 45] )"; auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); ASSERT_TRUE(indexing_map.has_value()); diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 8fd369b5b6e596..b7e45e141b6af6 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -136,8 +136,7 @@ TEST_F(IndexingMapTest, RTVar) { (d0, d1) -> (), rt_1 in [0, 7], hlo: NULL, - (d0, d1) -> (), - is_simplified: false + (d0, d1) -> () )")); } @@ -148,8 +147,7 @@ TEST_F(IndexingMapTest, Evaluation) { d0 in [0, 3], d1 in [0, 3], s0 in [0, 1], - s1 in [0, 1], - is_simplified: false + s1 in [0, 1] )"); auto results = indexing_map.Evaluate( mlir::getAffineConstantExprs({1, 2}, &mlir_context_), @@ -177,16 +175,14 @@ TEST_F(IndexingMapTest, Composition_Permutation) { d0 in [0, 3], d1 in [0, 3], s0 in [0, 1], - s1 in [0, 1], - is_simplified: false + s1 in [0, 1] )"); IndexingMap consumer = Parse(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 3], - s0 in [0, 3], - is_simplified: false + s0 in [0, 3] )"); auto composed = ComposeIndexingMaps(consumer, producer); @@ -196,8 +192,7 @@ TEST_F(IndexingMapTest, Composition_Permutation) { d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -208,16 +203,14 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { d0 in [0, 4], d1 in [0, 5], s0 in [0, 6], - s1 in [0, 1], - is_simplified: false + s1 in [0, 1] )"); IndexingMap consumer = Parse(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 9], - s0 in [0, 7], - is_simplified: false + s0 in [0, 7] )"); auto composed = ComposeIndexingMaps(consumer, producer); @@ -227,8 +220,7 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { d0 in [0, 4], s0 in [0, 6], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } @@ -241,8 +233,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { s0 in [0, 69], s1 in [0, 19], d0 mod 8 in [0, 0], - s0 mod 3 in [1, 1], - is_simplified: false + s0 mod 3 in [1, 1] )"); IndexingMap consumer = Parse(R"( @@ -251,8 +242,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { d0 in [0, 9], s0 in [0, 7], d0 + s0 in [0, 20], - s0 mod 4 in [0, 0], - is_simplified: false + s0 mod 4 in [0, 0] )"); auto composed = ComposeIndexingMaps(consumer, producer); @@ -266,8 +256,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { d0 + s2 in [0, 20], d0 mod 8 in [0, 0], s0 mod 3 in [1, 1], - s2 mod 4 in [0, 0], - is_simplified: false + s2 mod 4 in [0, 0] )")); EXPECT_TRUE(composed.Simplify()); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -279,8 +268,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { s2 in [0, 4], d0 mod 8 in [0, 0], s0 mod 3 in [1, 1], - s2 mod 4 in [0, 0], - is_simplified: true + s2 mod 4 in [0, 0] )")); } @@ -319,8 +307,7 @@ TEST_F(IndexingMapTest, Composition_RTVar) { (d0, d1) -> (), rt_2 in [0, 226], hlo: NULL, - (d0, d1) -> (), - is_simplified: false + (d0, d1) -> () )")); } @@ -365,8 +352,7 @@ TEST_F(IndexingMapTest, Composition_OnlyRTVars) { hlo: NULL, (d0, d1) -> (), d0 + cs_0 * 2 in [0, 24], - d1 + cs_1 * 3 in [0, 15], - is_simplified: false + d1 + cs_1 * 3 in [0, 15] )")); } @@ -380,8 +366,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { s0 in [0, 69], s1 in [0, 19], d0 + s0 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -392,8 +377,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { s0 in [0, 69], s1 in [0, 19], d0 + s0 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )")); } @@ -406,8 +390,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { d1 in [0, 59], s0 in [0, 69], s1 in [0, 19], - d0 mod 3 in [0, 0], - is_simplified: false + d0 mod 3 in [0, 0] )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -415,8 +398,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { domain: d0 in [0, 59], s0 in [0, 69], - s1 in [0, 19], - is_simplified: false + s1 in [0, 19] )")); } @@ -429,8 +411,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { d1 in [0, 59], s0 in [0, 69], s1 in [0, 19], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -438,8 +419,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { domain: d0 in [0, 49], d1 in [0, 59], - s0 in [0, 19], - is_simplified: false + s0 in [0, 19] )")); } @@ -456,8 +436,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { s1 in [0, 63], s2 in [0, 95], s0 * 4 + d1 + d3 in [24, 459], - s0 + s2 in [0, 512], - is_simplified: false + s0 + s2 in [0, 512] )"); // dimensions d0, d2, d4 and symbol s1 will be removed. auto unused_vars = indexing_map.RemoveUnusedVars(); @@ -469,8 +448,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { s0 in [0, 31], s1 in [0, 95], d0 + s0 * 4 + d1 in [24, 459], - s0 + s1 in [0, 512], - is_simplified: false + s0 + s1 in [0, 512] )")); EXPECT_THAT(ConvertToSTL(unused_vars), ::testing::ElementsAreArray( @@ -486,8 +464,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { s0 in [0, 69], s1 in [0, 19], s0 + s1 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); // This constraint cannot be removed, because it contains a "used symbol". indexing_map.RemoveUnusedSymbols(); @@ -499,8 +476,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { s0 in [0, 69], s1 in [0, 19], s0 + s1 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )")); } @@ -512,8 +488,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { d1 in [0, 59], s0 in [0, 69], s1 in [0, 19], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); // This constraint can be removed, because it contains only the unused symbol. indexing_map.RemoveUnusedSymbols(); @@ -522,8 +497,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { domain: d0 in [0, 49], d1 in [0, 59], - s0 in [0, 19], - is_simplified: false + s0 in [0, 19] )")); } @@ -532,14 +506,12 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { (d0) -> (d0), domain: d0 in [0, 49], - 0 in [-10, 5], - is_simplified: false + 0 in [-10, 5] )"); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 49], - is_simplified: false + d0 in [0, 49] )")); } @@ -547,8 +519,7 @@ TEST_F(IndexingMapTest, KnownEmpty_CreatingIndexingMapWithInfeasibleRange) { auto indexing_map = Parse(R"( (d0) -> (d0), domain: - d0 in [0, -2], - is_simplified: false + d0 in [0, -2] )"); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } @@ -558,20 +529,15 @@ TEST_F(IndexingMapTest, KnownEmpty_AddingConstraintOutOfRange) { (d0) -> (d0), domain: d0 in [0, 49], - 0 in [10, 15], - is_simplified: false + 0 in [10, 15] )"); // Addition of this constraint makes the domain empty. EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, KnownEmpty_Composition) { - auto indexing_map = Parse(R"( - (d0) -> (d0), domain: d0 in [0, 49], is_simplified: false - )"); - auto known_empty = Parse(R"( - (d0) -> (d0), domain: d0 in [0, -1], is_simplified: false - )"); + auto indexing_map = Parse("(d0) -> (d0), domain: d0 in [0, 49]"); + auto known_empty = Parse("(d0) -> (d0), domain: d0 in [0, -1]"); EXPECT_THAT(known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(indexing_map * known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(known_empty * indexing_map, MatchIndexingMap("KNOWN EMPTY")); @@ -588,8 +554,7 @@ TEST_F(IndexingMapTest, d1 in [0, 59], s0 in [0, 69], s1 in [0, 19], - s1 floordiv 20 in [2, 2], - is_simplified: false + s1 floordiv 20 in [2, 2] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); @@ -605,8 +570,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { s2 in [0, 2], s3 in [0, 3], s4 in [0, 4], - d0 * 4 + s1 + s3 in [24, 459], - is_simplified: false + d0 * 4 + s1 + s3 in [24, 459] )"); indexing_map.RemoveUnusedSymbols(); // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. @@ -616,8 +580,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { d0 in [0, 31], s0 in [0, 1], s1 in [0, 3], - d0 * 4 + s0 + s1 in [24, 459], - is_simplified: false + d0 * 4 + s0 + s1 in [24, 459] )")); } @@ -644,8 +607,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { s1 in [0, 3], hlo: NULL, (d0) -> (), - d0 * 4 + s0 + s1 in [24, 459], - is_simplified: false + d0 * 4 + s0 + s1 in [24, 459] )")); }; @@ -669,8 +631,7 @@ TEST_F(IndexingMapTest, ConvertSymbolsToDimensions) { d2 in [0, 1], d3 in [0, 3], d4 in [0, 4], - d0 * 4 + d1 + d3 * 2 in [24, 459], - is_simplified: false + d0 * 4 + d1 + d3 * 2 in [24, 459] )")); } @@ -679,16 +640,14 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { (d0) -> (d0), domain: d0 in [0, 99], - d0 mod 8 + 5 in [50, 54], - is_simplified: false + d0 mod 8 + 5 in [50, 54] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: d0 in [0, 99], - d0 mod 8 in [45, 49], - is_simplified: true + d0 mod 8 in [45, 49] )")); } @@ -700,8 +659,7 @@ TEST_F(IndexingMapTest, d0 in [0, 1999], s0 in [0, 1], s1 in [0, 2], - d0 * 6 + s0 * 3 + s1 in [0, 599], - is_simplified: false + d0 * 6 + s0 * 3 + s1 in [0, 599] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -709,8 +667,7 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [0, 1], - s1 in [0, 2], - is_simplified: true + s1 in [0, 2] )")); } @@ -722,8 +679,7 @@ TEST_F(IndexingMapTest, d0 in [0, 1999], s0 in [0, 1], s1 in [0, 2], - d0 * 6 + s0 * 3 + s1 in [0, 598], - is_simplified: false + d0 * 6 + s0 * 3 + s1 in [0, 598] )"); EXPECT_FALSE(indexing_map.Simplify()); } @@ -734,16 +690,14 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { domain: d0 in [0, 1999], s0 in [0, 1], - d0 * 6 + s0 * 3 in [0, 599], - is_simplified: false + d0 * 6 + s0 * 3 in [0, 599] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 * 6 + s0 * 3), domain: d0 in [0, 99], - s0 in [0, 1], - is_simplified: true + s0 in [0, 1] )")); } @@ -753,15 +707,13 @@ TEST_F(IndexingMapTest, (d0) -> (d0), domain: d0 in [0, 99], - d0 floordiv 8 in [5, 11], - is_simplified: false + d0 floordiv 8 in [5, 11] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: - d0 in [40, 95], - is_simplified: true + d0 in [40, 95] )")); } @@ -772,16 +724,14 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [-99, 99], - s0 floordiv 3 in [-11, -5], - is_simplified: false + s0 floordiv 3 in [-11, -5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [-33, -13], - is_simplified: true + s0 in [-33, -13] )")); } @@ -792,16 +742,14 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [-99, 99], - s0 floordiv -3 in [-11, -5], - is_simplified: false + s0 floordiv -3 in [-11, -5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [15, 35], - is_simplified: true + s0 in [15, 35] )")); } @@ -811,15 +759,13 @@ TEST_F(IndexingMapTest, (d0) -> (d0), domain: d0 in [0, 99], - d0 * 8 in [14, 33], - is_simplified: false + d0 * 8 in [14, 33] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: - d0 in [2, 4], - is_simplified: true + d0 in [2, 4] )")); } @@ -830,16 +776,14 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [-99, 99], - s0 * 3 in [-11, -5], - is_simplified: false + s0 * 3 in [-11, -5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [-3, -2], - is_simplified: true + s0 in [-3, -2] )")); } @@ -850,16 +794,14 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [-99, 99], - s0 * -3 in [-11, -5], - is_simplified: false + s0 * -3 in [-11, -5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [2, 3], - is_simplified: true + s0 in [2, 3] )")); } @@ -873,8 +815,7 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { d0 mod 3 in [0, 0], s0 mod 2 in [0, 0], s0 mod 3 in [0, 0], - s1 mod 5 in [1, 1], - is_simplified: false + s1 mod 5 in [1, 1] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -885,8 +826,7 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { s1 in [1, 6], d0 mod 3 in [0, 0], s0 mod 6 in [0, 0], - s1 mod 5 in [1, 1], - is_simplified: true + s1 mod 5 in [1, 1] )")); } @@ -894,15 +834,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { auto indexing_map = Parse(R"( (d0) -> (d0), domain: - d0 in [5, 5], - is_simplified: false + d0 in [5, 5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (5), domain: - d0 in [5, 5], - is_simplified: true + d0 in [5, 5] )")); } @@ -916,8 +854,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { d0 in [0, 9], d1 in [0, 19], s0 in [0, 29], - s1 in [0, 39], - is_simplified: false + s1 in [0, 39] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); @@ -930,8 +867,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression2) { (d0)[s0] -> ((((s0 + d0) + d0) floordiv 2)), domain: d0 in [0, 9], - s0 in [0, 19], - is_simplified: false + s0 in [0, 19] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); @@ -942,16 +878,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { (d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6), domain: d0 in [0, 11], - d1 in [0, 5], - is_simplified: false + d1 in [0, 5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 floordiv 6), domain: d0 in [0, 11], - d1 in [0, 5], - is_simplified: true + d1 in [0, 5] )")); } @@ -959,15 +893,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { auto indexing_map = Parse(R"( (d0) -> (d0 mod 42), domain: - d0 in [53, 71], - is_simplified: false + d0 in [53, 71] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0 - 42), domain: - d0 in [53, 71], - is_simplified: true + d0 in [53, 71] )")); } @@ -975,24 +907,20 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { auto indexing_map = Parse(R"( (d0) -> (d0 mod 5), domain: - d0 in [-5, -1], - is_simplified: false + d0 in [-5, -1] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0 + 5), domain: - d0 in [-5, -1], - is_simplified: true + d0 in [-5, -1] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsNotAdd) { - auto indexing_map1 = - Parse("(d0) -> (d0 mod 5), domain: d0 in [-4, 0], is_simplified: false"); + auto indexing_map1 = Parse("(d0) -> (d0 mod 5), domain: d0 in [-4, 0]"); EXPECT_FALSE(indexing_map1.Simplify()); - auto indexing_map2 = - Parse("(d0) -> (d0 mod 5), domain: d0 in [-6, -1], is_simplified: false"); + auto indexing_map2 = Parse("(d0) -> (d0 mod 5), domain: d0 in [-6, -1]"); EXPECT_FALSE(indexing_map2.Simplify()); } @@ -1001,16 +929,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { (d0)[s0] -> (d0 - (s0 floordiv 3) * 3 + s0), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: false + s0 in [0, 3] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + s0 mod 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } @@ -1019,16 +945,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { (d0)[s0] -> (d0 - (s0 floordiv 3) * 12 + s0 * 7), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: false + s0 in [0, 3] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 mod 3) * 4 + s0 * 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } @@ -1037,16 +961,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { (d0)[s0] -> (1 + d0 - ((s0 + 1) floordiv 3) * 3 + s0), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: false + s0 in [0, 3] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 + 1) mod 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } @@ -1056,16 +978,14 @@ TEST_F(IndexingMapTest, (d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16), domain: d0 in [0, 7], - d1 in [0, 15], - is_simplified: false + d1 in [0, 15] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } @@ -1077,8 +997,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { domain: d0 in [0, 8], d1 in [0, 8], - d2 in [0, 8], - is_simplified: false + d2 in [0, 8] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1086,8 +1005,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { domain: d0 in [0, 8], d1 in [0, 8], - d2 in [0, 8], - is_simplified: true + d2 in [0, 8] )")); } @@ -1099,8 +1017,7 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 9], d1 in [0, 9], - d2 in [0, 9], - is_simplified: false + d2 in [0, 9] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1109,8 +1026,7 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 9], d1 in [0, 9], - d2 in [0, 9], - is_simplified: true + d2 in [0, 9] )")); } @@ -1120,16 +1036,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99), domain: d0 in [0, 7], - d1 in [0, 8], - is_simplified: false + d1 in [0, 8] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 8], - is_simplified: true + d1 in [0, 8] )")); } @@ -1137,15 +1051,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { auto indexing_map = Parse(R"( ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715), domain: - s0 in [0, 127], - is_simplified: false + s0 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (s0 * 128), domain: - s0 in [0, 127], - is_simplified: true + s0 in [0, 127] )")); } @@ -1154,8 +1066,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { (d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024), domain: d0 in [0, 1023], - d1 in [0, 127], - is_simplified: false + d1 in [0, 127] )"); ; EXPECT_TRUE(indexing_map.Simplify()); @@ -1163,8 +1074,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { (d0, d1) -> (d0 * 128 + d1), domain: d0 in [0, 1023], - d1 in [0, 127], - is_simplified: true + d1 in [0, 127] )")); } @@ -1174,16 +1084,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { + ((d1 * 128 + d0) floordiv 192) * 768), domain: d0 in [0, 127], - d1 in [0, 3071], - is_simplified: false + d1 in [0, 3071] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 4 + d1 * 512), domain: d0 in [0, 127], - d1 in [0, 3071], - is_simplified: true + d1 in [0, 3071] )")); } @@ -1192,15 +1100,13 @@ TEST_F(IndexingMapTest, auto indexing_map = Parse(R"( (d0) -> ((-d0) mod 2), domain: - d0 in [0, 127], - is_simplified: false + d0 in [0, 127] )"); EXPECT_FALSE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> ((-d0) mod 2), domain: - d0 in [0, 127], - is_simplified: true + d0 in [0, 127] )")); } @@ -1215,16 +1121,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { + ((d0 * 2 + d1 floordiv 64) mod 3) * 256 + (d1 mod 64) * 4), domain: d0 in [0, 3071], - d1 in [0, 127], - is_simplified: false + d1 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 512 + d1 * 4), domain: d0 in [0, 3071], - d1 in [0, 127], - is_simplified: true + d1 in [0, 127] )")); } @@ -1233,15 +1137,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { auto indexing_map = Parse(R"( ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715), domain: - s0 in [0, 127], - is_simplified: false + s0 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715), domain: - s0 in [0, 127], - is_simplified: true + s0 in [0, 127] )")); } @@ -1249,15 +1151,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { auto indexing_map = Parse(R"( ()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * 14), domain: - s0 in [0, 1233], - is_simplified: false + s0 in [0, 1233] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (s0), domain: - s0 in [0, 1233], - is_simplified: true + s0 in [0, 1233] )")); } @@ -1266,16 +1166,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { ()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3), domain: s0 in [0, 1233], - s1 in [0, 127], - is_simplified: false + s1 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0, s1] -> ((s0 * 128 + s1) floordiv 192), domain: s0 in [0, 1233], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] )")); } @@ -1283,15 +1181,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { auto indexing_map = Parse(R"( ()[s0] -> ((s0 * 6 + 9) floordiv 18), domain: - s0 in [0, 1233], - is_simplified: false + s0 in [0, 1233] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> ((s0 * 2 + 3) floordiv 6), domain: - s0 in [0, 1233], - is_simplified: true + s0 in [0, 1233] )")); } @@ -1300,8 +1196,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivSumDiv) { ()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6), domain: s0 in [0, 1233], - s1 in [0, 127], - is_simplified: false + s1 in [0, 127] )"); // The rewrite tested in AffineMapSimplification_DivDiv must not trigger here. EXPECT_FALSE(indexing_map.Simplify()); @@ -1314,8 +1209,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { auto indexing_map = Parse(R"( ()[s0] -> ((s0 floordiv 2) floordiv -7), domain: - s0 in [0, 1233], - is_simplified: false + s0 in [0, 1233] )"); EXPECT_FALSE(indexing_map.Simplify()); } @@ -1327,8 +1221,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { s0 in [0, 871], s1 in [0, 3], s2 in [0, 127], - s3 in [0, 895], - is_simplified: false + s3 in [0, 895] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1339,8 +1232,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { s0 in [0, 871], s1 in [0, 3], s2 in [0, 127], - s3 in [0, 895], - is_simplified: true + s3 in [0, 895] )")); } @@ -1351,8 +1243,7 @@ TEST_F(IndexingMapTest, floordiv 4), domain: s0 in [0, 1], - s1 in [0, 127], - is_simplified: false + s1 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1361,8 +1252,7 @@ TEST_F(IndexingMapTest, ), domain: s0 in [0, 1], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] )")); } @@ -1374,8 +1264,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { s0 in [0, 6], s1 in [0, 1], s2 in [0, 5], - s0 mod 6 in [0, 0], - is_simplified: false + s0 mod 6 in [0, 0] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1384,8 +1273,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } @@ -1397,8 +1285,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { s0 in [0, 41], s1 in [0, 1], s2 in [0, 5], - s0 mod 6 in [3, 3], - is_simplified: false + s0 mod 6 in [3, 3] )"); // [BEFORE] Allowed values for s0: 3, 9, 15, ..., 39 = (6 * 6 + 3) // [AFTER] Allowed values for s0: 0, 1, 2, ..., 6 @@ -1409,8 +1296,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { d0 in [0, 3], s0 in [0, 6], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } @@ -1423,8 +1309,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { s1 in [0, 1], s2 in [0, 5], s0 mod 2 in [0, 0], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1433,8 +1318,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } @@ -1447,8 +1331,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { s1 in [0, 1], s2 in [0, 5], s0 * s2 in [0, 28], - s0 mod 6 in [3, 3], - is_simplified: false + s0 mod 6 in [3, 3] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1458,8 +1341,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { s0 in [0, 1], s1 in [0, 1], s2 in [0, 5], - (s0 * 6 + 3) * s2 in [0, 28], - is_simplified: false + (s0 * 6 + 3) * s2 in [0, 28] )")); } @@ -1473,8 +1355,7 @@ TEST_F(IndexingMapTest, s1 in [0, 1], s2 in [0, 5], s0 mod 6 in [3, 3], - s0 mod 7 in [5, 5], - is_simplified: false + s0 mod 7 in [5, 5] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); @@ -1510,8 +1391,7 @@ TEST_F(IndexingMapTest, RescaleSymbolsKeepsHashmapConsistent) { s1 in [0, 1], s2 in [0, 5], s0 mod 6 in [0, 0], - s0 * s1 in [0, 100], - is_simplified: false + s0 * s1 in [0, 100] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); @@ -1528,8 +1408,7 @@ TEST_F(IndexingMapTest, RangeEvaluatorTest) { d0 in [0, 9], d1 in [-10, -1], d2 in [-1, 2], - d3 in [0, 0], - is_simplified: false + d3 in [0, 0] )"); RangeEvaluator range_evaluator(indexing_map, &mlir_context_); mlir::AffineExpr d0, d1, d2, d3; @@ -1784,8 +1663,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, d0), domain: - d0 in [0, 255], - is_simplified: true + d0 in [0, 255] )")); } @@ -1815,8 +1693,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, 7), domain: - d0 in [0, 255], - is_simplified: true + d0 in [0, 255] )")); } @@ -1849,8 +1726,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { (d0) -> (d0, d0), domain: d0 in [0, 254], - d0 mod 2 in [0, 0], - is_simplified: true + d0 mod 2 in [0, 0] )")); } @@ -1883,8 +1759,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, 11), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )")); } @@ -1926,8 +1801,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, (d0 floordiv 12) * -4 + 8), domain: - d0 in [0, 35], - is_simplified: true + d0 in [0, 35] )")); } @@ -1963,8 +1837,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { d0 in [0, 23], s0 in [0, 512], hlo: %constant = s64[12]{0} constant({...}), - (d0) -> (d0 floordiv 2), - is_simplified: true + (d0) -> (d0 floordiv 2) )")); } @@ -1999,8 +1872,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Add) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, d0 * 2 + 42), domain: - d0 in [0, 11], - is_simplified: true + d0 in [0, 11] )")); } @@ -2040,8 +1912,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Multiply) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, (-d0 + 11) * d0), domain: - d0 in [0, 11], - is_simplified: true + d0 in [0, 11] )")); } @@ -2080,8 +1951,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { d0 in [0, 11], s0 in [0, 11], hlo: %constant = s64[12]{0} constant({...}), - (d0) -> (d0), - is_simplified: true + (d0) -> (d0) )")); } @@ -2173,8 +2043,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { d0 in [0, 49], d1 in [0, 59], s0 in [0, 69], - s1 in [0, 79], - is_simplified: false + s1 in [0, 79] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1 * 2, d0, s1, s0), @@ -2182,8 +2051,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { d0 in [0, 49], d1 in [0, 59], s0 in [0, 69], - s1 in [0, 79], - is_simplified: false + s1 in [0, 79] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1, d0, s1, s0), @@ -2191,8 +2059,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { d0 in [0, 50], d1 in [0, 59], s0 in [0, 69], - s1 in [0, 79], - is_simplified: false + s1 in [0, 79] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1, d0, s1, s0), @@ -2200,8 +2067,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { d0 in [0, 49], d1 in [0, 59], s0 in [0, 69], - s1 in [0, 79], - is_simplified: false + s1 in [0, 79] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1, d0, s1, s0), @@ -2211,8 +2077,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { s0 in [0, 69], s1 in [0, 79], d0 mod 8 in [0, 0], - d0 mod 16 in [0, 0], - is_simplified: false + d0 mod 16 in [0, 0] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1, d0, s1, s0), @@ -2222,8 +2087,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { s0 in [0, 69], s1 in [0, 79], d0 mod 8 in [0, 0], - d0 mod 32 in [0, 0], - is_simplified: false + d0 mod 32 in [0, 0] )"), IndexingMap( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index db62388f89f099..d9607223fae319 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -169,8 +169,7 @@ ENTRY main { (d0, d1) -> (d0, d1 * 10), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); auto p0_from_subtract0 = root->operand(0); @@ -183,8 +182,7 @@ ENTRY main { (d0, d1) -> (d0, d1 * 10), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); EXPECT_THAT(*p0_from_subtract1, MatchTiledHloInstruction( @@ -194,8 +192,7 @@ ENTRY main { (d0, d1) -> (d0, 0), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); } @@ -287,8 +284,7 @@ ENTRY main { (d0, d1) -> (d0, 0), domain: d0 in [0, 1], - d1 in [0, 0], - is_simplified: true + d1 in [0, 0] )")); } @@ -322,8 +318,7 @@ ENTRY main { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); EXPECT_THAT(*root->operand(0), @@ -334,8 +329,7 @@ ENTRY main { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); } @@ -372,8 +366,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, d1 * 2), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); EXPECT_THAT(*p0_from_slice0, @@ -383,8 +376,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, d1 * 2 + 2), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); EXPECT_THAT(*p0_from_slice1, @@ -394,8 +386,7 @@ ENTRY main { (d0, d1) -> (d0 * 2 + 3, d1 * 2 + 4), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); } @@ -430,8 +421,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, d1 * 2), domain: d0 in [0, 1], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); const TiledHloInstruction* lhs = dot->operand(0); @@ -441,8 +431,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, 0), domain: d0 in [0, 1], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); const TiledHloInstruction* rhs = dot->operand(1); @@ -452,8 +441,7 @@ ENTRY main { (d0, d1) -> (0, d1 * 2), domain: d0 in [0, 1], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); } @@ -911,8 +899,7 @@ ENTRY main { (d0, d1) -> (d0, d1), domain: d0 in [0, 65537], - d1 in [0, 32767], - is_simplified: true + d1 in [0, 32767] )")); } @@ -967,8 +954,7 @@ ENTRY main { (d0, d1) -> (0, d1, 0), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); EXPECT_THAT(*param_0_tile, MatchTiledHloInstruction( @@ -984,8 +970,7 @@ ENTRY main { (d0, d1, d2) -> (), s1 in [0, 226], hlo: %of3 = s32[] parameter(3), - (d0, d1, d2) -> (), - is_simplified: true + (d0, d1, d2) -> () )")); } From 8d5add834893f008310afc66c4ad96b4db7bc4fc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 13:18:25 -0700 Subject: [PATCH 382/483] Add a missing include PiperOrigin-RevId: 679705653 --- tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h | 1 + tensorflow/lite/kernels/internal/runtime_shape.h | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h b/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h index 1c3f6ce789dcc4..3a602ba9c51687 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h +++ b/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h @@ -20,6 +20,7 @@ limitations under the License. // LINT.IfChange #include +#include #include #include #include diff --git a/tensorflow/lite/kernels/internal/runtime_shape.h b/tensorflow/lite/kernels/internal/runtime_shape.h index e266bb85477ad6..8982cb1732f018 100644 --- a/tensorflow/lite/kernels/internal/runtime_shape.h +++ b/tensorflow/lite/kernels/internal/runtime_shape.h @@ -19,6 +19,7 @@ limitations under the License. // LINT.IfChange #include +#include #include #include #include From 92e621342b5bec9eccddd113ec8427dff0a174f3 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Fri, 27 Sep 2024 13:37:39 -0700 Subject: [PATCH 383/483] Add FP8 support to the exhaustive tests Adds new tests for two FP8 variants to both unary and binary exhaustive tests. PiperOrigin-RevId: 679712240 --- third_party/xla/xla/tests/exhaustive/BUILD | 1 + .../exhaustive_binary_test_definitions.inc | 114 ++++++++++++++- ...ary_test_f16_and_smaller_instantiation.inc | 18 ++- ...haustive_binary_test_f32_instantiation.inc | 4 + ...haustive_binary_test_f64_instantiation.inc | 4 + .../exhaustive_binary_test_functions.cc | 24 +++- .../exhaustive/exhaustive_op_test_base.cc | 5 + .../exhaustive/exhaustive_op_test_utils.cc | 4 + .../exhaustive/exhaustive_op_test_utils.h | 33 ++++- .../exhaustive_unary_test_definitions.inc | 32 ++++- ...ary_test_f32_and_smaller_instantiation.inc | 16 ++- ...xhaustive_unary_test_f64_instantiation.inc | 4 + .../exhaustive_unary_test_functions.cc | 135 ++++++++++++++++-- 13 files changed, 366 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD index 412e0af5c89a49..7acbcd74c249c8 100644 --- a/third_party/xla/xla/tests/exhaustive/BUILD +++ b/third_party/xla/xla/tests/exhaustive/BUILD @@ -142,6 +142,7 @@ exhaustive_xla_test( ":exhaustive_op_test_utils", ":exhaustive_unary_test_textual_hdrs", "//xla:literal", + "//xla:types", "//xla/client:xla_builder", "//xla/client/lib:constants", "//xla/client/lib:math", diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc index 8fe0a71d45277d..e4feb10c9918cd 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc @@ -13,6 +13,92 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Exhaustive test for binary operations for 8-bit floating point types, +// including float16 and bfloat. +// +// Test parameter is a pair of (begin, end) for range under test. +template +class Exhaustive8BitBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + public: + int64_t GetInputSize() override { + int64_t begin, end; + std::tie(begin, end) = GetParam(); + return end - begin; + } + + // Given a range of uint64_t representation, uses bits 7..0 and bits 15..8 + // for the values of src0 and src1 (see below for ordering) for the 8-bit + // binary operation being tested, and generate the cartesian product of the + // two sets as the two inputs for the test. + // + // If `kLeftToRightPacking == true`, then bits 15..8 are interpreted as src0 + // and bits 7..0 are interpreted as src1. If `kLeftToRightPacking == false`, + // then bits 15..8 are interpreted as src1 and 7..0 are interpreted as src0. + void FillInput(std::array* input_literals) override { + int64_t input_size = GetInputSize(); + CHECK_EQ(input_size, (*input_literals)[0].element_count()); + CHECK_EQ(input_size, (*input_literals)[1].element_count()); + + int64_t begin, end; + std::tie(begin, end) = GetParam(); + + if (VLOG_IS_ON(2)) { + uint8_t left_begin, left_end, right_begin, right_end; + if constexpr (kLeftToRightPacking) { + left_begin = std::bit_cast(static_cast(begin >> 8)); + left_end = std::bit_cast(static_cast(end >> 8)); + right_begin = std::bit_cast(static_cast(begin)); + right_end = std::bit_cast(static_cast(end)); + } else { + left_begin = std::bit_cast(static_cast(begin)); + left_end = std::bit_cast(static_cast(end)); + right_begin = std::bit_cast(static_cast(begin >> 8)); + right_end = std::bit_cast(static_cast(end >> 8)); + } + + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + // N.B.: Cast to u32 to avoid printing values as char. + LOG(INFO) << "\tfrom=(" << static_cast(left_begin) << ", " + << static_cast(right_begin) << "); hex=(" << std::hex + << static_cast(left_begin) << ", " + << static_cast(right_begin) << "); float=(" + << std::bit_cast(left_begin) << ", " + << std::bit_cast(right_begin) + << ") (inclusive)"; + LOG(INFO) << "\tto=(" << static_cast(left_end) << ", " + << static_cast(right_end) << "); hex=(" << std::hex + << static_cast(left_end) << ", " + << static_cast(right_end) << "); float=(" + << std::bit_cast(left_end) << ", " + << std::bit_cast(right_end) + << ") (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + for (int64_t i = 0; i < input_size; i++) { + uint32_t input_val = i + begin; + // Convert the packed bits to a pair of NativeT and replace known + // incorrect input values with 0. + // + // In either case, we only use 16 bits out of the 64 bits possible. + if constexpr (kLeftToRightPacking) { + input_arr_0[i] = this->ConvertValue(input_val >> 8); + input_arr_1[i] = this->ConvertValue(input_val); + } else { + input_arr_0[i] = this->ConvertValue(input_val); + input_arr_1[i] = this->ConvertValue(input_val >> 8); + } + } + } + + protected: + using typename ExhaustiveBinaryTest::NativeT; +}; + // Exhaustive test for binary operations for 16 bit floating point types, // including float16 and bfloat. // @@ -147,11 +233,29 @@ class Exhaustive32BitOrMoreBinaryTest } }; +using ExhaustiveF8E4M3FNBinaryTest = Exhaustive8BitBinaryTest; +using ExhaustiveF8E5M2BinaryTest = Exhaustive8BitBinaryTest; using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest; using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest; using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +#define BINARY_TEST_F8E4M3FN(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E4M3FNBinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_E4M3FN(test_name, ...) +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +#define BINARY_TEST_F8E5M2(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E5M2BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_E5M2(test_name, ...) +#endif + #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) #define BINARY_TEST_F16(test_name, ...) \ XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ @@ -180,10 +284,12 @@ using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; #define BINARY_TEST_F64(test_name, ...) #endif -#define BINARY_TEST(test_name, ...) \ - BINARY_TEST_F16(test_name, __VA_ARGS__) \ - BINARY_TEST_BF16(test_name, __VA_ARGS__) \ - BINARY_TEST_F32(test_name, __VA_ARGS__) \ +#define BINARY_TEST(test_name, ...) \ + BINARY_TEST_F8E4M3FN(test_name, __VA_ARGS__) \ + BINARY_TEST_F8E5M2(test_name, __VA_ARGS__) \ + BINARY_TEST_F16(test_name, __VA_ARGS__) \ + BINARY_TEST_BF16(test_name, __VA_ARGS__) \ + BINARY_TEST_F32(test_name, __VA_ARGS__) \ BINARY_TEST_F64(test_name, __VA_ARGS__) #define BINARY_TEST_COMPLEX(test_name, ...) \ diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc index 1e88061028a65d..04456e6f3a8eaa 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc @@ -13,16 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +INSTANTIATE_TEST_SUITE_P(F8E4M3FN, ExhaustiveF8E4M3FNBinaryTest, + ::testing::ValuesIn(CreateExhaustiveU16Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2BinaryTest, + ::testing::ValuesIn(CreateExhaustiveU16Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); +#endif + #if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); #else GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); #endif #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); #else GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); #endif diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc index 1c8e97d1d5d41e..ba62061d7437fc 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); + GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc index 0de1e1242d6b7c..a91f93ee155d45 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); + GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc index 07485354cf2427..217c16ab7af9c8 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc @@ -299,7 +299,13 @@ bool PowCpuGpuF16Skip(NativeT left, NativeT right) { BINARY_TEST(Pow, { PowOp(this) .CpuError(+[](NativeT left, NativeT right) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { return ErrorSpec::Builder() .strict_signed_zeros() .skip_comparison(PowCpuGpuF16Skip(left, right)) @@ -357,7 +363,14 @@ bool Atan2CpuBf16F32Skip(NativeT left, NativeT right) { BINARY_TEST(Atan2, { Atan2Op(this) .CpuError([](NativeT left, NativeT right) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + + } else if constexpr (std::is_same_v) { return ErrorSpec::Builder() .abs_err( Atan2CpuBf16F32F64AbsErr(left, right)) @@ -383,6 +396,13 @@ BINARY_TEST(Atan2, { return ErrorSpec::Builder().strict_signed_zeros().build(); }) .GpuError(+[](NativeT, NativeT) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } if constexpr (std::is_same_v || std::is_same_v) { return ErrorSpec::Builder() diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc index 1e393f20078e32..2d243ad7f90993 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/fp_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/shaped_buffer.h" @@ -864,11 +865,15 @@ template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc index c6dafe39b1f2ff..ae339f82c70b2c 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -44,11 +44,15 @@ template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); } diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h index 62bf6786330ec4..1f4985a4c9e8c4 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -39,7 +39,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/client/xla_builder.h" -#include "xla/fp_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/tests/exhaustive/error_spec.h" @@ -51,7 +50,7 @@ namespace exhaustive_op_test { // The primitive type used to compute the reference output. constexpr PrimitiveType Ref(PrimitiveType T) { - return !primitive_util::IsFloatingPointType(T) || T == F64 ? T : F32; + return (!primitive_util::IsFloatingPointType(T) || T == F64) ? T : F32; } // The primitive type of the component of T. If T is not complex, then @@ -195,6 +194,16 @@ inline ErrorSpec DefaultSpecGenerator(xla::bfloat16) { return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e4m3fn) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e5m2) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + template <> inline ErrorSpec DefaultSpecGenerator(double, double) { double atol = kDefaultAbsoluteToleranceSlackFactor * @@ -231,6 +240,18 @@ inline ErrorSpec DefaultSpecGenerator(bfloat16, bfloat16) { return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e4m3fn, + tsl::float8_e4m3fn) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e5m2, + tsl::float8_e5m2) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + template typename ExhaustiveOpTestTraits::ErrorSpecGen GetDefaultSpecGenerator() { // Select overload by casting to fn ptr type. @@ -782,7 +803,13 @@ CreateSubnormalExhaustiveRanges() { return ret; } -inline std::vector> CreateExhaustiveF32Ranges() { +inline std::vector> CreateExhaustiveU16Ranges() { + // The entire U16 range is small enough that we don't need to do any + // partitioning. + return {{0, std::numeric_limits::max()}}; +} + +inline std::vector> CreateExhaustiveU32Ranges() { // We break up the 2^32-element space into small-ish chunks to keep peak // memory usage low. std::vector> result; diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc index 64f491fbb7bcb7..dc160bac741954 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc @@ -67,9 +67,11 @@ class Exhaustive32BitOrLessUnaryTest } }; -using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; -using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF8E4M3FNUnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF8E5M2UnaryTest = Exhaustive32BitOrLessUnaryTest; using ExhaustiveBF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; // Exhaustive test for unary operations for double. // @@ -105,6 +107,22 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, } }; +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +#define UNARY_TEST_F8E4M3FN(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E4M3FNUnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_E4M3FN(test_name, ...) +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +#define UNARY_TEST_F8E5M2(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E5M2UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_E5M2(test_name, ...) +#endif + #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 #define UNARY_TEST_BF16(test_name, ...) \ XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ @@ -133,8 +151,10 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, #define UNARY_TEST_F64(test_name, ...) #endif -#define UNARY_TEST(test_name, ...) \ - UNARY_TEST_BF16(test_name, __VA_ARGS__) \ - UNARY_TEST_F16(test_name, __VA_ARGS__) \ - UNARY_TEST_F32(test_name, __VA_ARGS__) \ +#define UNARY_TEST(test_name, ...) \ + UNARY_TEST_F8E4M3FN(test_name, __VA_ARGS__) \ + UNARY_TEST_F8E5M2(test_name, __VA_ARGS__) \ + UNARY_TEST_BF16(test_name, __VA_ARGS__) \ + UNARY_TEST_F16(test_name, __VA_ARGS__) \ + UNARY_TEST_F32(test_name, __VA_ARGS__) \ UNARY_TEST_F64(test_name, __VA_ARGS__) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc index a958e2bbc88c74..b0c1a087b9283b 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc @@ -13,6 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +INSTANTIATE_TEST_SUITE_P(F8E4M3FN, ExhaustiveF8E4M3FNUnaryTest, + ::testing::Values(std::make_pair(0, 1 << 8))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNUnaryTest); +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 8))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2UnaryTest); +#endif + #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, ::testing::Values(std::make_pair(0, 1 << 16))); @@ -28,6 +42,6 @@ GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); #endif INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc index b558fb85f3f8e8..a2e67ff4f8fb0c 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNUnaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2UnaryTest); + GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc index 5baa12f15ca455..a8af6bcc915bc4 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" #include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" #include "xla/tests/exhaustive/test_op.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc +#include "xla/types.h" #ifdef __FAST_MATH__ #error "Can't be compiled with fast math on" @@ -41,9 +42,39 @@ namespace { #include "xla/tests/exhaustive/exhaustive_unary_test_ops.inc" UNARY_TEST(Log, { LogOp(this).Error(GetDefaultSpecGenerator()).Run(); }) -UNARY_TEST(Log1p, { Log1pOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Log1p, { + Log1pOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) -UNARY_TEST(Exp, { ExpOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Exp, { + ExpOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) UNARY_TEST(Expm1, { Expm1Op(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Logistic, { @@ -54,8 +85,20 @@ UNARY_TEST(Logistic, { } return std::abs(out) <= 1.0f; }) - // FIXME(rmlarsen): Break into region around zero and everything else. - .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + // FIXME(rmlarsen): Break into region around zero and everything else. + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + // FIXME(rmlarsen): Break into region around zero and everything else. + return GetDefaultSpecGenerator()(x); + }) .Run(); }) @@ -135,13 +178,46 @@ UNARY_TEST(Acosh, { }) .Run(); }) -UNARY_TEST(Asinh, { AsinhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) -UNARY_TEST(Atanh, { AtanhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Asinh, { + AsinhOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) +UNARY_TEST(Atanh, { + AtanhOp(this) + .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) // Tests for inverse trigonometric functions. UNARY_TEST(Acos, { AcosOp(this) - .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) .GpuError(+[](NativeT x) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(1e-6).rel_err(10 * eps).build(); @@ -175,12 +251,44 @@ UNARY_TEST(Atan, { UNARY_TEST(Cosh, { CoshOp(this) - .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(3).build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(4).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) .OutputRangeCheck( +[](NativeInputs in, NativeT actual) { return !(actual < 1); }) .Run(); }) -UNARY_TEST(Sinh, { SinhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Sinh, { + SinhOp(this) + .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(3).build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(4).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) UNARY_TEST(Tanh, { TanhOp(this) .Error(GetDefaultSpecGenerator()) @@ -275,7 +383,14 @@ UNARY_TEST(ErfInv, { UNARY_TEST(Digamma, { DigammaOp(this) - .Error(+[](NativeT x) { + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); + }) + .GpuError(+[](NativeT x) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); }) From 6678d914676e18b5a2e2884ffab6de3d59736c2a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 14:42:50 -0700 Subject: [PATCH 384/483] In a previous change, we throw an error if we encounter an unknown sharding when saving shardings for instructions. However, this ingored the fact that we deliberately replace some module parameter/root shardings with unknown sharding objects. This CL makes the condition tigher so we only throw an error when we encounter unknown sharding objects intended for shard_as or shard_like annotations, which was the original intention anyway. PiperOrigin-RevId: 679734540 --- .../xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index d69764435604cc..29ad335d888e23 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2572,7 +2572,8 @@ absl::Status SaveShardingForInstruction( if (!inst->has_sharding()) { return absl::OkStatus(); } - if (inst->sharding().IsUnknown()) { + if (inst->sharding().IsUnknown() && + (inst->sharding().IsShardLike() || inst->sharding().IsShardAs())) { return absl::UnimplementedError( "Auto-sharding currently does not support shard_as/shard_like " "sharding annotations"); From 7a036756f0590b0e490f721bca1ef9eb88cbd81d Mon Sep 17 00:00:00 2001 From: Arturo Schmidt Date: Fri, 27 Sep 2024 14:56:06 -0700 Subject: [PATCH 385/483] Replace usage of SavedModelObjectGraphImporter for formalized API. Decouples SavedModelObjectGraphImporter from importer base which is used behind the scenes in the defined API of ConvertGraphToMlir (eventually ConvertGraphToTfExecutor). PiperOrigin-RevId: 679738623 --- .../mlir/tensorflow/translate/import_model.cc | 29 ++++++++++++------- .../translate/mlir_roundtrip_flags.h | 5 ++++ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3a6572dbfa7f7c..b6c704a4431fd4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -2564,6 +2564,14 @@ absl::StatusOr> GraphDefImporter::Convert( TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(graph_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs)); + // TODO(b/370078030): add tests once migration and code decoupling is + // complete. + if (specs.convert_all_functions_to_mlir) { + auto fn_names = graph.flib_def().ListFunctionNames(); + for (const auto& fn_name : fn_names) { + TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); + } + } TF_RETURN_IF_ERROR(importer.ImporterBase::ConvertDeferredFunctions()); // Mark main function public, others private. @@ -3612,17 +3620,15 @@ SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, options, std::move(preprocessed_graphdef), &graph)); NameUniquifier function_name_uniquifier(graph.flib_def()); - SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs, - module.get(), &tf_name_to_mlir_name, - &function_name_uniquifier); - - TF_RETURN_IF_ERROR(importer.PrepareConvert(graph)); - - auto fn_names = graph.flib_def().ListFunctionNames(); - for (const auto& fn_name : fn_names) { - TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); + for (const auto& fn_name : graph.flib_def().ListFunctionNames()) { + std::string mlir_func_name(function_name_uniquifier.GetUniqueName(fn_name)); + (tf_name_to_mlir_name)[std::string(fn_name)] = mlir_func_name; } - TF_RETURN_IF_ERROR(importer.ConvertDeferredFunctions()); + + specs.convert_all_functions_to_mlir = true; + TF_ASSIGN_OR_RETURN( + module, ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs, + module->getContext())); if (!saved_model->meta_graph_def().has_object_graph_def()) { return errors::InvalidArgument( @@ -3639,7 +3645,8 @@ SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, llvm::make_early_inc_range(module->getOps())) { if (func.getName().starts_with("__inference__traced_save_") || func.getName().starts_with("__inference__traced_restore_") || - func.getName().starts_with("__inference_signature_wrapper_")) { + func.getName().starts_with("__inference_signature_wrapper_") || + func.getName().starts_with("main")) { func.erase(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index fca039c2601636..8873b0928b028f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -99,6 +99,11 @@ struct GraphImportConfig { // If true, a function attribute, `tf._original_func_name`, will be set in // functions which contains the corresponding original TF function name. bool set_original_tf_func_name = false; + + // If true, all functions in the graph will be converted to MLIR regardless of + // whether the functions are referenced by the nodes. This is needed if + // aliases and saved model object graph function matching is needed. + bool convert_all_functions_to_mlir = false; }; struct GraphExportConfig { From bc4b8d10d6582b26c6914d2392f80a7953024dd2 Mon Sep 17 00:00:00 2001 From: Mehrdad Khani Date: Fri, 27 Sep 2024 15:18:46 -0700 Subject: [PATCH 386/483] [XLA:MSA] Updates the CopyAllocation to support the allocation for sliced data movements. PiperOrigin-RevId: 679746040 --- .../xla/service/memory_space_assignment/BUILD | 17 ++ .../memory_space_assignment/allocation.cc | 33 ++-- .../memory_space_assignment/allocation.h | 10 +- .../allocation_test.cc | 153 ++++++++++++++++++ 4 files changed, 203 insertions(+), 10 deletions(-) create mode 100644 third_party/xla/xla/service/memory_space_assignment/allocation_test.cc diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index f3f989d083f8ea..01ce6a6f722904 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -237,6 +237,23 @@ cc_library( ], ) +xla_cc_test( + name = "allocation_test", + srcs = ["allocation_test.cc"], + deps = [ + ":allocation", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "tuning_utils", srcs = ["tuning_utils.cc"], diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.cc b/third_party/xla/xla/service/memory_space_assignment/allocation.cc index 54edfba58e5276..8e65e283aab149 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.cc +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.cc @@ -303,7 +303,8 @@ CopyAllocation::CopyAllocation( std::optional chunk, int64_t copy_start_schedule_after_time, int64_t copy_done_schedule_before_time, int64_t end_time, - std::optional cross_program_prefetch_index) + std::optional cross_program_prefetch_index, + HloInstruction* sync_instruction) : Allocation( /*defining_position=*/{nullptr, {}}, memory_space, chunk, // Allocation uses an inclusive start time @@ -312,7 +313,8 @@ CopyAllocation::CopyAllocation( /*is_scoped_allocation=*/false, cross_program_prefetch_index), prev_allocation_(prev_allocation), copy_start_schedule_after_(copy_start_schedule_after_time), - copy_done_schedule_before_(copy_done_schedule_before_time) {} + copy_done_schedule_before_(copy_done_schedule_before_time), + sync_instruction_(sync_instruction) {} int64_t CopyAllocation::earliest_available_time() const { return copy_done_schedule_before_; @@ -323,11 +325,23 @@ absl::Status CopyAllocation::Process() { Shape shape = defining_position().shape(); HloInstruction* producing_instruction = AddGetTupleElements(); HloComputation* computation = producing_instruction->parent(); - copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( - ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), - producing_instruction, cross_program_prefetch_index())); - copy_done_ = computation->AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); + if (sync_instruction_ != nullptr && + sync_instruction_->opcode() != HloOpcode::kCopy) { + TF_ASSIGN_OR_RETURN(copy_done_, + computation->CreateAsyncInstructions( + sync_instruction_, {ShapeUtil::MakeShape(S32, {})}, + HloInstruction::kMainExecutionThread, false)); + copy_start_ = copy_done_->mutable_operand(0); + TF_RETURN_IF_ERROR( + copy_start_->ReplaceOperandWith(0, producing_instruction)); + } else { + copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( + ShapeUtil::MakeTupleShape( + {shape, shape, ShapeUtil::MakeShape(U32, {})}), + producing_instruction, cross_program_prefetch_index())); + copy_done_ = computation->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); + } VLOG(4) << "Created " << copy_start_->name() << " for copy allocation: " << ToString(); @@ -359,8 +373,9 @@ std::string CopyAllocation::ToString() const { ", start_time:", start_time(), ", end_time:", end_time(), ", copy_start_after_time: ", copy_start_schedule_after(), ", copy_done_before_time: ", copy_done_schedule_before(), - ", uses: ", UsesToString(uses()), ", from ", - prev_allocation_.ToString()); + ", uses: ", UsesToString(uses()), ", sync_instruction: ", + sync_instruction_ ? sync_instruction_->name() : "none", + ", from ", prev_allocation_.ToString()); } HloPosition CopyAllocation::defining_position() const { diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.h b/third_party/xla/xla/service/memory_space_assignment/allocation.h index b576ade9fcb34e..4fd74a0a3b2066 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.h +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ +#include + #include #include #include @@ -27,10 +29,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_value.h" @@ -240,7 +244,8 @@ class CopyAllocation final : public Allocation { std::optional chunk, int64_t copy_start_schedule_after_time, int64_t copy_done_schedule_before_time, int64_t end_time, - std::optional cross_program_prefetch_index = std::nullopt); + std::optional cross_program_prefetch_index = std::nullopt, + HloInstruction* sync_instruction = nullptr); // Overridden methods // @@ -263,6 +268,7 @@ class CopyAllocation final : public Allocation { bool operator==(const Allocation& other) const override; // New non-virtual methods + const HloInstruction* sync_instruction() const { return sync_instruction_; } bool operator==(const CopyAllocation& other) const; const Allocation& prev_allocation() { return prev_allocation_; } @@ -286,6 +292,8 @@ class CopyAllocation final : public Allocation { int64_t copy_done_schedule_before_; HloInstruction* copy_start_ = nullptr; HloInstruction* copy_done_ = nullptr; + // The sync data movement instruction that this copy is associated with. + HloInstruction* sync_instruction_ = nullptr; }; // This class represents an allocation resulting from asynchronous sliced diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation_test.cc b/third_party/xla/xla/service/memory_space_assignment/allocation_test.cc new file mode 100644 index 00000000000000..bba0f400507f4c --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/allocation_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/allocation.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::memory_space_assignment { +namespace { + +class AllocationTest : public HloTestBase {}; + +TEST_F(AllocationTest, CopyAllocationProcessSimple) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = f32[2,3]{1,0} parameter(0) + p1 = f32[2,3]{1,0} parameter(1) + p1_negate = f32[2,3]{1,0} negate(p1) + add = f32[2,3]{1,0} add(p0, p1_negate) + ROOT tuple = tuple(add, p0) +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + // HloComputation* computation = module->entry_computation(); + HloInstruction* add = FindInstruction(module.get(), "add"); + HloInstruction* p1_negate = FindInstruction(module.get(), "p1_negate"); + + HeapSimulator::Chunk p1_negate_chunk = + HeapSimulator::Chunk::FromOffsetSize(0, 24); + + PinnedAllocation p1_negate_pinned( + HloPosition{p1_negate, {}}, MemorySpace::kAlternate, p1_negate_chunk, + /*start_time=*/0, + /*end_time=*/5, /*is_scoped_allocation=*/false); + CopyAllocation copy_allocation(p1_negate_pinned, MemorySpace::kAlternate, + std::nullopt, + /*copy_start_schedule_after_time=*/2, + /*copy_done_schedule_before_time=*/3, + /*end_time=*/5, std::nullopt, + /*sync_instruction=*/nullptr); + + // Use the correct instruction and operand numbers for the add instruction + copy_allocation.AddUse(HloUse{add, 1}); // Use of p1_negate in add + + TF_ASSERT_OK(copy_allocation.Process()); + + // Check copy_start and copy_done instructions. + HloInstruction* copy_start = copy_allocation.copy_start(); + ASSERT_NE(copy_start, nullptr); + EXPECT_EQ(copy_start->opcode(), HloOpcode::kCopyStart); + EXPECT_EQ(copy_start->operand(0), p1_negate); + + HloInstruction* copy_done = copy_allocation.copy_done(); + ASSERT_NE(copy_done, nullptr); + EXPECT_EQ(copy_done->opcode(), HloOpcode::kCopyDone); + EXPECT_EQ(copy_done->operand(0), copy_start); + + // Check that uses are updated. + EXPECT_EQ(add->operand(1), copy_done); + + // Check defining position + EXPECT_EQ(copy_allocation.defining_position().instruction, copy_done); +} + +TEST_F(AllocationTest, CopyAllocationProcessReplaceSyncSlice) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = f32[1,3]{1,0} parameter(0) + p1 = f32[2,3]{1,0} parameter(1) + p1_negate = f32[2,3]{1,0} negate(p1) + slice = f32[1,3]{1,0} slice(p1_negate), slice={[0:1], [0:3]} + add = f32[1,3]{1,0} add(p0, slice) + ROOT tuple = tuple(add, p0) +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + // HloComputation* computation = module->entry_computation(); + HloInstruction* add = FindInstruction(module.get(), "add"); + HloInstruction* p1_negate = FindInstruction(module.get(), "p1_negate"); + HloInstruction* slice = FindInstruction(module.get(), "slice"); + + HeapSimulator::Chunk p1_negate_chunk = + HeapSimulator::Chunk::FromOffsetSize(0, 24); + + PinnedAllocation p1_negate_pinned( + HloPosition{p1_negate, {}}, MemorySpace::kAlternate, p1_negate_chunk, + /*start_time=*/0, + /*end_time=*/5, /*is_scoped_allocation=*/false); + CopyAllocation copy_allocation(p1_negate_pinned, MemorySpace::kAlternate, + std::nullopt, + /*copy_start_schedule_after_time=*/2, + /*copy_done_schedule_before_time=*/3, + /*end_time=*/5, std::nullopt, + /*sync_instruction=*/slice); + + // Use the correct instruction and operand numbers for the add instruction + copy_allocation.AddUse(HloUse{add, 1}); // Use of p1_negate in add + + TF_ASSERT_OK(copy_allocation.Process()); + + // Check copy_start and copy_done instructions. + HloInstruction* slice_start = copy_allocation.copy_start(); + ASSERT_NE(slice_start, nullptr); + EXPECT_EQ(slice_start->opcode(), HloOpcode::kAsyncStart); + EXPECT_EQ(slice_start->operand(0), p1_negate); + + HloInstruction* slice_done = copy_allocation.copy_done(); + ASSERT_NE(slice_done, nullptr); + EXPECT_EQ(slice_done->opcode(), HloOpcode::kAsyncDone); + EXPECT_EQ(slice_done->operand(0), slice_start); + + // Check the shapes. + EXPECT_EQ(slice_done->shape(), slice->shape()); + + // Check that uses are updated. + EXPECT_EQ(add->operand(1), slice_done); + + // Check defining position + EXPECT_EQ(copy_allocation.defining_position().instruction, slice_done); +} + +} // namespace +} // namespace xla::memory_space_assignment From 542378cdbc3573331d925b7bd562deb23b427d85 Mon Sep 17 00:00:00 2001 From: Vladyslav Tsilytskyi Date: Fri, 27 Sep 2024 15:37:57 -0700 Subject: [PATCH 387/483] [XLA:CPU] Add a generic sort kernel to SortThunk This PR duplicates templated code as non-templated with "D" prefix. It is done for every function, which uses "n" directly, like loops. Thus the only unified class is the SortIterator, which operates on Value, Ref and Ptr abstractions, which, in turn, differ. PiperOrigin-RevId: 679752366 --- .../xla/xla/backends/cpu/runtime/BUILD | 4 +- .../xla/backends/cpu/runtime/sort_thunk.cc | 180 ++++++++++++++++-- .../backends/cpu/runtime/sort_thunk_test.cc | 173 +++++++++++++++++ 3 files changed, 341 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index 8e60bd7b30582a..86ee9f010a70bc 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -950,14 +950,12 @@ xla_cc_test( "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", - "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", - "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc index 30c5e1a1b34897..990bd523cae461 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" @@ -131,6 +132,7 @@ static constexpr size_t kMaxElementSize = 16; // Forward declare reference type defined below. template struct Ref; +struct DRef; // Value type to store values loaded from the input buffers. template @@ -145,6 +147,18 @@ struct Value { std::array value_sizes; }; +struct DValue { + DValue(const DRef& ref); // NOLINT + + const void* compared_value(size_t i) const { return value[i].data(); } + + // Use properly aligned byte array to store primitive values. + using ValueStorage = std::array; + std::vector value; + std::vector value_sizes; + size_t n; +}; + // Reference to values stored in the input buffers. template struct Ref { @@ -160,6 +174,20 @@ struct Ref { std::array ptr_sizes; }; +struct DRef { + DRef(std::vector ptr, std::vector ptr_sizes) + : ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {} + + DRef& operator=(const DValue& value); + DRef& operator=(const DRef& other); + + const void* compared_value(size_t i) const { return ptr[i]; } + + std::vector ptr; + std::vector ptr_sizes; + const size_t n; +}; + template Value::Value(const Ref& ref) : value_sizes(ref.ptr_sizes) { for (size_t i = 0; i < n; ++i) { @@ -167,6 +195,15 @@ Value::Value(const Ref& ref) : value_sizes(ref.ptr_sizes) { } } +DValue::DValue(const DRef& ref) + : value_sizes(ref.ptr_sizes), n(ref.ptr.size()) { + value.reserve(n); + for (size_t i = 0; i < n; ++i) { + value.emplace_back(); + std::memcpy(value[i].data(), ref.ptr[i], ref.ptr_sizes[i]); + } +} + template Ref& Ref::operator=(const Value& value) { DCHECK(ptr_sizes == value.value_sizes); @@ -176,6 +213,14 @@ Ref& Ref::operator=(const Value& value) { return *this; } +DRef& DRef::operator=(const DValue& value) { + DCHECK(ptr_sizes == value.value_sizes); + for (size_t i = 0; i < n; ++i) { + std::memcpy(ptr[i], value.value[i].data(), value.value_sizes[i]); + } + return *this; +} + template Ref& Ref::operator=(const Ref& other) { DCHECK(ptr_sizes == other.ptr_sizes); @@ -185,6 +230,15 @@ Ref& Ref::operator=(const Ref& other) { return *this; } +DRef& DRef::operator=(const DRef& other) { + DCHECK(ptr_sizes == other.ptr_sizes); + const size_t n = other.ptr.size(); + for (size_t i = 0; i < n; ++i) { + std::memcpy(ptr[i], other.ptr[i], other.ptr_sizes[i]); + } + return *this; +} + // Swap function required by `std::sort` and `std::stable_sort` implementations. template void swap(const Ref& lhs, const Ref& rhs) { @@ -196,6 +250,17 @@ void swap(const Ref& lhs, const Ref& rhs) { } } +void swap(const DRef& lhs, const DRef& rhs) { + DCHECK(lhs.ptr_sizes == rhs.ptr_sizes); + const size_t n = lhs.ptr.size(); + for (size_t i = 0; i < n; ++i) { + std::array tmp; + std::memcpy(tmp.data(), lhs.ptr[i], lhs.ptr_sizes[i]); + std::memcpy(lhs.ptr[i], rhs.ptr[i], rhs.ptr_sizes[i]); + std::memcpy(rhs.ptr[i], tmp.data(), lhs.ptr_sizes[i]); + } +} + // An array of pointers to the input data. template struct Ptr { @@ -250,19 +315,72 @@ struct Ptr { std::array ptr_sizes; // pointers sizes in bytes }; +struct DPtr { + using difference_type = std::ptrdiff_t; + + DPtr() = default; + + DPtr(std::vector ptr, std::vector ptr_sizes) + : ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {} + + DRef operator*() const { return DRef{ptr, ptr_sizes}; } + + DPtr& operator+=(difference_type diff) { + for (size_t i = 0; i < n; ++i) ptr[i] += diff * ptr_sizes[i]; + return *this; + } + + DPtr& operator-=(difference_type diff) { + for (size_t i = 0; i < n; ++i) ptr[i] -= diff * ptr_sizes[i]; + return *this; + } + + DPtr operator+(difference_type diff) const { + std::vector upd(n); + for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] + diff * ptr_sizes[i]; + return DPtr{upd, ptr_sizes}; + } + + DPtr operator-(difference_type diff) const { + std::vector upd(n); + for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] - diff * ptr_sizes[i]; + return DPtr{upd, ptr_sizes}; + } + + // In all comparison operators defined below we use only the ptr at index 0, + // because we know that all pointers change together and this is an + // implementation detail of sort iterator. + + difference_type operator-(const DPtr& rhs) const { + DCHECK(ptr_sizes == rhs.ptr_sizes); + return (ptr[0] - rhs.ptr[0]) / ptr_sizes[0]; + } + + bool operator==(const DPtr& rhs) const { return ptr[0] == rhs.ptr[0]; } + bool operator!=(const DPtr& rhs) const { return ptr[0] != rhs.ptr[0]; } + bool operator>(const DPtr& rhs) const { return ptr[0] > rhs.ptr[0]; } + bool operator<(const DPtr& rhs) const { return ptr[0] < rhs.ptr[0]; } + bool operator>=(const DPtr& rhs) const { return ptr[0] >= rhs.ptr[0]; } + bool operator<=(const DPtr& rhs) const { return ptr[0] <= rhs.ptr[0]; } + + std::vector ptr; // pointers into the input buffers + std::vector ptr_sizes; // pointers sizes in bytes + size_t n; +}; + // We rely on `std::sort` and `std::stable_sort` to sort the raw data. We sort // multiple input buffers together using the same comparator function, so we // need to provide a custom iterator that can access the data of all input // buffers at the same time and swap elements in them. -template +template class SortIterator { public: using iterator_category = std::random_access_iterator_tag; using difference_type = std::ptrdiff_t; - using value_type = Value; - using reference = Ref; - using pointer = Ptr; + using value_type = Value; + using reference = Ref; + using pointer = Ptr; SortIterator() = default; SortIterator(pointer ptr, difference_type stride) @@ -388,8 +506,40 @@ static void SortInplace(const SortDims& sort_dims, int64_t offset, return (*less_than)(data.data()); }; - SortIterator begin(Ptr(ptr, ptr_sizes), - /*stride=*/sort_dims.inner_dim_size); + SortIterator, Ref, Ptr> begin( + Ptr(ptr, ptr_sizes), + /*stride=*/sort_dims.inner_dim_size); + if (is_stable) { + std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); + } else { + std::sort(begin, begin + sort_dims.sort_dim_size, compare); + } +} + +static void DSortInplace(const SortDims& sort_dims, int64_t offset, + absl::Span data, + absl::Span shapes, bool is_stable, + SortThunk::LessThan* less_than, size_t n) { + std::vector ptr(n); + std::vector ptr_sizes(n); + + for (size_t i = 0; i < n; ++i) { + std::byte* base = reinterpret_cast(data[i].opaque()); + ptr_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type()); + ptr[i] = base + offset * ptr_sizes[i]; + } + + auto compare = [&](const auto& a, const auto& b) { + std::vector data(2 * n); + for (size_t i = 0, j = 0; i < n; i += 1, j += 2) { + data[j] = a.compared_value(i); + data[j + 1] = b.compared_value(i); + } + return (*less_than)(data.data()); + }; + + SortIterator begin(DPtr(ptr, ptr_sizes), + /*stride=*/sort_dims.inner_dim_size); if (is_stable) { std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); } else { @@ -416,9 +566,15 @@ static absl::Status SortInplace(absl::Span data, is_stable, less_than); }; - // TODO(ezhulenev): We can replace statically known number of sorted inputs - // with a dynamic value, however statically known number of inputs allows - // compiler to generate better code. Benchmark if it really matters. + auto dsort = [&](size_t num_inputs) { + DSortInplace(sort_dims, offset, data, shapes, is_stable, less_than, + num_inputs); + }; + + // use "sort" for statically known number of sorted inputs (expected to be + // faster) and "dsort" for dynamically known number of sorted inputs. + // for 100 elements stable sort is 1.5 times faster than stable dsort. + // for 100 elements unstable sort is 2.47 times faster than unstable dsort. switch (data.size()) { case 1: sort(std::integral_constant{}); @@ -495,11 +651,9 @@ static absl::Status SortInplace(absl::Span data, case 25: sort(std::integral_constant{}); break; - case 29: - sort(std::integral_constant{}); - break; default: - return Internal("Unsupported number of sorted inputs: %d", data.size()); + dsort(data.size()); + break; } } diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc index 1f450f77548d70..6c8dfae0b65a8e 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc @@ -15,8 +15,10 @@ limitations under the License. #include "xla/backends/cpu/runtime/sort_thunk.h" +#include #include #include +#include #include #include @@ -34,6 +36,7 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" namespace xla::cpu { namespace { @@ -100,6 +103,83 @@ TEST_P(SortThunkTest, Sort1D) { EXPECT_EQ(indices, expected_indices); } +TEST_P(SortThunkTest, DynamicSort1D) { + bool is_stable = GetParam(); + + // 33 empty slices + 2 slices with data = 35 slices + // This amount of slices will call the dynamic sort implementation. + constexpr int num_of_empty_slices = 33; + constexpr int total_num_of_slices = num_of_empty_slices + 2; + + // size of each of 33 data buffers + constexpr int data_size = 31; + + // values range will be [5.0, 35.0] + constexpr float starting_value = 5.0f; + + std::array data{ + 17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, + 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, + 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, + 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}; + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + + // This is a container for the rest of the buffers. + std::array empty; + + const size_t data_size_in_bytes = data.size() * sizeof(float); + const size_t ind_size_in_bytes = indices.size() * sizeof(int32_t); + const size_t empty_size_in_bytes = empty.size() * sizeof(uint32_t); + + const BufferAllocation alloc0(0, data_size_in_bytes, 0); + const BufferAllocation alloc1(1, ind_size_in_bytes, 0); + const BufferAllocation rest(2, empty_size_in_bytes, 0); + + const BufferAllocation::Slice slice0(&alloc0, 0, data_size_in_bytes); + const BufferAllocation::Slice slice1(&alloc1, 0, ind_size_in_bytes); + + const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + const Shape indices_shape = ShapeUtil::MakeShape(S32, {data_size}); + const Shape rest_shape = ShapeUtil::MakeShape(U32, {data_size}); + + std::vector buffers; + buffers.emplace_back(se::DeviceMemoryBase(data.data(), data_size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(indices.data(), ind_size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(empty.data(), empty_size_in_bytes)); + + BufferAllocations allocations(buffers); + + std::array inputs{ + {{slice0, data_shape}, {slice1, indices_shape}}}; + for (int i = 0; i < num_of_empty_slices; ++i) { + constexpr size_t empty_slice_in_bytes = data_size * sizeof(uint32_t); + inputs[i + 2].slice = BufferAllocation::Slice( + &rest, i * empty_slice_in_bytes, empty_slice_in_bytes); + inputs[i + 2].shape = rest_shape; + } + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, SortThunk::Create({"sort"}, inputs, + /*dimension=*/0, is_stable, LessThan)); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + std::array expected_data; + std::iota(expected_data.begin(), expected_data.end(), starting_value); + const std::array expected_indices{ + 2, 28, 20, 5, 6, 3, 30, 13, 21, 8, 24, 1, 0, 16, 12, 26, + 7, 15, 19, 25, 14, 22, 29, 11, 10, 4, 27, 9, 23, 18, 17}; + + EXPECT_EQ(data, expected_data); + EXPECT_EQ(indices, expected_indices); +} + TEST_P(SortThunkTest, Sort2D) { bool is_stable = GetParam(); @@ -237,6 +317,99 @@ TEST_P(SortThunkTest, Sort2DWithLayout) { EXPECT_EQ(indices, expected_indices); } +void BM_DynamicSort1D(::testing::benchmark::State& state, bool is_stable) { + const int total_num_of_slices = state.range(0); + const int num_of_empty_slices = total_num_of_slices - 2; + + // size of each of data buffers + constexpr int data_size = 31; + + const std::array data{ + 17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, + 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, + 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, + 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}; + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + + // This is the container for the rest of the buffers. + std::vector empty(data_size * num_of_empty_slices); + + const size_t data_size_in_bytes = data.size() * sizeof(float); + const size_t ind_size_in_bytes = indices.size() * sizeof(int32_t); + const size_t empty_size_in_bytes = empty.size() * sizeof(uint32_t); + + const BufferAllocation alloc0(0, data_size_in_bytes, 0); + const BufferAllocation alloc1(1, ind_size_in_bytes, 0); + const BufferAllocation rest(2, empty_size_in_bytes, 0); + + const BufferAllocation::Slice slice0(&alloc0, 0, data_size_in_bytes); + const BufferAllocation::Slice slice1(&alloc1, 0, ind_size_in_bytes); + + const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + const Shape indices_shape = ShapeUtil::MakeShape(S32, {data_size}); + const Shape rest_shape = ShapeUtil::MakeShape(U32, {data_size}); + + for (auto s : state) { + // Pause timing to avoid counting the time spent in the setup. + state.PauseTiming(); + auto data_clone(data); + auto indices_clone(indices); + + std::vector buffers; + buffers.emplace_back( + se::DeviceMemoryBase(data_clone.data(), data_size_in_bytes)); + buffers.emplace_back( + se::DeviceMemoryBase(indices_clone.data(), ind_size_in_bytes)); + buffers.emplace_back( + se::DeviceMemoryBase(empty.data(), empty_size_in_bytes)); + + BufferAllocations allocations(buffers); + + std::vector inputs(total_num_of_slices); + inputs[0] = {slice0, data_shape}; + inputs[1] = {slice1, indices_shape}; + for (int i = 0; i < num_of_empty_slices; ++i) { + constexpr size_t empty_slice_in_bytes = data_size * sizeof(uint32_t); + inputs[i + 2].slice = BufferAllocation::Slice( + &rest, i * empty_slice_in_bytes, empty_slice_in_bytes); + inputs[i + 2].shape = rest_shape; + } + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + state.ResumeTiming(); + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, SortThunk::Create({"sort"}, inputs, + /*dimension=*/0, is_stable, LessThan)); + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + } +} + +void BM_StableDynamicSort1D(::testing::benchmark::State& state) { + BM_DynamicSort1D(state, /*is_stable=*/true); +} + +void BM_UnstableDynamicSort1D(::testing::benchmark::State& state) { + BM_DynamicSort1D(state, /*is_stable=*/false); +} + +BENCHMARK(BM_StableDynamicSort1D) + ->MeasureProcessCPUTime() + ->Arg(35) + ->Arg(50) + ->Arg(100); + +BENCHMARK(BM_UnstableDynamicSort1D) + ->MeasureProcessCPUTime() + ->Arg(35) + ->Arg(50) + ->Arg(100); + INSTANTIATE_TEST_SUITE_P(SortThunk, SortThunkTest, testing::Bool(), testing::PrintToStringParamName()); From 4e9171c74a58f2cbd8bc565f47b0e570ae2f218d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 16:01:47 -0700 Subject: [PATCH 388/483] Enable Runtime Uptime Telemetry in TensorFlow-2.18.0. PiperOrigin-RevId: 679758959 --- tensorflow/api_template.__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 5ea9ef248d55e2..86515dfaa27602 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -34,6 +34,8 @@ import site as _site import sys as _sys +_os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1") + # Do not remove this line; See https://github.com/tensorflow/tensorflow/issues/42596 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.tools import module_util as _module_util From 3f4b2fda6ffe7dfe03c1663ef37f54fc4432cc8b Mon Sep 17 00:00:00 2001 From: tchatow Date: Fri, 27 Sep 2024 16:18:36 -0700 Subject: [PATCH 389/483] PR #16882: Symlink hermetic cuda headers to permit clang cuda version detection Imported from GitHub PR https://github.com/openxla/xla/pull/16882 Fixes #16877 Copybara import of the project: -- 1ff356ac0870002b369c3ec09547aae2a62c70e2 by tchatow : Symlink hermetic cuda headers to permit clang cuda version detection Fixes #16877 Merging this change closes #16882 PiperOrigin-RevId: 679764212 --- .bazelrc | 2 ++ .../gpus/cuda/hermetic/cuda_redist_init_repositories.bzl | 5 +++++ third_party/xla/.bazelrc | 2 ++ third_party/xla/third_party/tsl/.bazelrc | 2 ++ .../gpus/cuda/hermetic/cuda_redist_init_repositories.bzl | 5 +++++ 5 files changed, 16 insertions(+) diff --git a/.bazelrc b/.bazelrc index bee4dc3e784a99..d6e47d9e6d279a 100644 --- a/.bazelrc +++ b/.bazelrc @@ -242,6 +242,8 @@ build:cuda_clang --copt=-Qunused-arguments # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Permit newer CUDA versions than Clang is aware of +build:cuda_clang --copt="-Wno-unknown-cuda-version" # Set lld as the linker. build:cuda_clang --host_linkopt="-fuse-ld=lld" build:cuda_clang --host_linkopt="-lm" diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl index 11b32cdbb71c56..ecc99f06455614 100644 --- a/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -219,6 +219,10 @@ def _create_libcuda_symlinks( repository_ctx.symlink(nvidia_driver_path, "lib/libcuda.so.1") repository_ctx.symlink("lib/libcuda.so.1", "lib/libcuda.so") +def _create_cuda_header_symlinks(repository_ctx): + if repository_ctx.name == "cuda_nvcc": + repository_ctx.symlink("../cuda_cudart/include/cuda.h", "include/cuda.h") + def use_local_path(repository_ctx, local_path, dirs): # buildifier: disable=function-docstring-args """Creates repository using local redistribution paths.""" @@ -339,6 +343,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): repository_ctx, lib_name_to_version_dict, ) + _create_cuda_header_symlinks(repository_ctx) repository_ctx.file("version.txt", major_version) def _cuda_repo_impl(repository_ctx): diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index bee4dc3e784a99..d6e47d9e6d279a 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -242,6 +242,8 @@ build:cuda_clang --copt=-Qunused-arguments # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Permit newer CUDA versions than Clang is aware of +build:cuda_clang --copt="-Wno-unknown-cuda-version" # Set lld as the linker. build:cuda_clang --host_linkopt="-fuse-ld=lld" build:cuda_clang --host_linkopt="-lm" diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index bee4dc3e784a99..d6e47d9e6d279a 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -242,6 +242,8 @@ build:cuda_clang --copt=-Qunused-arguments # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Permit newer CUDA versions than Clang is aware of +build:cuda_clang --copt="-Wno-unknown-cuda-version" # Set lld as the linker. build:cuda_clang --host_linkopt="-fuse-ld=lld" build:cuda_clang --host_linkopt="-lm" diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl index 11b32cdbb71c56..ecc99f06455614 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -219,6 +219,10 @@ def _create_libcuda_symlinks( repository_ctx.symlink(nvidia_driver_path, "lib/libcuda.so.1") repository_ctx.symlink("lib/libcuda.so.1", "lib/libcuda.so") +def _create_cuda_header_symlinks(repository_ctx): + if repository_ctx.name == "cuda_nvcc": + repository_ctx.symlink("../cuda_cudart/include/cuda.h", "include/cuda.h") + def use_local_path(repository_ctx, local_path, dirs): # buildifier: disable=function-docstring-args """Creates repository using local redistribution paths.""" @@ -339,6 +343,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): repository_ctx, lib_name_to_version_dict, ) + _create_cuda_header_symlinks(repository_ctx) repository_ctx.file("version.txt", major_version) def _cuda_repo_impl(repository_ctx): From a00ef0f7caeb2197f5b0eaa14c3cd6dedf24ddd4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 16:49:22 -0700 Subject: [PATCH 390/483] Implementation of simplify_ici_dummy_variables_pass PiperOrigin-RevId: 679773625 --- tensorflow/core/common_runtime/BUILD | 14 +- .../simplify_ici_dummy_variables_pass.cc | 288 +++++- .../simplify_ici_dummy_variables_pass_test.cc | 103 +- ...lify_ici_dummy_variables_pass_before.pbtxt | 919 ++++++++++++++++++ tensorflow/tensorflow.bzl | 2 + 5 files changed, 1317 insertions(+), 9 deletions(-) create mode 100644 tensorflow/core/common_runtime/testdata/simplify_ici_dummy_variables_pass_before.pbtxt diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 4fb8679e6e8458..aaaaba3cf74192 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -1214,13 +1214,21 @@ cc_library( hdrs = ["simplify_ici_dummy_variables_pass.h"], copts = tf_copts(), deps = [ - ":colocate_predecessor_trees_pass", ":optimization_registry", "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core/config:flag_defs", + "//tensorflow/core/config:flags", + "//tensorflow/core/framework:node_def_util", "//tensorflow/core/framework:tensor_proto_cc", "//tensorflow/core/framework:tensor_shape_proto_cc", + "//tensorflow/core/platform:bfloat16", + "//tensorflow/core/platform:logging", "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", ], alwayslink = 1, ) @@ -2424,6 +2432,9 @@ tf_cc_tests( "threadpool_device_test.cc", ], create_named_test_suite = True, + data = [ + "testdata/simplify_ici_dummy_variables_pass_before.pbtxt", + ], linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -2459,6 +2470,7 @@ tf_cc_tests( "//tensorflow/core/util:protos_test_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", ], diff --git a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc index 3ac8bec6145dc3..e65a25690dc4a6 100644 --- a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc +++ b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc @@ -15,20 +15,304 @@ limitations under the License. #include "tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h" +#include +#include +#include + +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/util/device_name_utils.h" +#include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/config/flag_defs.h" +#include "tensorflow/core/config/flags.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/bfloat16.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" +#include "tsl/platform/errors.h" namespace tensorflow { +namespace { + +constexpr absl::string_view kTpuExecute = "TPUExecute"; +constexpr absl::string_view kParallelExecuteIds = "_parallel_execution_ids"; +const char kICIWeightDistributionMlirBridgeMarker[] = + "ici_weight_distribution_mlir_bridge_marker"; + +// Get the new op name which is used to replace the old op, the new op name +// contains the index of the input and the task id of the TPUExecute node. +std::string GetNewOpName(std::string op_name, int index, int task_id) { + return absl::StrCat(op_name, "_ici_specific_index_", std::to_string(index), + "_task_id_", std::to_string(task_id)); +} + +// Find all the TPUExecute nodes that is not on replica 0. In addition, return +// an empty vector if there is a parallel execute id that is not 0, which +// indicates SPMD case. In the meantime, we check if this is a SPMD case. +std::vector GetNonMainReplicaIciTPUExecuteNodes(Graph* graph, + bool& is_spmd) { + std::vector tpu_nodes; + for (Node* node : graph->nodes()) { + if (node->type_string() == kTpuExecute && + HasNodeAttr(node->def(), kParallelExecuteIds)) { + auto parallel_exec_ids = node->attrs().Find(kParallelExecuteIds)->s(); + std::vector group_vec = + absl::StrSplit(parallel_exec_ids, ','); + if (group_vec.empty()) return tpu_nodes; + std::vector replica_vec = absl::StrSplit(group_vec[0], ':'); + int replica_id = std::stoi(replica_vec[1]); + if (replica_id != 0) tpu_nodes.push_back(node); + if (group_vec.size() > 1) { + std::vector parallel_vec = + absl::StrSplit(group_vec[1], ':'); + int parallel_id = std::stoi(parallel_vec[1]); + if (parallel_id != 0) is_spmd = true; + } + } + } + return tpu_nodes; +} + +// Remove the edge from old_src_node to dst_node, and add the edge from +// new_src_node to dst_node. +void RedirectEdge(Graph* graph, Node* old_src_node, Node* dst_node, + Node* new_src_node, int input_index) { + const Edge* delete_edge; + for (auto edge : dst_node->in_edges()) { + if (edge->src() == old_src_node) { + delete_edge = edge; + break; + } + } + if (delete_edge == nullptr) return; + + graph->RemoveEdge(delete_edge); + graph->AddEdge(new_src_node, 0, dst_node, input_index); +} + +// Find the corresponding host device name from the TPU device name. +string GetHostDeviceName(Node* tpu_node) { + auto device_name = tpu_node->requested_device(); + if (device_name.empty()) device_name = tpu_node->assigned_device_name(); + DeviceNameUtils::ParsedName parsed_device_name; + DeviceNameUtils::ParseFullName(device_name, &parsed_device_name); + string host_device_name = DeviceNameUtils::FullName( + parsed_device_name.job, parsed_device_name.replica, + parsed_device_name.task, /*type=*/"CPU", /*id=*/0); + return host_device_name; +} + +std::optional> GetOutputShapeVec(Node* node) { + auto output_shapes = node->attrs().Find("_output_shapes"); + if (output_shapes == nullptr) return std::nullopt; + auto output_shape = output_shapes->list().shape()[0]; + std::vector output_shape_vec; + output_shape_vec.reserve(output_shape.dim_size()); + for (auto i = 0; i < output_shape.dim_size(); i++) { + output_shape_vec.push_back(output_shape.dim()[i].size()); + } + return output_shape_vec; +} + +int GetTPUTaskId(Node* tpu_node) { + auto device_name = tpu_node->requested_device(); + if (device_name.empty()) device_name = tpu_node->assigned_device_name(); + DeviceNameUtils::ParsedName parsed_device_name; + DeviceNameUtils::ParseFullName(device_name, &parsed_device_name); + return parsed_device_name.task; +} + +// Build the fill op. Its value is 0 and the fill op is put on the host device +// with the same task id as the TPUExecute node. +Node* BuildFillOp(GraphDefBuilder::Options& bopts, Node* tpu_node, + Node* in_node, int input_index, string host_device_name) { + // Find the output_shape vector + auto output_shape_vec = GetOutputShapeVec(in_node); + if (!output_shape_vec.has_value()) return nullptr; + + // Find the element type + auto dtype = in_node->attrs().Find("T")->type(); + + // Get TPU task id. + int tpu_task_id = GetTPUTaskId(tpu_node); + + TensorShape tensor_shape; + tensor_shape.AddDim(output_shape_vec.value().size()); + Tensor const_op_shape_tensor(DT_INT32, tensor_shape); + for (int i = 0; i < output_shape_vec.value().size(); i++) { + const_op_shape_tensor.flat()(i) = output_shape_vec.value()[i]; + } + + // Build dim of fill op + std::string const_1_name = GetNewOpName("const_1", input_index, tpu_task_id); + Node* fill_dim_input = + ops::SourceOp("Const", bopts.WithName(const_1_name) + .WithAttr("dtype", DT_INT32) + .WithAttr("value", const_op_shape_tensor)); + TensorShape fill_dim_output_shape; + fill_dim_output_shape.AddDim(output_shape_vec.value().size()); + fill_dim_input->AddAttr("_output_shapes", + std::vector{fill_dim_output_shape}); + + // Build value of fill op + std::string const_2_name = GetNewOpName("const_2", input_index, tpu_task_id); + auto scalar_tensor = Tensor(dtype, {}); + + if (dtype == DT_FLOAT) { + scalar_tensor.scalar()() = 0; + } else if (dtype == DT_BFLOAT16) { + scalar_tensor.scalar()() = bfloat16(0); + } else { + LOG(ERROR) << "Unsupported data type: ", DataTypeString(dtype); + return nullptr; + } + Node* fill_value_input = + ops::SourceOp("Const", bopts.WithName(const_2_name) + .WithAttr("dtype", dtype) + .WithAttr("value", scalar_tensor)); + TensorShape fill_value_output_shape; + fill_value_input->AddAttr("_output_shapes", + std::vector{fill_value_output_shape}); + + // Build fill op + std::string fill_name = GetNewOpName("fill", input_index, tpu_task_id); + Node* new_fill = + ops::BinaryOp("Fill", fill_dim_input, fill_value_input, + bopts.WithName(fill_name).WithAttr("T", dtype)); + + TensorShape new_output_shape; + for (auto output_shape : output_shape_vec.value()) { + new_output_shape.AddDim(output_shape); + } + new_fill->AddAttr("_output_shapes", + std::vector{new_output_shape}); + new_fill->AddAttr("_xla_inferred_shapes", + std::vector{new_output_shape}); + + // Set the device to each node. + fill_dim_input->set_requested_device(host_device_name); + fill_value_input->set_requested_device(host_device_name); + new_fill->set_requested_device(host_device_name); + + return new_fill; +} + +// Replace the ici dummy variable with one on the right task id. +absl::Status ReplaceIciDummyVariables(Graph* graph, int input_index, + std::vector tpu_nodes, + GraphDefBuilder::Options& bopts) { + absl::flat_hash_map device_to_node_map; + for (Node* tpu_node : tpu_nodes) { + Node* in_node; + TF_RETURN_IF_ERROR(tpu_node->input_node(input_index, &in_node)); + + if (!in_node->attrs().Find(kICIWeightDistributionMlirBridgeMarker)) { + continue; + } + + string host_device_name = GetHostDeviceName(tpu_node); + + // If the node corresponding to host_device_name is already in the graph, + // replace the edge from in_node to tpu_node with the edge from + // device_to_node_map[host_device_name] to tpu_node. + if (device_to_node_map.contains(host_device_name)) { + RedirectEdge(graph, in_node, tpu_node, + device_to_node_map[host_device_name], input_index); + continue; + } + + Node* new_fill = + BuildFillOp(bopts, tpu_node, in_node, input_index, host_device_name); + if (new_fill == nullptr) continue; + + device_to_node_map[host_device_name] = new_fill; + RedirectEdge(graph, in_node, tpu_node, device_to_node_map[host_device_name], + input_index); + } + return absl::OkStatus(); +} + +} // namespace + +bool ShouldRunPass(const GraphOptimizationPassOptions& options) { + if (!flags::Global().enable_tf2min_ici_weight.value()) { + VLOG(1) << "SimplifyIciDummyVariablesPass is disabled."; + return false; + } + VLOG(1) << "SimplifyIciDummyVariablesPass is enabled."; + + // find all potential nodes. + if (options.graph == nullptr) { + LOG(INFO) << "No graph in simplify_ici_dummy_variables_pass."; + return false; + } + return true; +} Status SimplifyIciDummyVariablesPass::Run( const GraphOptimizationPassOptions& options) { + if (!ShouldRunPass(options)) { + return absl::OkStatus(); + } + + Graph* graph = options.graph->get(); + VLOG(1) << DumpGraphToFile("before_simplify_ici_dummy_variables_pass", *graph, + options.flib_def); + + absl::Status status; + GraphDefBuilder::Options bopts(graph, &status); + if (!status.ok()) { + LOG(ERROR) << "GraphDefBuilder::Option failed to initialize."; + return status; + } + + bool is_spmd = false; + + // Find all the qualified tpu_execute nodes which is not on replica 0. + std::vector tpu_nodes = + GetNonMainReplicaIciTPUExecuteNodes(graph, is_spmd); + + if (!is_spmd) { + VLOG(1) << "Not SPMD case, skip SimplifyIciDummyVariablesPass."; + return absl::OkStatus(); + } + + if (tpu_nodes.empty()) { + VLOG(1) << "tpu_nodes is empty, skip SimplifyIciDummyVariablesPass."; + return absl::OkStatus(); + } + + for (int i = 0; i < tpu_nodes[0]->num_inputs(); ++i) { + auto replace_status = ReplaceIciDummyVariables(graph, i, tpu_nodes, bopts); + if (!replace_status.ok()) { + LOG(ERROR) << "Replace ici dummy variables failed."; + return replace_status; + } + } + + // Remove the dead nodes that previously connected to the TPUExecute node. + RemoveDeadNodes(graph); + + VLOG(1) << DumpGraphToFile("after_simplify_ici_dummy_variables_pass", *graph, + options.flib_def); + return absl::OkStatus(); } -// REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 49, -// SimplifyIciDummyVariablesPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 49, + SimplifyIciDummyVariablesPass); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass_test.cc b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass_test.cc index 6319eb336c7dd7..867fb0aeaf0754 100644 --- a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass_test.cc +++ b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass_test.cc @@ -16,29 +16,120 @@ limitations under the License. #include "tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h" #include +#include +#include "absl/status/status.h" #include "tensorflow/cc/framework/scope.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_def_builder_util.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/config/flag_defs.h" +#include "tensorflow/core/config/flags.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tsl/platform/test.h" namespace tensorflow { -TEST(SimplifyIciDummyVariablesPassTest, SimplifyIciDummyVariables) { +// Return the node with the specified name +Node* GetNode(const Graph& graph, const std::string& name) { + for (Node* node : graph.nodes()) { + if (node->name() == name) return node; + } + return nullptr; +} + +std::string TestDataPath() { + return tensorflow::GetDataDependencyFilepath( + "tensorflow/core/common_runtime/testdata/"); +} + +// Test the case enable_tf2min_ici_weight is false. +TEST(SimplifyIciDummyVariablesPassTest, flag_is_false) { + flags::Global().enable_tf2min_ici_weight.reset(false); auto graph = std::make_unique(OpRegistry::Global()); - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - GraphDef before; - graph->ToGraphDef(&before); + std::string graph_path = + TestDataPath() + "simplify_ici_dummy_variables_pass_before.pbtxt"; + tensorflow::GraphDef graph_def; + absl::Status load_graph_status = + ReadTextProto(tensorflow::Env::Default(), graph_path, &graph_def); + EXPECT_EQ(load_graph_status.ok(), true); + TF_EXPECT_OK(ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, + graph.get())); + GraphOptimizationPassOptions options; options.graph = &graph; SimplifyIciDummyVariablesPass pass; TF_ASSERT_OK(pass.Run(options)); + + Node* fill_1_dim = GetNode(*graph, "const_1_ici_specific_index_0_task_id_2"); + Node* fill_1_value = + GetNode(*graph, "const_2_ici_specific_index_0_task_id_2"); + Node* fill_1 = GetNode(*graph, "fill_ici_specific_index_0_task_id_2"); + EXPECT_EQ(fill_1_dim, nullptr); + EXPECT_EQ(fill_1_value, nullptr); + EXPECT_EQ(fill_1, nullptr); + + Node* fill_2_dim = GetNode(*graph, "const_1_ici_specific_index_1_task_id_2"); + Node* fill_2_value = + GetNode(*graph, "const_2_ici_specific_index_1_task_id_2"); + Node* fill_2 = GetNode(*graph, "fill_ici_specific_index_1_task_id_2"); + EXPECT_EQ(fill_2_dim, nullptr); + EXPECT_EQ(fill_2_value, nullptr); + EXPECT_EQ(fill_2, nullptr); +} + +// Test the case enable_tf2min_ici_weight is true, graph after pass will have +// dummy variables on task 2. +TEST(SimplifyIciDummyVariablesPassTest, replace_dummy_variable) { + flags::Global().enable_tf2min_ici_weight.reset(true); + auto graph = std::make_unique(OpRegistry::Global()); + std::string graph_path = + TestDataPath() + "simplify_ici_dummy_variables_pass_before.pbtxt"; + tensorflow::GraphDef graph_def; + tensorflow::Status load_graph_status = + ReadTextProto(tensorflow::Env::Default(), graph_path, &graph_def); + EXPECT_EQ(load_graph_status.ok(), true); + TF_EXPECT_OK(ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, + graph.get())); + + GraphOptimizationPassOptions options; + options.graph = &graph; + SimplifyIciDummyVariablesPass pass; + TF_ASSERT_OK(pass.Run(options)); + + Node* fill_1_dim = GetNode(*graph, "const_1_ici_specific_index_0_task_id_2"); + Node* fill_1_value = + GetNode(*graph, "const_2_ici_specific_index_0_task_id_2"); + Node* fill_1 = GetNode(*graph, "fill_ici_specific_index_0_task_id_2"); + EXPECT_NE(fill_1_dim, nullptr); + EXPECT_NE(fill_1_value, nullptr); + EXPECT_NE(fill_1, nullptr); + EXPECT_EQ(fill_1_dim->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(fill_1_value->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(fill_1->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + + Node* fill_2_dim = GetNode(*graph, "const_1_ici_specific_index_1_task_id_2"); + Node* fill_2_value = + GetNode(*graph, "const_2_ici_specific_index_1_task_id_2"); + Node* fill_2 = GetNode(*graph, "fill_ici_specific_index_1_task_id_2"); + EXPECT_NE(fill_2_dim, nullptr); + EXPECT_NE(fill_2_value, nullptr); + EXPECT_NE(fill_2, nullptr); + EXPECT_EQ(fill_2_dim->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(fill_2_value->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(fill_2->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/testdata/simplify_ici_dummy_variables_pass_before.pbtxt b/tensorflow/core/common_runtime/testdata/simplify_ici_dummy_variables_pass_before.pbtxt new file mode 100644 index 00000000000000..a8018d4ee85ea7 --- /dev/null +++ b/tensorflow/core/common_runtime/testdata/simplify_ici_dummy_variables_pass_before.pbtxt @@ -0,0 +1,919 @@ +# proto-file: third_party/tensorflow/core/framework/graph.proto +# proto-message: GraphDef +node { + name: "unknown_2" + op: "_Arg" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "_handle_dtypes" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "_handle_shapes" + value { + list { + shape { + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "905" + } + } + attr { + key: "index" + value { + i: 4 + } + } +} + +node { + name: "unknown_17" + op: "_Arg" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "_handle_dtypes" + value { + list { type: DT_FLOAT } + } + } + attr { + key: "_handle_shapes" + value { + list { + shape { + dim { size: 128 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { unknown_rank: true } + } + } + } + attr { + key: "_user_specified_name" + value { s: "935" } + } + attr { + key: "index" + value { i: 19 } + } +} + +node { + name: "tpu_compile_mlir" + op: "_TPUCompileMlir" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "NumDynamicShapes" + value { i: 0 } + } + attr { + key: "_output_shapes" + value { + list { + shape {} + shape { dim { size: 3 } } + shape { dim { size: 3 } } + shape { dim { size: 3 } } + shape { dim { size: 3 } } + } + } + } + attr { + key: "metadata" + value { s: "" } + } + attr { + key: "mlir_module" + value { s: "" } + } + attr { + key: "num_computations" + value { i: 4 } + } +} + +node { + name: "readvariableop_1" + op: "ReadVariableOp" + input: "unknown_17" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { size: 128 } + dim { size: 1024 } + } + } + } + } + attr { + key: "dtype" + value { type: DT_FLOAT } + } +} + +node { + name: "identity_1" + op: "Identity" + input: "readvariableop_1" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { size: 128 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0" + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } +} + +node { + name: "const_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape {} + } + } + } + attr { + key: "_parallel_execution_ids" + value { s: "r0:0" } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape {} + } + } + } +} + +node { + name: "split_1" + op: "Split" + input: "const_1" + input: "identity_1" + attr { + key: "T" + value { type: DT_FLOAT } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0" + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } + attr { + key: "num_split" + value { i: 4 } + } +} + +node { + name: "readvariableop_2" + op: "ReadVariableOp" + input: "unknown_2" + attr { + key: "_output_shapes" + value { + list { + shape { dim { size: 1024 } } + } + } + } + attr { + key: "dtype" + value { type: DT_FLOAT } + } +} + +node { + name: "identity_2" + op: "Identity" + input: "readvariableop_2" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { dim { size: 1024 } } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0" + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } +} + +node { + name: "tpu_execute_1" + op: "TPUExecute" + input: "split_1" + input: "identity_2" + input: "tpu_compile_mlir:1" + device: "/job:tpu_host_worker/replica:0/task:0/device:TPU:0" + attr { + key: "Targs" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Tresults" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape {} + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0,p0:0" + } + } +} + +node { + name: "tpu_execute_2" + op: "TPUExecute" + input: "split_1:1" + input: "identity_2" + input: "tpu_compile_mlir:2" + device: "/job:tpu_host_worker/replica:0/task:0/device:TPU:1" + attr { + key: "Targs" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Tresults" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape {} + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0,p0:1" + } + } +} + +node { + name: "const_3" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\000\000\000\000\000\004\000\000\000\000\000\000" + } + } + } +} + +node { + name: "const_4" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + } + } + } +} + +node { + name: "fill_1" + op: "Fill" + input: "const_3" + input: "const_4" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "index_type" + value { + type: DT_INT64 + } + } +} + +node { + name: "identity_3" + op: "Identity" + input: "fill_1" + device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1" + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } +} + +node { + name: "const_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1" + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + } + } + } +} + +node { + name: "split_2" + op: "Split" + input: "const_2" + input: "identity_3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { s: "r0:1" } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } + attr { + key: "num_split" + value { i: 4 } + } +} + +node { + name: "const_5" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + dim { + size: 1 + } + } + int64_val: 1024 + } + } + } +} + +node { + name: "const_6" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "bcast_id" + value { + i: 4 + } + } + attr { + key: "dtype" + value { + type: DT_BFLOAT16 + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BFLOAT16 + tensor_shape { + } + } + } + } + experimental_debug_info { + original_node_names: "Identity" + original_func_names: "__inference__train_helper_851" + } +} + +node { + name: "fill_2" + op: "Fill" + input: "const_5" + input: "const_6" + attr { + key: "T" + value { + type: DT_BFLOAT16 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1024 + } + } + } + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "index_type" + value { + type: DT_INT64 + } + } +} + +node { + name: "identity_4" + op: "Identity" + input: "fill_2" + device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" + attr { + key: "T" + value { + type: DT_BFLOAT16 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1" + } + } + attr { + key: "ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } +} + +node { + name: "tpu_execute_3" + op: "TPUExecute" + input: "split_2" + input: "identity_4" + input: "tpu_compile_mlir:1" + device: "/job:tpu_host_worker/replica:0/task:2/device:TPU:0" + attr { + key: "Targs" + value { + list { + type: DT_FLOAT + type: DT_BFLOAT16 + } + } + } + attr { + key: "Tresults" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape {} + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1,p0:0" + } + } +} + +node { + name: "tpu_execute_4" + op: "TPUExecute" + input: "split_2:1" + input: "identity_4" + input: "tpu_compile_mlir:2" + device: "/job:tpu_host_worker/replica:0/task:2/device:TPU:1" + attr { + key: "Targs" + value { + list { + type: DT_FLOAT + type: DT_BFLOAT16 + } + } + } + attr { + key: "Tresults" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + shape { + dim { + size: 32 + } + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1,p0:1" + } + } +} diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d35fbedc4f4413..aa40b33739a7e3 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1762,6 +1762,7 @@ def tf_cc_tests( srcs, deps, name = "", + data = [], linkstatic = 0, tags = [], size = "medium", @@ -1779,6 +1780,7 @@ def tf_cc_tests( size = size, srcs = [src], args = args, + data = data, kernels = kernels, linkopts = linkopts, linkstatic = linkstatic, From a5508d4f7feec3230441dd2b747d944b5dac5fbe Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Fri, 27 Sep 2024 16:55:28 -0700 Subject: [PATCH 391/483] [IFRT] Add donated_input_indices attribute to CallOp to distinguish between donation and aliasing. PiperOrigin-RevId: 679775172 --- .../xla/xla/python/ifrt/ir/ifrt_ops.cc | 37 ++-- .../xla/xla/python/ifrt/ir/ifrt_ops.td | 22 ++- .../ifrt_populate_atom_program_metadata.mlir | 28 +-- .../ifrt/ir/tests/ifrt_verify_donation.mlir | 42 ++++- .../xla/python/ifrt/ir/tests/verify_call.mlir | 55 +++++- .../tests/verify_call_loaded_executable.mlir | 46 ++++- ...frt_populate_atom_program_metadata_pass.cc | 3 + .../transforms/ifrt_verify_donation_pass.cc | 165 ++++++++++-------- .../xla/python/ifrt/ir/transforms/passes.td | 3 + 9 files changed, 296 insertions(+), 105 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc index 080f0faf76e725..cc37e5010e0840 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc @@ -248,12 +248,24 @@ mlir::LogicalResult VerifyIoAlias(mlir::Operation* op, IoAlias io_alias, return mlir::success(); } -mlir::LogicalResult VerifyIoAliases(mlir::Operation* op, - mlir::ArrayAttr io_aliases, - llvm::ArrayRef inputs, - llvm::ArrayRef outputs) { - llvm::SmallSet aliased_inputs; +mlir::LogicalResult VerifyIoAliasesAndDonations( + mlir::Operation* op, mlir::ArrayAttr io_aliases, + llvm::ArrayRef donated_input_indices, + llvm::ArrayRef inputs, + llvm::ArrayRef outputs) { + llvm::SmallSet aliased_or_donated_inputs; llvm::SmallSet aliased_outputs; + for (const int32_t donated_input_index : donated_input_indices) { + if (donated_input_index < 0 || donated_input_index >= inputs.size()) { + return op->emitOpError() + << "can't donate input #" << donated_input_index + << " as only having " << inputs.size() << " inputs"; + } + if (!aliased_or_donated_inputs.insert(donated_input_index).second) { + return op->emitOpError() << "can't donate input #" << donated_input_index + << " more than once"; + } + } for (const auto& raw_io_alias : io_aliases.getAsRange()) { llvm::ArrayRef io_alias_as_array = raw_io_alias.asArrayRef(); @@ -263,9 +275,9 @@ mlir::LogicalResult VerifyIoAliases(mlir::Operation* op, inputs, outputs))) { return mlir::failure(); } - if (!aliased_inputs.insert(aliased_input).second) { - return op->emitOpError() - << "can't alias input #" << aliased_input << " more than once"; + if (!aliased_or_donated_inputs.insert(aliased_input).second) { + return op->emitOpError() << "can't alias or donate input #" + << aliased_input << " more than once"; } if (!aliased_outputs.insert(aliased_output).second) { return op->emitOpError() @@ -618,8 +630,9 @@ mlir::LogicalResult CallOp::verify() { if (mlir::failed(VerifyDevicePlacement(*this, getDevices(), input_arrays, output_arrays)) || - mlir::failed(VerifyIoAliases(*this, getIoAliases(), input_arrays, - output_arrays))) { + mlir::failed(VerifyIoAliasesAndDonations(*this, getIoAliases(), + getDonatedInputIndices(), + input_arrays, output_arrays))) { return mlir::failure(); } return mlir::success(); @@ -680,7 +693,9 @@ mlir::LogicalResult CallLoadedExecutableOp::verify() { output_arrays.push_back(mlir::cast(output.getType())); } - return VerifyIoAliases(*this, getIoAliases(), input_arrays, output_arrays); + return VerifyIoAliasesAndDonations(*this, getIoAliases(), + getDonatedInputIndices(), input_arrays, + output_arrays); } mlir::LogicalResult LoadedExecutableOp::verify() { diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td index 937cdf96ca6e30..1f4bde2c710152 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td @@ -182,9 +182,11 @@ def Ifrt_CallOp : Ifrt_Op<"Call", a subset of these devices. `io_aliases` represents pairs of inputs and outputs, where the input buffer - may be donated and used as the output buffer. The aliased pair must have the - same Ifrt_ArrayType. It's up to IFRT implementations whether to respect this - hint or not. + may be aliased and used as the output buffer. The aliased pair must have the + same byte size. It's up to IFRT implementations whether to respect this + hint or not. Alternatively, if the index of an input is In + `donated_input_indices` then the input buffer might be donated to the + callee if an output with the same byte size is found. }]; let arguments = (ins @@ -192,7 +194,8 @@ def Ifrt_CallOp : Ifrt_Op<"Call", Variadic:$control_inputs, SymbolRefAttr:$callee, Ifrt_DevicesAttr:$devices, - DefaultValuedAttr:$io_aliases); + DefaultValuedAttr:$io_aliases, + DefaultValuedAttr:$donated_input_indices); let results = (outs Variadic:$outputs, Ifrt_ControlType:$control_output); @@ -220,16 +223,19 @@ def Ifrt_CallLoadedExecutableOp : Ifrt_Op<"CallLoadedExecutable", be placed on a subset of these devices. `io_aliases` represents pairs of inputs and outputs, where the input buffer - may be donated and used as the output buffer. The aliased pair must have the - same Ifrt_ArrayType. It's up to IFRT implementations whether to respect this - hint or not. + may be aliased and used as the output buffer. The aliased pair must have the + same byte size. It's up to IFRT implementations whether to respect this + hint or not. Alternatively, if the index of an input is In + `donated_input_indices` then the input buffer might be donated to the + callee if an output with the same byte size is found. }]; let arguments = (ins Variadic:$inputs, Variadic:$control_inputs, SymbolRefAttr:$callee, - DefaultValuedAttr:$io_aliases); + DefaultValuedAttr:$io_aliases, + DefaultValuedAttr:$donated_input_indices); let results = (outs Variadic:$outputs, Ifrt_ControlType:$control_output); diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir index 783ebb26bfdca1..3e3728d864ed35 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir @@ -153,16 +153,17 @@ module @call_twice_with_different_sharding { !array = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> -// CHECK-LABEL: @populate_io_alias -module @populate_io_alias { - func.func @main(%arg0: !array) attributes {ifrt.function} { - // CHECK: ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0) - %0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0,1] - {io_aliases=[array]} : (!array) -> !array +// CHECK-LABEL: @populate_io_alias_and_donation +module @populate_io_alias_and_donation { + func.func @main(%arg0: !array, %arg1: !array) attributes {ifrt.function} { + // CHECK: ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0, %arg1) + %0, %ctrl_0 = ifrt.Call @callee::@main(%arg0, %arg1) on devices [0,1] + {io_aliases=[array], donated_input_indices=array} + : (!array, !array) -> !array // Verify that the module is cloned if io_aliases differ. - // CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%arg0) - %1, %ctrl_1 = ifrt.Call @callee::@main(%arg0) on devices [0,1] - : (!array) -> !array + // CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%arg0, %arg1) + %1, %ctrl_1 = ifrt.Call @callee::@main(%arg0, %arg1) on devices [0,1] + : (!array, !array) -> !array return } @@ -188,8 +189,15 @@ module @populate_io_alias { // CHECK-DAG: ifrt.devices = #ifrt // CHECK-DAG: tf.aliasing_output = 0 : i32 // CHECK-SAME: } + // CHECK: %arg1: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-DAG: jax.buffer_donor = true + // CHECK-SAME: } module @callee attributes {sym_visibility = "private"} { - func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + func.func private @main(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { return %arg0: tensor<2x2xi32> } } diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir index 8c70318c03598c..92bed2748c2188 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir @@ -41,7 +41,7 @@ module @donate_to_reshard_duplicated_arg { // ----- !array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> -module @donate_to_two_calls_error { +module @alias_to_two_calls_error { func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] @@ -59,13 +59,49 @@ module @donate_to_two_calls_error { // ----- +!array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @donate_to_two_calls_error { + func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + // expected-error @+1 {{'ifrt.Call' op input #0 of @identity was already donated}} + %1, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + return %0, %1 : !array, !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + +!array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @arg_donated_to_call_not_donated_to_program { + func.func @main(%arg0: !array) -> (!array) + attributes {ifrt.function} { + // expected-error @+1 {{'ifrt.Call' op input #0 has not been donated to the program.}} + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + return %0 : !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + !array0 = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> !array1 = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> module @program_arg_not_donated_error { func.func @main(%arg0: !array0) -> (!array1) attributes {ifrt.function} { - // expected-error @+1 {{'ifrt.Reshard' op input has not been donated to the program.}} + // expected-error @+1 {{'ifrt.Reshard' op input #0 has not been donated to the program.}} %0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 return %0 : !array1 } @@ -167,7 +203,7 @@ module @donate_to_two_copy_arrays_error { module @program_arg_not_donated_to_remap_error { func.func @main(%arg0: !array {ifrt.donated}, %arg1: !array) -> (!array) attributes {ifrt.function} { - // expected-error @+1 {{'ifrt.RemapArrays' op input has not been donated to the program.}} + // expected-error @+1 {{'ifrt.RemapArrays' op input #1 has not been donated to the program.}} %0 = ifrt.RemapArrays(%arg0, %arg1) mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir index e512b260600e73..202724e44496a0 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir @@ -293,7 +293,7 @@ func.func @io_aliases_should_only_alias_input_once( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' op can't alias input #0 more than once}} + // expected-error@+1 {{'ifrt.Call' op can't alias or donate input #0 more than once}} %0, %1, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array, array]} : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, @@ -429,4 +429,57 @@ func.func @call_local_view_should_have_valid_shape( func.func @callee(%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> { return %arg0 : tensor<4x4xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @donate_an_arg_and_alias_another(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_only_donate_once(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.Call' op can't donate input #0 more than once}} + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_not_both_donate_and_alias_the_same_arg( + %arg0: !array, %arg1: !array) attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.Call' op can't alias or donate input #0 more than once}} + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> } \ No newline at end of file diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir index 14485f4c86a4e0..e41add06877c60 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir @@ -145,7 +145,7 @@ func.func @io_aliases_should_only_alias_input_once( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 more than once}} + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias or donate input #0 more than once}} %0, %1, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array, array]} : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, @@ -230,3 +230,47 @@ ifrt.LoadedExecutable @callee on devices [0,1] [0,1]>) -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> + + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @donate_one_arg_and_alias_another_arg(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_only_donate_once(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't donate input #0 more than once}} + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array} : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_not_both_donate_and_alias_the_same_arg( + %arg0: !array, %arg1: !array) attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias or donate input #0 more than once}} + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc index 5b6e8268ac0902..833658f923c929 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc @@ -159,6 +159,9 @@ mlir::LogicalResult PopulateMetadata(xla::ifrt::CallOp call_op, callee_op.setArgAttr(io_alias_as_array[0], "tf.aliasing_output", builder.getI32IntegerAttr(io_alias_as_array[1])); } + for (const auto idx : call_op.getDonatedInputIndices()) { + callee_op.setArgAttr(idx, "jax.buffer_donor", builder.getBoolAttr(true)); + } return mlir::success(); } diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc index 7e3492147e1665..7a4fcfdf16be96 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc @@ -38,17 +38,100 @@ namespace { #include "xla/python/ifrt/ir/transforms/passes.h.inc" // Verifies that if the value is an input to the IR, then it has been donated. -mlir::LogicalResult VerifyIfInputAndDonated(mlir::Operation* op, +mlir::LogicalResult VerifyIfInputAndDonated(mlir::Operation* op, int idx, mlir::Value arg) { auto block_arg = mlir::dyn_cast(arg); mlir::func::FuncOp func_op = block_arg ? mlir::dyn_cast( block_arg.getOwner()->getParentOp()) : nullptr; - if (func_op && - func_op.getArgAttr(block_arg.getArgNumber(), - xla::ifrt::kIfrtDonatedArgAttrName) == nullptr) { - return op->emitOpError() << "input has not been donated to the program."; + if (func_op && func_op.getArgAttr(block_arg.getArgNumber(), + kIfrtDonatedArgAttrName) == nullptr) { + return op->emitOpError() + << "input #" << idx << " has not been donated to the program."; + } + return mlir::success(); +} + +template +mlir::LogicalResult verifyCallOpAliasesAndDonations( + T op, llvm::DenseMap& donated_value_to_op) { + llvm::DenseSet donated_input_idxs; + // Verify if a donated input is an argument of the main func, then it has + // also been donated by the user. + for (const auto idx : op.getDonatedInputIndices()) { + donated_input_idxs.insert(idx); + auto donated_value = op.getInputs()[idx]; + auto donated_it = donated_value_to_op.try_emplace(donated_value, op); + if (!donated_it.second) { + op.emitOpError() << "input #" << idx << " of " << op.getCalleeAttr() + << " was already donated or aliased to the op at " + << donated_it.first->second->getLoc(); + return mlir::failure(); + } + if (mlir::failed(VerifyIfInputAndDonated(op, idx, donated_value))) { + return mlir::failure(); + } + } + + for (const auto& io_alias : + op.getIoAliases().template getAsRange()) { + mlir::ArrayRef io_alias_as_array = io_alias.asArrayRef(); + donated_input_idxs.insert(io_alias_as_array[0]); + auto aliased_value = op.getInputs()[io_alias_as_array[0]]; + auto donated_it = donated_value_to_op.try_emplace(aliased_value, op); + if (!donated_it.second) { + op.emitOpError() << "input #" << io_alias_as_array[0] << " of " + << op.getCalleeAttr() + << " was already donated or aliased to the op at " + << donated_it.first->second->getLoc(); + return mlir::failure(); + } + if (mlir::failed( + VerifyIfInputAndDonated(op, io_alias_as_array[0], aliased_value))) { + return mlir::failure(); + } + } + + // Verify non-donated inputs after donated inputs have been + // added to also catch instances such as + // `ifrt.Call(%arg0 {ifrt.donated}, %arg0})`. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + if (!donated_input_idxs.contains(idx)) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() << "input #" << idx << " of " << op.getCalleeAttr() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } + } + } + return mlir::success(); +} + +template +mlir::LogicalResult verifyCopyRemapAndReshardOpsDonation( + T op, llvm::DenseMap& donated_value_to_op) { + // Verify that no inputs have already been donated. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() << "input #" << idx << " of op at " << op.getLoc() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } + } + if (op.getDonated()) { + // Add the donated inputs to the map and verify that all the + // donated inputs are also donated to the main func. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + donated_value_to_op.try_emplace(input, op); + if (mlir::failed(VerifyIfInputAndDonated(op, idx, input))) { + return mlir::failure(); + } + } } return mlir::success(); } @@ -74,72 +157,12 @@ void IfrtVerifyDonationPass::runOnOperation() { -> mlir::WalkResult { auto result = llvm::TypeSwitch(op) - .Case( - [&](auto& op) { - llvm::DenseSet donated_input_idxs; - for (const auto& io_alias : - op.getIoAliases() - .template getAsRange()) { - mlir::ArrayRef io_alias_as_array = - io_alias.asArrayRef(); - donated_input_idxs.insert(io_alias_as_array[0]); - auto donated_value = op.getInputs()[io_alias_as_array[0]]; - auto donated_it = - donated_value_to_op.try_emplace(donated_value, op); - if (!donated_it.second) { - op.emitOpError() << "input #" << io_alias_as_array[0] - << " of " << op.getCalleeAttr() - << " was already donated to the op at " - << donated_it.first->second->getLoc(); - return mlir::failure(); - } - if (mlir::failed( - VerifyIfInputAndDonated(op, donated_value))) { - return mlir::failure(); - } - } - // Verify non-donated inputs after donated inputs have been - // added to also catch instances such as - // `ifrt.Call(%arg0 {ifrt.donated}, %arg0})`. - for (const auto [idx, input] : - llvm::enumerate(op.getInputs())) { - if (!donated_input_idxs.contains(idx)) { - auto donated_it = donated_value_to_op.find(input); - if (donated_it != donated_value_to_op.end()) { - op.emitOpError() - << "input #" << idx << " of " << op.getCalleeAttr() - << " was already donated to the op at " - << donated_it->second->getLoc(); - return mlir::failure(); - } - } - } - return mlir::success(); - }) - .Case([&](auto& op) { - // Verify that no inputs have already been donated. - for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { - auto donated_it = donated_value_to_op.find(input); - if (donated_it != donated_value_to_op.end()) { - op.emitOpError() - << "input #" << idx << " of op at " << op.getLoc() - << " was already donated to the op at " - << donated_it->second->getLoc(); - return mlir::failure(); - } - } - if (op.getDonated()) { - // Add the donated inputs to the map and verify that all the - // donated inputs are also donated to the main func. - for (const auto input : op.getInputs()) { - donated_value_to_op.try_emplace(input, op); - if (mlir::failed(VerifyIfInputAndDonated(op, input))) { - return mlir::failure(); - } - } - } - return mlir::success(); + .Case([&](auto& op) { + return verifyCallOpAliasesAndDonations(op, donated_value_to_op); + }) + .Case([&](auto& op) { + return verifyCopyRemapAndReshardOpsDonation(op, + donated_value_to_op); }) .Case([&](mlir::func::ReturnOp return_op) { for (const auto& [idx, result] : diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td index a75fc059cb8e64..6acb62e678de29 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td @@ -214,6 +214,9 @@ For every CallOp, this pass main FuncOp 3. attaches `tf.aliasing_output` attr to the callee main FuncOp's inputs according to `io_aliases` + 4. attaches `jax.buffer_donor` attr to the callee main FuncOp's inputs + according to `donated_input_indices` + For CallOps with the same callee, a different clone will be created for each CallOp, even if the populated metadata are the same. User may want to run `ifrt-duplicated-callee-elimination` pass to dedup the clones. From 3f384035758d56aad6a937664f44f79367602815 Mon Sep 17 00:00:00 2001 From: Arturo Schmidt Date: Fri, 27 Sep 2024 17:08:32 -0700 Subject: [PATCH 392/483] Decouple SavedModelObjectGraphImporter from ImporterBase. PiperOrigin-RevId: 679778589 --- .../mlir/tensorflow/translate/import_model.cc | 32 +++---------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index b6c704a4431fd4..848921c96704dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -2810,26 +2810,6 @@ Status GraphDefImporter::GetControlRetsFromGraph( return absl::OkStatus(); } -// Stateful helper class to import a TensorFlow model expressed in SavedModel -// into an MLIR Module. -class SavedModelObjectGraphImporter : public ImporterBase { - public: - // Main entry point: converts all functions in the given meta graph to an MLIR - // Module. - static absl::StatusOr> Convert( - SavedModelV2Bundle* saved_model, absl::Span exported_names, - mlir::MLIRContext* context, MLIRImportOptions options); - - private: - explicit SavedModelObjectGraphImporter( - const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, - const GraphImportConfig& specs, mlir::ModuleOp module, - std::unordered_map* tf_name_to_mlir_name, - NameUniquifier* function_name_uniquifier) - : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name, - function_name_uniquifier) {} -}; - // Determines the names used to reference objects in the SavedObjectGraph. class ObjectNames { public: @@ -3585,11 +3565,9 @@ Status CreateSavedModelIR( return absl::OkStatus(); } -absl::StatusOr> -SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, - absl::Span exported_names, - mlir::MLIRContext* context, - MLIRImportOptions import_options) { +absl::StatusOr> ConvertSavedModelObjectGraph( + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, MLIRImportOptions import_options) { LoadImporterDialects(*context); GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = @@ -4375,8 +4353,8 @@ absl::StatusOr> ConvertFunctionToMlir( absl::StatusOr> ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, MLIRImportOptions options) { - return SavedModelObjectGraphImporter::Convert(saved_model, exported_names, - context, options); + return ConvertSavedModelObjectGraph(saved_model, exported_names, context, + options); } absl::StatusOr> ConvertSavedModelV1ToMlir( From c28b27925ff2775e8a4de7be21abde1abff02863 Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Fri, 27 Sep 2024 17:32:26 -0700 Subject: [PATCH 393/483] PR #16841: Delete FP8 Scaling Factors in GEMM Rewriter Imported from GitHub PR https://github.com/openxla/xla/pull/16841 Removes the scaling factors of C and D (matrix bias and result) from FP8 Custom Calls created in the GEMM rewriter when their data types are not FP8. See https://github.com/openxla/xla/pull/15795. Copybara import of the project: -- fd9750fa8474fe72fe641c7b3bc005ff30396e0a by Philipp Hack : Removes superfluous FP8 scaling factors in GEMM rewriter. Merging this change closes #16841 PiperOrigin-RevId: 679784586 --- .../xla/service/gpu/ir_emitter_unnested.cc | 23 ++- .../xla/xla/service/gpu/matmul_utils.cc | 4 +- .../service/gpu/transforms/gemm_rewriter.cc | 27 ++-- .../gpu/transforms/gemm_rewriter_test.cc | 137 ++++++++---------- .../xla/stream_executor/cuda/cuda_blas_lt.cc | 7 +- 5 files changed, 90 insertions(+), 108 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index b7da6271befc9f..ea0577e05077f5 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -740,8 +740,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( const HloCustomCallInstruction* instr) { - TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 || - instr->operand_count() == 8); + TF_RET_CHECK(instr->operand_count() > 3 && instr->operand_count() < 8); TF_ASSIGN_OR_RETURN(const auto gpu_config, instr->backend_config()); const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config(); @@ -777,22 +776,22 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN( BufferAllocation::Slice b_scale, GetAllocationSliceForHlo(instr->operand(a_scale_index + 1))); + + // cublasLT requires c_scale/d_scale to be null when C/D is not FP8. + // Currently, C cannot be FP8. + BufferAllocation::Slice c_scale, d_scale; #if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice c_scale, - GetAllocationSliceForHlo(instr->operand(a_scale_index + 2))); - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice d_scale, - GetAllocationSliceForHlo(instr->operand(a_scale_index + 3))); -#else // TENSORFLOW_USE_ROCM - BufferAllocation::Slice c_scale; - BufferAllocation::Slice d_scale; + if (instr->shape().tuple_shapes(0).element_type() == F8E4M3FN || + instr->shape().tuple_shapes(0).element_type() == F8E5M2) { + TF_ASSIGN_OR_RETURN(d_scale, + GetAllocationSliceForHlo(instr->operands().back())); + } #endif BufferAllocation::Slice bias; if (has_vector_bias) { TF_ASSIGN_OR_RETURN( - bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 4))); + bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 2))); } BufferAllocation::Slice d_amax; diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index 49270de65ecd3f..401832e2d17e0e 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -484,8 +484,8 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm, if (has_vector_bias) { int vector_bias_index = has_matrix_bias ? 3 : 2; if (primitive_util::IsF8Type(lhs_shape.element_type())) { - // FP8 gemms have 4 scales as inputs which come before the vector bias. - vector_bias_index += 4; + // FP8 gemms have 2 scales as inputs which come before the vector bias. + vector_bias_index += 2; } vector_bias_shape = gemm->operand(vector_bias_index)->shape(); } diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index dea1f704c5801e..d7674efc15e943 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -1083,12 +1083,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32 // format. Set the factors to one when no scaling factors were captured. - Literal one_literal = LiteralUtil::One(F32); - HloInstruction *one = instr->AddInstruction( - HloInstruction::CreateConstant(one_literal.Clone())); std::array mult_scale{a.mult_scale, b.mult_scale}; std::array scales{a.scale, b.scale}, inv_scales, scales_f32; + HloInstruction *one_constant = nullptr; + auto one = [&one_constant, instr]() -> HloInstruction * { + if (!one_constant) { + one_constant = instr->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::One(F32))); + } + return one_constant; + }; + for (int i = 0; i < scales.size(); ++i) { if (scales[i]) { if (!ShapeUtil::IsScalar(scales[i]->shape())) { @@ -1099,7 +1105,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (!mult_scale[i]) { inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary( - scales[i]->shape(), HloOpcode::kDivide, one, scales[i])); + scales[i]->shape(), HloOpcode::kDivide, one(), scales[i])); } scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i]; if (scales_f32[i]->shape().element_type() != F32) { @@ -1107,7 +1113,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { ShapeUtil::MakeScalarShape(F32), scales_f32[i])); } } else { - scales_f32[i] = one; + scales_f32[i] = one(); } } @@ -1249,7 +1255,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { PadShapeToMultipleOf16(instr->shape(), out_batch_dims); std::vector operands_list = { - a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one}; + a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1]}; HloInstruction *new_custom_call = instr->AddInstruction(HloInstruction::CreateCustomCall( @@ -1415,13 +1421,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - // If necessary, invert the scaling factor of D and convert to F32. + // If necessary, invert the scaling factor of D and convert to F32. When no + // scaling factor was captured, set the factor to one. if (d_scale) { TF_ASSIGN_OR_RETURN(d_scale, InvertAndConvertScalar(d_scale, !mult_scale)); - TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith( - gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale)); + } else { + d_scale = instr->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::One(F32))); } + existing_gemm->AppendOperand(d_scale); // If present, elide the calculation of the maximum of the absolute values // of the result of the GEMM. diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc index 3df393a0c89d77..140787413d0f67 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -4950,11 +4950,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks.append( - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), )"); } else { checks.append( - R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), )"); } checks.append( @@ -5009,7 +5009,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5064,8 +5064,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5121,8 +5120,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { ; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C4:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]], [[C4]], /*index=5*/[[C4]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5205,7 +5203,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5269,8 +5267,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C2]], /*index=5*/[[C2]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5328,7 +5325,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]], [[C2]], /*index=5*/[[C2]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5389,8 +5386,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5456,8 +5452,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { ; CHECK-NEXT: [[SELECT:%[^ ]+]] = <>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5541,8 +5536,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5598,8 +5592,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":3 @@ -5655,8 +5648,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5731,15 +5723,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks += - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } else { checks += R"(; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[B]]), )"; } checks += R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -5831,15 +5822,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks += - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } else { checks += - R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } checks += R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -5943,8 +5933,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { ; CHECK: [[C0:%[^ ]+]] = f32[16,16]{1,0} add({{.*}}) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: output_to_operand_aliasing={ ; CHECK-SAME: {0}: (2, {}) @@ -6009,8 +5998,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2 ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]], /*index=5*/[[C3]], [[C3]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6067,8 +6055,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6117,7 +6106,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) { ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6164,7 +6153,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6215,7 +6204,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6280,12 +6269,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6390,12 +6379,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6472,11 +6461,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { ; CHECK: [[C0:%[^ ]+]] = f16[16,16]{1,0} add({{.*}}) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK: [[P3:%[^ ]+]] = f16[] parameter(4) -; CHECK: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5) -; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), ; CHECK-NOT: output_to_operand_aliasing -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6543,14 +6531,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5) ; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) -; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]), -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]], /*index=5*/[[CV2]]), +; CHECK-GCN: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6607,10 +6595,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[VB:%[^ ]+]] = f32[16]{0} parameter(2) ; CHECK-NEXT: [[VBC:%[^ ]+]] = bf16[16]{0} convert([[VB]]) -; CHECK: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]], [[VBC]]), +; CHECK: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[VBC]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6670,9 +6657,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6744,10 +6730,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[B:%[^ ]+]] = f32[32]{0} parameter(2) ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]]) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[B_F16]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6828,12 +6813,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[B:%[^ ]+]] = f32[31]{0} parameter(2) ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[31]{0} convert([[B]]) ; CHECK-NEXT: [[C3:%[^ ]+]] = f16[] constant(0) ; CHECK-NEXT: [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1 -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[P2_PAD]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6906,8 +6890,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6988,8 +6971,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[48,32]{1,0} pad([[B_BITCAST]], [[C3]]), padding=0_3x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7054,8 +7036,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7117,8 +7098,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { ; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0} ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7175,8 +7155,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { ; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7233,8 +7212,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7296,8 +7274,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[CV0:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(5) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]], /*index=5*/[[C1]], [[C1]]), +; CHECK: [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7372,12 +7349,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7453,13 +7430,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[P4_INV_CONVERT]]), +; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7533,12 +7510,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index fbbcfad52fb3ae..c1ba88f3d61bf1 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -449,15 +449,12 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, b_scale.opaque())); } - auto isF8Input = [](const auto& desc) { - return desc.type() == CUDA_R_8F_E4M3 || desc.type() == CUDA_R_8F_E5M2; - }; - if (c_scale != nullptr && isF8Input(c_desc_)) { + if (c_scale != nullptr) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, c_scale.opaque())); } - if (d_scale != nullptr && isF8Input(d_desc_)) { + if (d_scale != nullptr) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, d_scale.opaque())); From d56090d0fcebd339bf589e163c5291ff9a108a65 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Sep 2024 21:18:26 -0700 Subject: [PATCH 394/483] Automated Code Change PiperOrigin-RevId: 679837839 --- tensorflow/cc/experimental/libexport/load.cc | 7 +++---- tensorflow/cc/experimental/libexport/load.h | 7 +++---- tensorflow/cc/experimental/libexport/load_test.cc | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tensorflow/cc/experimental/libexport/load.cc b/tensorflow/cc/experimental/libexport/load.cc index be9319b066d74d..cec8af507a7fbd 100644 --- a/tensorflow/cc/experimental/libexport/load.cc +++ b/tensorflow/cc/experimental/libexport/load.cc @@ -28,7 +28,7 @@ namespace libexport { using protobuf::RepeatedPtrField; -tensorflow::StatusOr TFPackage::Load(const std::string& path) { +absl::StatusOr TFPackage::Load(const std::string& path) { // Load the proto TFPackage tf_package; const string saved_model_pb_path = io::JoinPath(path, kSavedModelFilenamePb); @@ -83,8 +83,7 @@ tensorflow::StatusOr TFPackage::Load(const std::string& path) { return tf_package; } -tensorflow::StatusOr TFPackage::GetVariableCheckpointKey( - int index) { +absl::StatusOr TFPackage::GetVariableCheckpointKey(int index) { // TODO(danielellis): make sure valid index const auto& trackable_object = trackable_object_graph_.nodes(index); const TrackableObjectGraph::TrackableObject::SerializedTensor* @@ -105,7 +104,7 @@ const SavedObjectGraph& TFPackage::GetObjectGraph() { return saved_model_proto_.mutable_meta_graphs(0)->object_graph_def(); } -tensorflow::StatusOr TFPackage::GetGraphDefNode( +absl::StatusOr TFPackage::GetGraphDefNode( std::string name) { const auto& iter = graph_def_nodes_by_name_.find(name); if (iter == graph_def_nodes_by_name_.end()) { diff --git a/tensorflow/cc/experimental/libexport/load.h b/tensorflow/cc/experimental/libexport/load.h index 8ab5019eba45fe..6775f73b5ab8fb 100644 --- a/tensorflow/cc/experimental/libexport/load.h +++ b/tensorflow/cc/experimental/libexport/load.h @@ -42,7 +42,7 @@ namespace libexport { class TFPackage { public: // Load a SavedModel, parsing the associated protobuf for later access. - static tensorflow::StatusOr Load(const std::string& path); + static absl::StatusOr Load(const std::string& path); // Reads and returns a checkpoint key associated with a variable. // @@ -53,7 +53,7 @@ class TFPackage { // checkpoint files by "checkpoint keys". These keys along with dtype and // shape / slice information allow RestoreV2 to look up a variable's value in // the SavedModel and restore it into a tensor. - tensorflow::StatusOr GetVariableCheckpointKey(int index); + absl::StatusOr GetVariableCheckpointKey(int index); // Retrieves the object graph from the SavedModel. // @@ -74,8 +74,7 @@ class TFPackage { // Since we may need to load many constants, we create a hash map of these // names to their corresponding nodes at load time in order to look them up // in constant time. - tensorflow::StatusOr GetGraphDefNode( - std::string name); + absl::StatusOr GetGraphDefNode(std::string name); // Returns a list of function defs in the SavedModel. const protobuf::RepeatedPtrField& GetFunctionDefs(); diff --git a/tensorflow/cc/experimental/libexport/load_test.cc b/tensorflow/cc/experimental/libexport/load_test.cc index 0b1565be4355fa..a8ad6e211718d7 100644 --- a/tensorflow/cc/experimental/libexport/load_test.cc +++ b/tensorflow/cc/experimental/libexport/load_test.cc @@ -24,7 +24,7 @@ namespace libexport { namespace { TEST(LoadTest, TestDiskSavedModelLoad) { - StatusOr result = TFPackage::Load("test"); + absl::StatusOr result = TFPackage::Load("test"); EXPECT_FALSE(result.status().ok()); } From 55c21108ec48334dc21bb8b50a89068e67c98417 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 27 Sep 2024 21:34:49 -0700 Subject: [PATCH 395/483] [XLA] Ensure that the operands of rng bit generator are replicated since the spmd partitioner will replicate it anyway. PiperOrigin-RevId: 679841854 --- .../xla/xla/service/sharding_propagation.cc | 4 +++ .../xla/service/sharding_propagation_test.cc | 33 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index bf44306a1cccd0..2c58d92a9d2a07 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -2714,6 +2714,10 @@ bool ShardingPropagation::InferShardingFromUsers( bool improved_sharding = false; const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; for (const HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kRngBitGenerator) { + instruction->set_sharding(HloSharding::Replicate()); + return true; + } std::optional user_sharding = ShardingPropagation::GetShardingFromUser(*instruction, *user, aggressiveness, is_spmd, diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index 5ca4b47d8ea15c..96303a8d1b6880 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -12195,5 +12195,38 @@ ENTRY main { op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, ReplicateRngBitGeneratorSeed) { + const char* const hlo_string = R"( +HloModule module +apply_or { + x = u64[] parameter(0) + y = u64[] parameter(1) + ROOT x_or_y = or(x, y) +} +ENTRY main { + p = s32[2,2]{1,0} parameter(0), sharding={devices=[2,2]<=[4]} + up = u64[2,2] convert(p) + i = u64[] constant(0) + seed = u64[2] reduce(up, i), dimensions={1}, to_apply=apply_or + rbg = u32[2048,4096] rng-bit-generator(seed), algorithm=rng_default + ROOT s = u32[2048,4096]{1,0} custom-call(rbg), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "seed"); + // Check sharding is correctly propagated. + EXPECT_TRUE(instruction->sharding().IsReplicated()); +} + } // namespace } // namespace xla From 986463e32adcfbb47e0896819ea28c575bc50fa9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 28 Sep 2024 02:03:14 -0700 Subject: [PATCH 396/483] Update GraphDef version to 1999. PiperOrigin-RevId: 679903483 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 13944c256422d0..c85290421d343e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1998 // Updated: 2024/9/27 +#define TF_GRAPH_DEF_VERSION 1999 // Updated: 2024/9/28 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From ce012757026a2dd9d36cb465fd288114970a8367 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 28 Sep 2024 02:03:18 -0700 Subject: [PATCH 397/483] compat: Update forward compatibility horizon to 2024-09-28 PiperOrigin-RevId: 679903512 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index b65e9c57c9c560..3e7177c7bbdca4 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 27) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 28) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 3307e6189a3bec9f3eedf736f473fd3718a140f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 28 Sep 2024 02:22:33 -0700 Subject: [PATCH 398/483] Automated Code Change PiperOrigin-RevId: 679908031 --- tensorflow/core/lib/gtl/edit_distance.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h index 818ec69fd96fd7..94a5ad687d678a 100644 --- a/tensorflow/core/lib/gtl/edit_distance.h +++ b/tensorflow/core/lib/gtl/edit_distance.h @@ -44,9 +44,8 @@ namespace gtl { // int64 dist = LevenshteinDistance("hi", "bye", std::equal_to()); // template -inline int64_t LevenshteinDistance(const gtl::ArraySlice& s, - const gtl::ArraySlice& t, - const Cmp& cmp) { +inline int64_t LevenshteinDistance(const gtl::ArraySlice s, + const gtl::ArraySlice t, const Cmp& cmp) { const int64_t s_size = s.size(); const int64_t t_size = t.size(); From 50e812ae4161a3aa8bb70d00b6533c82fee39f94 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Sat, 28 Sep 2024 09:51:37 -0700 Subject: [PATCH 399/483] [XLA:GPU] Remove xla_gpu_enable_triton_gemm_int4 flag which is on by default. This flag has been enabled by default for a while now, and there is no reason to keep it around. PiperOrigin-RevId: 679992341 --- third_party/xla/xla/debug_options_flags.cc | 10 +++------- .../fusions/triton/triton_fusion_emitter.cc | 18 ------------------ ...triton_fusion_emitter_device_legacy_test.cc | 1 - .../fusions/triton/triton_support_legacy.cc | 9 +-------- .../service/gpu/transforms/gemm_fusion_test.cc | 12 ------------ third_party/xla/xla/xla.proto | 6 ++---- 6 files changed, 6 insertions(+), 50 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index abedfc370dd83f..267029c0630674 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -285,8 +285,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cudnn_gemm_max_plans(5); - opts.set_xla_gpu_enable_triton_gemm_int4(true); - opts.set_xla_gpu_enable_pgle_accuracy_checker(false); opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10); @@ -1923,11 +1921,9 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Limit for the number of kernel configurations (plans) to use during " "autotuning of cuDNN GEMM fusions.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_triton_gemm_int4", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_gemm_int4), - debug_options->xla_gpu_enable_triton_gemm_int4(), - "Experimental: Enable Triton gemm for int4 inputs.")); + flag_list->push_back(tsl::Flag("xla_gpu_enable_triton_gemm_int4", + noop_flag_setter, true, + "[Deprecated, do not use]")); flag_list->push_back( tsl::Flag("xla_gpu_async_dot", bool_setter_for(&DebugOptions::set_xla_gpu_async_dot), diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index a105789fa86f01..ccaad403bd954d 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -1216,14 +1216,6 @@ absl::StatusOr EmitScope( Value result; if (hlo->opcode() == HloOpcode::kConvert && hlo->operand(0)->shape().element_type() == S4) { - if (!hlo->GetModule() - ->config() - .debug_options() - .xla_gpu_enable_triton_gemm_int4()) { - return absl::UnimplementedError( - "Int4 support is not enabled in the debug options."); - } - TF_ASSIGN_OR_RETURN( auto unpacked, EmitUnpackInt4(b, hlo, side, values[hlo->operand(0)])); std::vector operands({unpacked}); @@ -3058,15 +3050,6 @@ absl::Status CreateInternalError(std::string_view message, return absl::InternalError(err); } -absl::Status DoSupportType(const DebugOptions& debug_options, - PrimitiveType type) { - if (type == S4 && !debug_options.xla_gpu_enable_triton_gemm_int4()) { - return absl::FailedPreconditionError( - "Int4 support is not enabled in the debug options."); - } - return absl::OkStatus(); -} - absl::StatusOr> CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, @@ -3088,7 +3071,6 @@ absl::StatusOr> CreateTritonModule( SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { PrimitiveType type = p->shape().element_type(); - TF_RETURN_IF_ERROR(DoSupportType(debug_options, type)); Type ir_type; if (type == U16) { ir_type = b.getI16Type(); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 2d5c4232038891..036e3221c71bcb 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -119,7 +119,6 @@ class TritonGemmTest : public TritonTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); // Always rewrite Gemms with Triton regardless of size. debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); - debug_options.set_xla_gpu_enable_triton_gemm_int4(true); return debug_options; } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc index 802fed51f4d200..97c4891441d98c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -122,14 +122,7 @@ CodegenDecision IsInstructionSupportsDataTypes( const auto operand_type = operand->shape().element_type(); switch (instr.opcode()) { case HloOpcode::kConvert: - // TODO(b/358580281): remove DebugOptions from this function after - // enabling int4 in Triton GEMM. - if (operand_type == S4 && instr.GetModule() - ->config() - .debug_options() - .xla_gpu_enable_triton_gemm_int4()) { - continue; - } + if (operand_type == S4) continue; [[fallthrough]]; default: if (!IsTritonSupportedDataType(operand_type, gpu_version)) { diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc index af97932ddf3a89..54985658367d4d 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -1364,9 +1364,6 @@ TEST_F(SmallDotGemmFusionTest, Int4DotIsRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); } @@ -1384,9 +1381,6 @@ TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); // Check that the fusion is present and that the lhs is not converted. @@ -1411,9 +1405,6 @@ TEST_F(SmallDotGemmFusionTest, Int4ConvertPlusNegateIsRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); // Check that the fusion is present and that convert and negation is fused in // it. @@ -1440,9 +1431,6 @@ TEST_F(SmallDotGemmFusionTest, Int4WithMinorBatchDimIsNotRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); TF_ASSERT_OK_AND_ASSIGN(auto result, GemmFusion(gpu_version_).Run(module.get())); EXPECT_FALSE(result); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index c8900a65753b32..f4ef137fce057e 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -949,9 +949,6 @@ message DebugOptions { // If enabled, uses the libnvjitlink library for PTX compilation and linking bool xla_gpu_enable_libnvjitlink = 319; - // If enabled, generates triton gemm kernels for int4 inputs. - bool xla_gpu_enable_triton_gemm_int4 = 320; - // If true, XLA will wrap `dot` operations into async computations in an // effort to parallelize matrix operations. bool xla_gpu_async_dot = 321; @@ -1005,7 +1002,8 @@ message DebugOptions { // xla_gpu_graph_level // xla_gpu_single_wave_autotuning // xla_gpu_enable_persistent_temp_buffers - reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206; + // xla_gpu_enable_triton_gemm_int4 + reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206, 320; } // Contains flags which affects the GPU compilation result. From 99c6339c36ca9ef4940e5eb2771867d4853f695a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 28 Sep 2024 19:14:12 -0700 Subject: [PATCH 400/483] Updates the solver to accept feasible -- albeit suboptimal -- solutions. PiperOrigin-RevId: 680092346 --- .../hlo/experimental/auto_sharding/auto_sharding_solver.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index cf18e7a5c56c6e..114cca321a0509 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -875,11 +875,14 @@ AutoShardingSolverResult SolveAndExtractSolution( } else if (status == operations_research::MPSolver::MODEL_INVALID) { LOG(FATAL) << "Solver says that the input MIP is invalid. This is most " "likely a bug and should be reported."; - return AutoShardingSolverResult(absl::InternalError("Solver timed out."), + return AutoShardingSolverResult(absl::InternalError("Invalid MIP."), /*skip_auto_sharding=*/false); - } else if (status != operations_research::MPSolver::OPTIMAL) { + } else if (status == operations_research::MPSolver::NOT_SOLVED) { + LOG(WARNING) << "Solver timeout; no solution was produced"; return AutoShardingSolverResult(absl::InternalError("Solver timed out."), /*skip_auto_sharding=*/true); + } else if (status != operations_research::MPSolver::OPTIMAL) { + LOG(WARNING) << "Solver timeout; moving forward with a suboptimal solution"; } // Fingerprint the model & solution (useful when checking for determinism). From 603c0169401b5b0adaac04167949773b3f652534 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 29 Sep 2024 02:02:40 -0700 Subject: [PATCH 401/483] Update GraphDef version to 2000. PiperOrigin-RevId: 680183460 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index c85290421d343e..cea3f2c3853654 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1999 // Updated: 2024/9/28 +#define TF_GRAPH_DEF_VERSION 2000 // Updated: 2024/9/29 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 6e86a6956240077dff42675f93f15eb8e2f8e391 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 29 Sep 2024 02:02:54 -0700 Subject: [PATCH 402/483] compat: Update forward compatibility horizon to 2024-09-29 PiperOrigin-RevId: 680183534 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 3e7177c7bbdca4..80b1f00456289f 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 28) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 29) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 3dd6917640a2e519dc0d300adfcd3ca287457e8d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 29 Sep 2024 06:36:32 -0700 Subject: [PATCH 403/483] Automated Code Change PiperOrigin-RevId: 680233139 --- tensorflow/compiler/mlir/tfr/resources/composite_ops.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc b/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc index 8120625bc89e27..3523f295ee8291 100644 --- a/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc +++ b/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc @@ -13,10 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { From df40eac4714033ad3574597125c75549b65d8bc8 Mon Sep 17 00:00:00 2001 From: pizzud Date: Sun, 29 Sep 2024 16:42:41 -0700 Subject: [PATCH 404/483] gemm_fusion_autotuner_test: Properly delete the verified module. `std::unique_ptr::release` doesn't destroy the held contents, it just gives up on them. To actually destroy them, we need to assign the `unique_ptr` itself to to `nullptr`. PiperOrigin-RevId: 680342745 --- .../xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index bb70c963a9c450..8a43da38f52e4a 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -105,7 +105,7 @@ ENTRY entry { // Destroy the original module to be sure that the extracted one has no // dependency on it. - module.release(); + module = nullptr; EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); @@ -145,7 +145,7 @@ ENTRY entry { // Destroy the original module to be sure that the extracted one has no // dependency on it. - module.release(); + module = nullptr; EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), GmockMatch(m::Dot(m::Convert(m::Parameter()), m::Parameter()))); From a9a64784cba6da221692ab4d0b7ba7bd3b0c4f8c Mon Sep 17 00:00:00 2001 From: pizzud Date: Sun, 29 Sep 2024 16:43:09 -0700 Subject: [PATCH 405/483] hlo_runner_pjrt: Have PjRtWrappedExecutable own the underlying executable. Avoids a memory leak in various unit tests currently masked by not actually looking for leaks. PiperOrigin-RevId: 680342833 --- third_party/xla/xla/service/hlo_runner_pjrt.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.cc b/third_party/xla/xla/service/hlo_runner_pjrt.cc index 3965bf61870f3a..0047ccd4288193 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.cc +++ b/third_party/xla/xla/service/hlo_runner_pjrt.cc @@ -114,6 +114,7 @@ absl::StatusOr GenerateExecuteOptions(const HloModule& module) { // TODO(b/245550554): Remove the use of PjRtWrappedExecutable. class PjRtWrappedExecutable : public Executable { public: + // Takes ownership of the provided executable. explicit PjRtWrappedExecutable(std::shared_ptr hlo_module, PjRtLoadedExecutable* pjrt_loaded_executable) : Executable(hlo_module), @@ -125,11 +126,11 @@ class PjRtWrappedExecutable : public Executable { HloExecutionProfile* hlo_execution_profile) override; PjRtLoadedExecutable* GetPjRtLoadedExecutable() const { - return pjrt_loaded_executable_; + return pjrt_loaded_executable_.get(); } private: - PjRtLoadedExecutable* pjrt_loaded_executable_; + std::unique_ptr pjrt_loaded_executable_; }; absl::StatusOr PjRtWrappedExecutable::ExecuteAsyncOnStream( @@ -373,9 +374,7 @@ absl::StatusOr> HloRunnerPjRt::CreateExecutable( std::move(pjrt_executable->GetHloModules().value()[0])), pjrt_executable.release()); - std::unique_ptr exec = - static_cast>(executable.release()); - return exec; + return executable; } absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( From 62fb230776055f1a3ce3148844c9937d814c5f3d Mon Sep 17 00:00:00 2001 From: Eric Salo Date: Sun, 29 Sep 2024 19:27:14 -0700 Subject: [PATCH 406/483] cleanup: remove api_version from BUILD files PiperOrigin-RevId: 680374028 --- tensorflow/dtensor/proto/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/dtensor/proto/BUILD b/tensorflow/dtensor/proto/BUILD index da9cd8002c8d87..8694130645b933 100644 --- a/tensorflow/dtensor/proto/BUILD +++ b/tensorflow/dtensor/proto/BUILD @@ -27,7 +27,6 @@ alias( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "layout_proto_py_pb2", -# api_version = 2, # deps = [":layout_proto"], # ) # copybara:uncomment_end From b715cadc779763333a76a07448ecaed351b3da95 Mon Sep 17 00:00:00 2001 From: Toli Yevtushenko Date: Sun, 29 Sep 2024 20:26:22 -0700 Subject: [PATCH 407/483] Add explicit includes to fix Kokoro compile issues. PiperOrigin-RevId: 680391929 --- third_party/xla/xla/hlo/utils/BUILD | 7 ++-- .../xla/xla/hlo/utils/hlo_live_range_test.cc | 7 ++++ third_party/xla/xla/service/BUILD | 36 ++++++++++++++++++- .../xla/xla/service/buffer_assignment_test.cc | 17 ++++++--- .../xla/service/collective_pipeliner_test.cc | 3 ++ third_party/xla/xla/service/defuser_test.cc | 8 +++++ .../xla/service/hlo_alias_analysis_test.cc | 18 +++++++--- .../xla/service/hlo_dfs_reachability_test.cc | 11 ++++-- .../xla/xla/service/hlo_reachability_test.cc | 7 ++++ .../hlo_rematerialization_test_utils.h | 3 ++ .../hlo_rematerialization_test_utils_test.cc | 2 ++ third_party/xla/xla/tests/BUILD | 24 +++++++++++-- third_party/xla/xla/tests/hlo_test_base.cc | 14 ++++++++ third_party/xla/xla/tests/hlo_test_base.h | 35 +++++++++--------- 14 files changed, 158 insertions(+), 34 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 814f54fa613550..fc3b0ed42a8bd1 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -46,16 +46,17 @@ xla_cc_test( srcs = ["hlo_live_range_test.cc"], deps = [ ":hlo_live_range", - "//xla:literal", - "//xla:status_macros", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_alias_analysis", - "//xla/service:hlo_ordering", "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc b/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc index 64e4ab5ee37d62..eba6cb81abc4bc 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc @@ -14,19 +14,26 @@ limitations under the License. ==============================================================================*/ #include "xla/hlo/utils/hlo_live_range.h" +#include #include #include #include #include #include +#include #include "absl/container/flat_hash_map.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/literal_util.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 35d58653848b72..3b4784da8c133f 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -630,8 +630,11 @@ xla_cc_test( srcs = ["collective_pipeliner_test.cc"], deps = [ ":collective_pipeliner", + ":hlo_module_config", ":hlo_parser", + ":hlo_verifier", ":host_memory_offload_annotations_hdr", + "//xla:test_helpers", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", @@ -1062,12 +1065,16 @@ xla_cc_test( name = "hlo_dfs_reachability_test", srcs = ["hlo_dfs_reachability_test.cc"], deps = [ + ":computation_placer_hdr", + ":hlo_module_config", + "//xla:literal_util", + "//xla:shape_util", "//xla:test", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_dfs_reachability", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/random", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test_benchmark", ], ) @@ -1077,6 +1084,9 @@ xla_cc_test( srcs = ["hlo_reachability_test.cc"], deps = [ ":computation_placer", + ":hlo_module_config", + "//xla:literal_util", + "//xla:shape_util", "//xla:test", "//xla:test_helpers", "//xla/hlo/ir:hlo", @@ -1084,6 +1094,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/random", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test_benchmark", ], ) @@ -2022,24 +2033,36 @@ xla_cc_test( ":cpu_plugin", ":flatten_call_graph", ":hlo_alias_analysis", + ":hlo_buffer", ":hlo_dce", ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", ":hlo_proto_cc", ":hlo_proto_util", + ":hlo_value", + ":logical_buffer", + "//xla:comparison_util", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:test_helpers", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service/memory_space_assignment", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) @@ -3887,10 +3910,13 @@ xla_cc_test( deps = [ ":defuser", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", ], ) @@ -5029,10 +5055,13 @@ xla_cc_test( deps = [ ":flatten_call_graph", ":hlo_alias_analysis", + ":hlo_buffer", ":hlo_graph_dumper", ":hlo_ordering", + ":hlo_value", ":instruction_fusion", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:test_helpers", @@ -5042,7 +5071,10 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -5486,6 +5518,7 @@ cc_library( testonly = 1, hdrs = ["hlo_rematerialization_test_utils.h"], deps = [ + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -5501,6 +5534,7 @@ xla_cc_test( ":hlo_rematerialization_test_utils", "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index 04238c4fd39f5a..353c4c6bac1c19 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -15,15 +15,21 @@ limitations under the License. #include "xla/service/buffer_assignment.h" +#include #include #include -#include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -31,24 +37,27 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" #include "xla/service/copy_insertion.h" #include "xla/service/flatten_call_graph.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_dce.h" +#include "xla/service/hlo_buffer.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_ordering.h" #include "xla/service/hlo_parser.h" -#include "xla/service/hlo_proto_util.h" +#include "xla/service/hlo_value.h" +#include "xla/service/logical_buffer.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 6c48208f65d8fd..9b3f9c9a54233a 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -38,8 +38,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" +#include "xla/service/hlo_verifier.h" #include "xla/service/host_memory_offload_annotations.h" +#include "xla/test_helpers.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/defuser_test.cc b/third_party/xla/xla/service/defuser_test.cc index ad70f7998c66a4..d2feaaf172f237 100644 --- a/third_party/xla/xla/service/defuser_test.cc +++ b/third_party/xla/xla/service/defuser_test.cc @@ -15,8 +15,16 @@ limitations under the License. #include "xla/service/defuser.h" +#include +#include +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/hlo_alias_analysis_test.cc b/third_party/xla/xla/service/hlo_alias_analysis_test.cc index ea687b640a58ef..fe126db410bda3 100644 --- a/third_party/xla/xla/service/hlo_alias_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_alias_analysis_test.cc @@ -15,15 +15,22 @@ limitations under the License. #include "xla/service/hlo_alias_analysis.h" -#include #include - +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/flatten_call_graph.h" +#include "xla/service/hlo_buffer.h" #include "xla/service/hlo_ordering.h" -#include "xla/service/instruction_fusion.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" @@ -31,6 +38,7 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_dfs_reachability_test.cc b/third_party/xla/xla/service/hlo_dfs_reachability_test.cc index 9bc77c75f42ea1..0b596b080e5a8e 100644 --- a/third_party/xla/xla/service/hlo_dfs_reachability_test.cc +++ b/third_party/xla/xla/service/hlo_dfs_reachability_test.cc @@ -16,14 +16,21 @@ limitations under the License. #include "xla/hlo/ir/hlo_dfs_reachability.h" #include -#include +#include +#include #include -#include "absl/random/random.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status.h" #include "tsl/platform/test_benchmark.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_reachability_test.cc b/third_party/xla/xla/service/hlo_reachability_test.cc index bc0d2b7293b47d..42d9ec789a5645 100644 --- a/third_party/xla/xla/service/hlo_reachability_test.cc +++ b/third_party/xla/xla/service/hlo_reachability_test.cc @@ -17,14 +17,21 @@ limitations under the License. #include #include +#include #include #include "absl/random/random.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status.h" #include "tsl/platform/test_benchmark.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_rematerialization_test_utils.h b/third_party/xla/xla/service/hlo_rematerialization_test_utils.h index 069494536f2637..37bece3354b888 100644 --- a/third_party/xla/xla/service/hlo_rematerialization_test_utils.h +++ b/third_party/xla/xla/service/hlo_rematerialization_test_utils.h @@ -18,12 +18,15 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_REMATERIALIZATION_TEST_UTILS_H_ #define XLA_SERVICE_HLO_REMATERIALIZATION_TEST_UTILS_H_ +#include #include #include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/hlo_rematerialization_test_utils_test.cc b/third_party/xla/xla/service/hlo_rematerialization_test_utils_test.cc index 803a0704fde839..7ab66fa50892ba 100644 --- a/third_party/xla/xla/service/hlo_rematerialization_test_utils_test.cc +++ b/third_party/xla/xla/service/hlo_rematerialization_test_utils_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index f6f9a729bcbc62..3da7cbf0818787 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -183,7 +183,7 @@ cc_library( "//xla:literal", "//xla:shape_layout", "//xla:shape_util", - "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", @@ -191,9 +191,10 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:computation_layout", + "//xla/service:computation_placer_hdr", + "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:hlo_module_util", - "//xla/service:hlo_parser", "//xla/service:hlo_runner", "//xla/service:hlo_runner_interface", "//xla/service:hlo_runner_pjrt", @@ -214,6 +215,7 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -936,17 +938,35 @@ xla_test( ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:array4d", "//xla:error_spec", + "//xla:executable_run_options", + "//xla:literal", + "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", + "//xla:test_helpers", + "//xla:types", + "//xla/client:client_library", + "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/client:xla_builder", "//xla/client/lib:arithmetic", "//xla/client/lib:matrix", + "//xla/service", "//xla/service:hlo_parser", + "//xla/service:platform_util", + "//xla/service:shaped_buffer", + "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index 28985fd2ba33b7..ee8d4653b97089 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" +#include #include #include #include @@ -28,20 +29,31 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/literal.h" +#include "xla/service/backend.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_module_util.h" +#include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_runner_pjrt.h" +#include "xla/service/hlo_verifier.h" #include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/filecheck.h" #include "xla/tests/literal_test_util.h" @@ -49,8 +61,10 @@ limitations under the License. #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index c312fc35090d7c..bd0a628bac45f7 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -37,19 +37,23 @@ limitations under the License. #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/layout.h" #include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_runner.h" +#include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_verifier.h" #include "xla/service/platform_util.h" #include "xla/shape_layout.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/verified_hlo_module.h" -#include "xla/types.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" @@ -147,7 +151,7 @@ class HloTestBase : public ::testing::Test { void MatchOptimizedHlo(absl::string_view hlo, absl::string_view pattern, bool print_operand_shape = false); - // LikeMatchOptimizedHlo, but checks operand shapes as well. + // Like MatchOptimizedHlo, but checks operand shapes as well. void MatchOptimizedHloWithShapes(absl::string_view hlo, absl::string_view pattern) { MatchOptimizedHlo(hlo, pattern, /*print_operand_shape=*/true); @@ -177,7 +181,7 @@ class HloTestBase : public ::testing::Test { bool allow_mixed_precision_in_hlo_verifier = true, HloPredicate instruction_can_change_layout_func = {}); - ~HloTestBase() override {} + ~HloTestBase() override = default; // Runs pass `hlo_pass` on input HLO module `hlo` with optional config, and // FileChecks the result against `expected`. @@ -287,16 +291,14 @@ class HloTestBase : public ::testing::Test { // reference backend. Note that the program shape of the module must not be // modified. [[nodiscard]] ::testing::AssertionResult RunAndCompare( - std::unique_ptr module, - const absl::Span arguments, + std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr); // Same as above, except that the module will be executed without Hlo // optimization. [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( - std::unique_ptr module, - const absl::Span arguments, + std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); @@ -323,11 +325,11 @@ class HloTestBase : public ::testing::Test { // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. [[nodiscard]] ::testing::AssertionResult RunAndCompare( - const absl::string_view hlo_string, const std::optional& error, + absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, std::optional args_max_bits_of_precision = std::nullopt); [[nodiscard]] ::testing::AssertionResult Run( - const absl::string_view hlo_string, bool run_hlo_passes = true, + absl::string_view hlo_string, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr, const tsl::protobuf::Message* backend_config = nullptr, bool use_random_data = true); @@ -355,7 +357,7 @@ class HloTestBase : public ::testing::Test { // Same as below, except requires passing fake arguments. ::testing::AssertionResult RunAndCompareTwoModules( std::unique_ptr module_0, std::unique_ptr module_1, - const absl::Span arguments, + absl::Span arguments, const std::optional& error, bool run_hlo_passes = true); // Same as below, except requires passing the modules. @@ -390,14 +392,14 @@ class HloTestBase : public ::testing::Test { // Executes an hlo module with fake inputs on multiple replicas. [[nodiscard]] ::testing::AssertionResult RunReplicated( - const absl::string_view hlo_string, bool run_hlo_passes = true, + absl::string_view hlo_string, bool run_hlo_passes = true, int64_t num_replicas = 1, const tsl::protobuf::Message* backend_config = nullptr); // If assert_determinism is true, the assertion will fail unless all runs // produce exactly the same output. [[nodiscard]] ::testing::AssertionResult RunMultipleTimes( - const absl::string_view hlo_string, bool run_hlo_passes, + absl::string_view hlo_string, bool run_hlo_passes, std::vector* profiles, const tsl::protobuf::Message* backend_config = nullptr, bool assert_determinism = false); @@ -405,7 +407,7 @@ class HloTestBase : public ::testing::Test { const std::string& filename, const std::optional& error, const std::function& reference_preprocessor = nullptr); [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( - const absl::string_view hlo_string, const std::optional& error, + absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( @@ -525,8 +527,7 @@ class HloTestBase : public ::testing::Test { // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( - std::unique_ptr module, - const absl::Span arguments, + std::unique_ptr module, absl::Span arguments, const std::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor, const std::function& test_preprocessor = nullptr); @@ -545,7 +546,7 @@ class HloTestBase : public ::testing::Test { // error happens before the results are computed, returns the error status. absl::StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( std::unique_ptr module_0, std::unique_ptr module_1, - const absl::Span arguments, + absl::Span arguments, const std::optional& error, bool run_hlo_passes); // Returns either an HloRunner or HloRunnerPjRt implementation depending if From e7d602494e94e36b36afb162a06a9310048d82e8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 29 Sep 2024 23:48:54 -0700 Subject: [PATCH 408/483] Automated Code Change PiperOrigin-RevId: 680438926 --- third_party/xla/xla/BUILD | 4 ++++ third_party/xla/xla/text_literal_reader.cc | 11 ++++++++--- third_party/xla/xla/text_literal_reader.h | 1 + third_party/xla/xla/text_literal_reader_test.cc | 1 - third_party/xla/xla/text_literal_writer.cc | 5 +++-- third_party/xla/xla/text_literal_writer.h | 1 + third_party/xla/xla/text_literal_writer_test.cc | 2 -- 7 files changed, 17 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 64553ebfafb0ee..00bd2c61d3520c 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -913,10 +913,13 @@ cc_library( "//xla/service:hlo_parser", "//xla/tsl/lib/io:buffered_inputstream", "//xla/tsl/lib/io:random_inputstream", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", ], ) @@ -946,6 +949,7 @@ cc_library( ":status_macros", ":types", ":xla_data_proto_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", diff --git a/third_party/xla/xla/text_literal_reader.cc b/third_party/xla/xla/text_literal_reader.cc index 3aaa23a8f958a0..bcbba1b9ca354c 100644 --- a/third_party/xla/xla/text_literal_reader.cc +++ b/third_party/xla/xla/text_literal_reader.cc @@ -21,6 +21,10 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" @@ -28,14 +32,15 @@ limitations under the License. #include "absl/strings/strip.h" #include "xla/literal.h" #include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/tsl/lib/io/buffered_inputstream.h" #include "xla/tsl/lib/io/random_inputstream.h" -#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/protobuf.h" +#include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/text_literal_reader.h b/third_party/xla/xla/text_literal_reader.h index 397229e74d81cf..20684755cae91d 100644 --- a/third_party/xla/xla/text_literal_reader.h +++ b/third_party/xla/xla/text_literal_reader.h @@ -24,6 +24,7 @@ limitations under the License. #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" namespace xla { diff --git a/third_party/xla/xla/text_literal_reader_test.cc b/third_party/xla/xla/text_literal_reader_test.cc index afeed461c61be2..11d76f224f4c9a 100644 --- a/third_party/xla/xla/text_literal_reader_test.cc +++ b/third_party/xla/xla/text_literal_reader_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/xla/text_literal_writer.cc b/third_party/xla/xla/text_literal_writer.cc index 050eb5fe835adc..83833dacbf5924 100644 --- a/third_party/xla/xla/text_literal_writer.cc +++ b/third_party/xla/xla/text_literal_writer.cc @@ -18,14 +18,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" #include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" namespace xla { diff --git a/third_party/xla/xla/text_literal_writer.h b/third_party/xla/xla/text_literal_writer.h index 2ce5b368773d34..a11205c905f626 100644 --- a/third_party/xla/xla/text_literal_writer.h +++ b/third_party/xla/xla/text_literal_writer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_TEXT_LITERAL_WRITER_H_ #define XLA_TEXT_LITERAL_WRITER_H_ +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/literal.h" #include "xla/types.h" diff --git a/third_party/xla/xla/text_literal_writer_test.cc b/third_party/xla/xla/text_literal_writer_test.cc index e517279a4c447d..657937f749fa32 100644 --- a/third_party/xla/xla/text_literal_writer_test.cc +++ b/third_party/xla/xla/text_literal_writer_test.cc @@ -18,12 +18,10 @@ limitations under the License. #include #include -#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" #include "tsl/platform/env.h" namespace xla { From 0927f90ce5d3c7a5dc26ee464ff86dbbd76e758d Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Mon, 30 Sep 2024 00:59:09 -0700 Subject: [PATCH 409/483] [XLA:GPU] Use MaterializeOp for side outputs in transpose fusion emitter PiperOrigin-RevId: 680458248 --- .../xla/service/gpu/fusions/transpose_mlir.cc | 79 ++++++++----------- 1 file changed, 31 insertions(+), 48 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 57e7d4fa104f5d..b7e409151541c9 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -239,7 +239,6 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( int num_inputs = fusion.fused_instructions_computation()->num_parameters(); SmallVector callee_operands( entry_function.getArguments().take_front(num_inputs)); - auto tids_and_bids = EmitThreadAndBlockIds(builder); auto identity_map = IndexingMapAttr::get(ctx, CreateIdentityMap(shmem_tensor_size, ctx)); @@ -260,8 +259,7 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( write_indexing.GetMutableSymbolBound(index) = bound; } write_indexing.Simplify(); - auto dimensions = SmallVector(operand_shape.dimensions().begin(), - operand_shape.dimensions().end()); + SmallVector shmem_tensors; for (auto* transpose : shmem_transposes_) { auto elem_type = mlir_converter::PrimitiveTypeToMlirType( @@ -277,69 +275,54 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( auto materialized = builder.create( /* result_type=*/indexed_vector, /*input=*/callee_operands, - /*indices(dimensions)=*/tids_and_bids, + /*indices(dimensions)=*/thread_and_block_ids, /*callee=*/callee, /*map=*/IndexingMapAttr::get(ctx, indexing)); auto insert = builder.create( /*result_type=*/shmem.getType(), /*source=*/materialized.getResult(), - /*indices(dimensions)=*/tids_and_bids, + /*indices(dimensions)=*/thread_and_block_ids, /*dest=*/shmem, /*map=*/identity_map); shmem_tensors.push_back(insert.getResult()); } - // Produce all side outputs and then write them. - SmallVector side_output_inits; - for (int index : side_output_root_indices_) { - side_output_inits.push_back(entry_function.getArgument(num_inputs + index)); - } - auto body_builder = [&](ValueRange symbol_values, ValueRange map_results, - ValueRange output_tensors) -> SmallVector { - auto input_indices = [&](const HloInstruction* instr) { - return ApplyIndexing(GetIndexing(/*input=*/true, instr->shape(), ctx), - thread_and_block_ids, symbol_values, builder); - }; - - SmallVector side_outputs; - SmallVector> side_output_indices; - auto* root_tuple = fusion.fused_expression_root(); - for (auto root : side_output_roots_) { - side_output_indices.push_back(input_indices(root)); - ValueRange param_values = mlir_converter::ProvideParameter( - root_computation, root_tuple, root_tuple->operand_index(root), - side_output_indices.back(), call_target_provider, entry_function, - builder); - side_outputs.append(param_values.begin(), param_values.end()); - } - - SmallVector result_tensors; - for (const auto& [value, indices, output] : - llvm::zip(side_outputs, side_output_indices, output_tensors)) { - result_tensors.push_back( - builder.create(value, output, indices)); - } + WriteResult result; + result.updated_outputs = output_args; + for (auto [index, root] : + llvm::zip(side_output_root_indices_, side_output_roots_)) { + auto elem_type = mlir_converter::PrimitiveTypeToMlirType( + root->shape().element_type(), builder); + auto callee = mlir::SymbolRefAttr::get(call_target_provider(root)); + auto side_indexing = GetIndexing(/*input=*/true, root->shape(), ctx); + auto side_dims = SmallVector(root->shape().dimensions().begin(), + root->shape().dimensions().end()); + auto indexed_vector = IndexedVectorType::get( + ctx, side_dims, elem_type, IndexingMapAttr::get(ctx, side_indexing)); + auto materialize = builder.create( + /* result_type=*/indexed_vector, + /*input=*/callee_operands, + /*indices(dimensions)=*/thread_and_block_ids, + /*callee=*/callee, + /*map=*/IndexingMapAttr::get(ctx, side_indexing)); - return result_tensors; - }; - mlir::ValueRange side_output_vector; - if (!side_output_inits.empty()) { - side_output_vector = mlir_converter::EmitXlaLoopOp( - builder, thread_and_block_ids, side_output_inits, indexing, - body_builder); + auto init = entry_function.getArgument(num_inputs + index); + auto side_identity_map = + IndexingMapAttr::get(ctx, CreateIdentityMap(side_dims, ctx)); + auto insert = builder.create( + /*result_type=*/init.getType(), + /*source=*/materialize.getResult(), + /*indices(dimensions)=*/thread_and_block_ids, + /*dest=*/init, + /*map=*/side_identity_map); + result.updated_outputs[index] = insert.getResult(); } - WriteResult result; result.shmem_tensors = builder .create(mlir::TypeRange(shmem_tensors), shmem_tensors) .getResults(); - result.updated_outputs = output_args; - for (auto [index, side_output_result] : - llvm::zip(side_output_root_indices_, side_output_vector)) { - result.updated_outputs[index] = side_output_result; - } return result; } From 70665f1572c58e4a08c5a496145cdb4b3d26166b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 01:03:29 -0700 Subject: [PATCH 410/483] Automated Code Change PiperOrigin-RevId: 680459451 --- tensorflow/compiler/mlir/tools/kernel_gen/BUILD | 2 ++ tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h | 1 + tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc | 1 + 3 files changed, 4 insertions(+) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 1dcd829833b985..7e319de3d8217b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "//tensorflow/core:lib", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithDialect", @@ -211,6 +212,7 @@ cc_library( "//tensorflow/core/framework:resource_base", "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:ExecutionEngine", "@local_tsl//tsl/platform:thread_annotations", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index 380fb8c52448cf..6b224e91bfb5eb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -24,6 +24,7 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc index 51a505dc8de1e2..79884a02769785 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "llvm/Support/Error.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project #include "tensorflow/core/platform/mutex.h" From efba8bc51e23e2c9b184edb72149bebc3291cf13 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 01:52:16 -0700 Subject: [PATCH 411/483] Automated Code Change PiperOrigin-RevId: 680473564 --- .../core/common_runtime/pluggable_device/BUILD | 9 +++++++++ .../pluggable_device/pluggable_device.cc | 18 ++++++++++-------- .../pluggable_device/pluggable_device.h | 4 ++++ .../pluggable_device_bfc_allocator.cc | 7 ++++++- .../pluggable_device_context.cc | 8 +++++--- .../pluggable_device_context.h | 3 +++ .../pluggable_device_factory.h | 3 +++ .../pluggable_device/pluggable_device_init.cc | 8 +------- .../pluggable_device/pluggable_device_init.h | 1 + .../pluggable_device_plugin_init.cc | 6 +++++- .../pluggable_device_process_state.cc | 14 +++++++++----- .../pluggable_device_simple_allocator.cc | 3 ++- .../pluggable_device_simple_allocator.h | 1 + .../pluggable_device/pluggable_device_util.cc | 17 ++++++----------- .../pluggable_device/pluggable_device_util.h | 3 +++ 15 files changed, 68 insertions(+), 37 deletions(-) diff --git a/tensorflow/core/common_runtime/pluggable_device/BUILD b/tensorflow/core/common_runtime/pluggable_device/BUILD index 5121950354e86e..cc0ee4e16d5418 100644 --- a/tensorflow/core/common_runtime/pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/pluggable_device/BUILD @@ -55,6 +55,7 @@ cc_library( "//tensorflow/core/common_runtime/device:device_event_mgr", "//tensorflow/core/platform:stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_xla//xla/stream_executor", @@ -86,6 +87,7 @@ cc_library( "//tensorflow/compiler/jit:pjrt_device_context", "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -99,6 +101,8 @@ cc_library( "//tensorflow/core/common_runtime/next_pluggable_device:next_pluggable_device_factory", "//tensorflow/core/platform:stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@local_xla//xla/pjrt:pjrt_api", "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", @@ -115,6 +119,7 @@ cc_library( linkstatic = 1, deps = [ "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -151,6 +156,9 @@ cc_library( "//tensorflow/core/common_runtime:bfc_allocator", "//tensorflow/core/common_runtime/device:device_id", "//tensorflow/core/common_runtime/device:device_mem_allocator", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@local_xla//xla/tsl/framework:bfc_allocator", ], ) @@ -168,6 +176,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/device:device_id", "//tensorflow/core/common_runtime/device:device_mem_allocator", + "//tensorflow/core/framework:allocator", ], ) diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc index 161a21c5d44615..d200fdb2dec4e6 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc @@ -27,38 +27,40 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/common_runtime/device/device_event_mgr.h" #include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device/device_id_manager.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_id_utils.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h" -#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/env_var.h" -#include "tensorflow/core/util/stream_executor_util.h" +#include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h index 67aa658a3fd9d8..a6ccf62283c032 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/common_runtime/device/device_event_mgr.h" #include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device/device_id_manager.h" @@ -28,13 +29,16 @@ limitations under the License. #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h" #include "tensorflow/core/common_runtime/shared_counter.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc index eba49bb8a1aff9..1523c64a1ae4ed 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc @@ -17,7 +17,12 @@ limitations under the License. #include -#include "tensorflow/core/lib/strings/strcat.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "xla/tsl/framework/bfc_allocator.h" +#include "tensorflow/core/common_runtime/device/device_mem_allocator.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc index ec3faf2d6329ca..bae16315cbdcad 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc @@ -17,12 +17,14 @@ limitations under the License. #include -#include "tensorflow/core/common_runtime/device.h" +#include "absl/status/status.h" #include "tensorflow/core/common_runtime/device/device_event_mgr.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h index 3baef39192487d..afe6d9b6d6904c 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h @@ -20,7 +20,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" namespace stream_executor { class Stream; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h index c423dd2a1fc13a..a3e9b3a781cc76 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h @@ -24,6 +24,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc index 0b7279d0098ac1..2c5db2a4f0b176 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc @@ -18,15 +18,9 @@ limitations under the License. #include #include "xla/stream_executor/platform_manager.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h index 6362de9856ae38..9410d507807d70 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" namespace stream_executor { class Platform; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc index a3879acd5daa08..2b9fc947f3834b 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "tensorflow/c/experimental/grappler/grappler_internal.h" #include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h" @@ -31,9 +33,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc index f9f0ad68977516..bea6e122185faa 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc @@ -20,29 +20,33 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" -#include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/device_id_utils.h" +#include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/device/device_host_allocator.h" #include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device/device_id_manager.h" #include "tensorflow/core/common_runtime/device/device_mem_allocator.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_id_utils.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h" -#include "tensorflow/core/common_runtime/pool_allocator.h" -#include "tensorflow/core/common_runtime/shared_counter.h" +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/numa.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/env_var.h" +#include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.cc index c8fee4b0ef6ec9..4fdc86fa045b50 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.cc @@ -16,7 +16,8 @@ limitations under the License. #include -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/common_runtime/device/device_mem_allocator.h" +#include "tensorflow/core/framework/allocator.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h index 3a46f766809449..dccb2548868f1e 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/device/device_mem_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc index ecb08eec856dc9..33a3b6351678b6 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc @@ -15,26 +15,21 @@ limitations under the License. #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h" -#include "tensorflow/core/common_runtime/copy_tensor.h" -#include "tensorflow/core/common_runtime/device.h" +#include "absl/status/status.h" +#include "xla/stream_executor/device_memory.h" #include "tensorflow/core/common_runtime/device/device_event_mgr.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h" -#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_reference.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/core/platform/tensor_coding.h" -#include "tensorflow/core/util/util.h" +#include "tensorflow/core/platform/status.h" // IMPLEMENTATION NOTE: // diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h index 8cff5449c853f5..9aad40dd82e829 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h @@ -18,6 +18,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor.h" From 4c33565dfcbf9922cec3f193562aabe991551a29 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 01:56:13 -0700 Subject: [PATCH 412/483] Automated Code Change PiperOrigin-RevId: 680474574 --- third_party/xla/xla/service/gpu/fusions/BUILD | 4 ++++ third_party/xla/xla/service/gpu/fusions/reduction_base.cc | 5 ----- third_party/xla/xla/service/gpu/fusions/reduction_base.h | 1 + third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc | 4 ++++ third_party/xla/xla/service/gpu/fusions/reduction_mlir.h | 1 + 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 9b813219939ef2..d5ef2439b59304 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -519,6 +519,7 @@ cc_library( ":reduction_base", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", @@ -531,10 +532,13 @@ cc_library( "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir:type_util", "//xla/service/gpu/model:indexing_analysis", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc index b7f62e3c7d1d54..dfac23affc34e3 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc @@ -31,7 +31,6 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/primitive_util.h" @@ -40,14 +39,10 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/union_find.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.h b/third_party/xla/xla/service/gpu/fusions/reduction_base.h index ad99e6f40140c9..0239838c2b8bb3 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.h @@ -22,6 +22,7 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/reduction_utils.h" +#include "xla/util.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index d195d40fe7e0bd..7897966a91409a 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -27,6 +27,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" @@ -61,7 +63,9 @@ limitations under the License. #include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h index db0fbd2b45c31c..b4deb0ee862b9e 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/types/span.h" From 713920b56df2cab244278be5524e2874b698f907 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 30 Sep 2024 02:00:09 -0700 Subject: [PATCH 413/483] Remove cuda_only_cc_library Since now we can exclude targets from building using tags, we won't need the `cuda_only_cc_library` rule anymore. This also required me to remove some wrongly added dependencies, notably I found several targets depending on cublas_plugin, even though those targets were not CUDA specific and shouldn't directly depend on CUDA-specific targets. I also found out that the `tsl_gpu_library` macro is not handling its `cuda_deps` attribute correctly. It was adding those dependencies both for ROCm and for CUDA. So this change is fixing that as well. PiperOrigin-RevId: 680475594 --- third_party/xla/xla/BUILD | 1 + third_party/xla/xla/lit.bzl | 7 +- third_party/xla/xla/pjrt/gpu/BUILD | 1 - .../xla/xla/service/gpu/fusions/triton/BUILD | 3 - third_party/xla/xla/service/gpu/kernels/BUILD | 1 - .../xla/xla/stream_executor/build_defs.bzl | 30 ----- .../xla/xla/stream_executor/cuda/BUILD | 125 +++++++++++++----- third_party/xla/xla/stream_executor/gpu/BUILD | 15 ++- third_party/xla/xla/tsl/tsl.bzl | 6 +- 9 files changed, 116 insertions(+), 73 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 00bd2c61d3520c..bf0b97f0ec72aa 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1351,6 +1351,7 @@ bzl_library( deps = [ "//xla/tsl:tsl_bzl", "@bazel_skylib//lib:paths", + "@local_tsl//tsl/platform/default:cuda_build_defs_bzl", ], ) diff --git a/third_party/xla/xla/lit.bzl b/third_party/xla/xla/lit.bzl index 5837c54ad81eab..5ac1cde98f1d8c 100644 --- a/third_party/xla/xla/lit.bzl +++ b/third_party/xla/xla/lit.bzl @@ -1,6 +1,7 @@ """Helper rules for writing LIT tests.""" load("@bazel_skylib//lib:paths.bzl", "paths") +load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") load("//xla/tsl:tsl.bzl", "if_cuda_tools", "if_google", "if_oss") def enforce_glob(files, **kwargs): @@ -209,7 +210,11 @@ def lit_test( srcs = tools, bin_dir = bin_dir, lib_dir = lib_dir, - deps = ["//xla/stream_executor/cuda:all_runtime"], + deps = if_cuda_is_configured( + [ + "//xla/stream_executor/cuda:all_runtime", + ], + ), visibility = ["//visibility:private"], **kwargs ) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 9c6949f79661f8..8bb23efbab7c82 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -315,7 +315,6 @@ xla_test( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", "//xla/service:hlo_parser", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:literal_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 99489d8e6d41b4..6a8c9ff80aec95 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -223,7 +223,6 @@ xla_test( "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:filecheck", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", @@ -267,7 +266,6 @@ xla_test( "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", @@ -431,7 +429,6 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 5c517476917932..1e5bc9d4847126 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -242,7 +242,6 @@ xla_test( "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", - "//xla/stream_executor/cuda:cuda_platform", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/stream_executor/build_defs.bzl b/third_party/xla/xla/stream_executor/build_defs.bzl index 109872e2a0f0df..3204b886c651ff 100644 --- a/third_party/xla/xla/stream_executor/build_defs.bzl +++ b/third_party/xla/xla/stream_executor/build_defs.bzl @@ -1,6 +1,5 @@ """Configurations for StreamExecutor builds""" -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load( "@local_config_rocm//rocm:build_defs.bzl", _if_cuda_or_rocm = "if_cuda_or_rocm", @@ -64,34 +63,5 @@ def gpu_only_cc_library(name, tags = [], **kwargs): target_compatible_with = kwargs.get("target_compatible_with"), ) -def cuda_only_cc_library(name, tags = [], **kwargs): - """A library that only gets compiled when CUDA is configured, otherwise it's an empty target. - - Args: - name: Name of the target - tags: Tags being applied to the implementation target - **kwargs: Accepts all arguments that a `cc_library` would also accept - """ - if not native.package_name().startswith("xla/stream_executor"): - fail("cuda_only_cc_library may only be used in `xla/stream_executor/...`.") - - cc_library( - name = "%s_non_cuda" % name, - tags = ["manual"], - ) - cc_library( - name = "%s_cuda_only" % name, - tags = tags + ["manual", "cuda-only"], - **kwargs - ) - native.alias( - name = name, - actual = if_cuda_is_configured(":%s_cuda_only" % name, ":%s_non_cuda" % name), - visibility = kwargs.get("visibility"), - compatible_with = kwargs.get("compatible_with"), - restricted_to = kwargs.get("restricted_to"), - target_compatible_with = kwargs.get("target_compatible_with"), - ) - def stream_executor_build_defs_bzl_deps(): return [] diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 529a0b6197935f..26c394a144300c 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -10,20 +10,14 @@ load( ) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", "if_cuda_newer_than", ) load( "//xla:xla.bzl", "xla_cc_test", ) -load( - "//xla/service/gpu:build_defs.bzl", - "gpu_kernel_library", -) load( "//xla/stream_executor:build_defs.bzl", - "cuda_only_cc_library", "stream_executor_friends", "tf_additional_cuda_platform_deps", "tf_additional_cudnn_plugin_copts", @@ -87,10 +81,14 @@ cc_library( deps = ["//xla/stream_executor:platform"], ) -cuda_only_cc_library( +cc_library( name = "cuda_platform", srcs = ["cuda_platform.cc"], hdrs = ["cuda_platform.h"], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ @@ -123,10 +121,14 @@ cuda_only_cc_library( alwayslink = True, # Registers itself with the PlatformManager. ) -cuda_only_cc_library( +cc_library( name = "cuda_diagnostics", srcs = ["cuda_diagnostics.cc"], hdrs = ["cuda_diagnostics.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "//xla/stream_executor/gpu:gpu_diagnostics_header", "@com_google_absl//absl/container:inlined_vector", @@ -157,10 +159,14 @@ cc_library( ), ) -cuda_only_cc_library( +cc_library( name = "cuda_driver", srcs = ["cuda_driver.cc"], hdrs = ["cuda_driver.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":cuda_diagnostics", # buildcleaner: keep ":cuda_status", @@ -198,10 +204,14 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_status", srcs = ["cuda_status.cc"], hdrs = ["cuda_status.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", @@ -210,10 +220,14 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_runtime", srcs = ["cuda_runtime.cc"], hdrs = ["cuda_runtime.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -225,10 +239,13 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_collectives", hdrs = ["cuda_collectives.h"], - tags = ["gpu"], + tags = [ + "cuda-only", + "gpu", + ], deps = if_nccl( [":cuda_collectives_impl"], [":cuda_collectives_stub"], @@ -246,6 +263,7 @@ cc_library( "cuda_collectives.h", ], tags = [ + "cuda-only", "gpu", "manual", ], @@ -318,12 +336,16 @@ xla_test( ], ) -cuda_only_cc_library( +cc_library( name = "cublas_lt_header", hdrs = [ "cuda_blas_lt.h", "cuda_blas_utils.h", ], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ "//xla:types", @@ -340,7 +362,7 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cublas_plugin", srcs = [ "cuda_blas.cc", @@ -350,6 +372,10 @@ cuda_only_cc_library( "cuda_blas.h", "cuda_blas_lt.h", ], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_blas_utils", @@ -402,10 +428,14 @@ cuda_only_cc_library( alwayslink = True, ) -cuda_only_cc_library( +cc_library( name = "cuda_blas_utils", srcs = ["cuda_blas_utils.cc"], hdrs = ["cuda_blas_utils.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "//xla/stream_executor", "//xla/stream_executor:blas", @@ -418,10 +448,14 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cufft_plugin", srcs = ["cuda_fft.cc"], hdrs = ["cuda_fft.h"], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_helpers", @@ -447,13 +481,17 @@ cuda_only_cc_library( alwayslink = True, ) -gpu_kernel_library( +cuda_library( name = "delay_kernel_cuda", srcs = [ "delay_kernel.h", "delay_kernel_cuda.cu.cc", ], - tags = ["manual"], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "cuda-only", + "gpu", + ], visibility = internal_visibility([ "//xla/stream_executor:__subpackages__", ]), @@ -467,11 +505,15 @@ gpu_kernel_library( ], ) -cuda_only_cc_library( +cc_library( name = "cudnn_plugin", srcs = ["cuda_dnn.cc"], hdrs = ["cuda_dnn.h"], copts = tf_additional_cudnn_plugin_copts(), + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_diagnostics", @@ -524,10 +566,14 @@ cuda_only_cc_library( alwayslink = True, ) -cuda_only_cc_library( +cc_library( name = "cuda_kernel", srcs = ["cuda_kernel.cc"], hdrs = ["cuda_kernel.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_driver_header", @@ -574,22 +620,28 @@ cuda_library( ], ) -# TODO(leary) we likely need to canonicalize/eliminate this. cc_library( name = "cuda_helpers", - textual_hdrs = if_cuda_is_configured(["cuda_helpers.h"]), - deps = if_cuda_is_configured([ + hdrs = ["cuda_helpers.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ "//xla/stream_executor/gpu:gpu_helpers_header", - "@local_config_cuda//cuda:cuda_headers", - ]) + [ "@com_google_absl//absl/log:check", + "@local_config_cuda//cuda:cuda_headers", ], ) -cuda_only_cc_library( +cc_library( name = "cuda_event", srcs = ["cuda_event.cc"], hdrs = ["cuda_event.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":cuda_driver", "//xla/stream_executor:event", @@ -825,7 +877,7 @@ xla_cc_test( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_asm_compiler", srcs = ["cuda_asm_compiler.cc"], hdrs = ["cuda_asm_compiler.h"], @@ -844,6 +896,10 @@ cuda_only_cc_library( "@cuda_nvcc//:ptxas", ]), # copybara:comment_end + tags = [ + "cuda-only", + "gpu", + ], visibility = internal_visibility([ "//third_party/py/jax:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", @@ -889,7 +945,7 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_executor", srcs = [ "cuda_executor.cc", @@ -898,6 +954,10 @@ cuda_only_cc_library( hdrs = [ "cuda_executor.h", ], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":cuda_collectives", ":cuda_diagnostics", @@ -908,6 +968,7 @@ cuda_only_cc_library( ":cuda_runtime", ":cuda_status", ":cuda_version_parser", + ":delay_kernel_cuda", "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", @@ -954,13 +1015,17 @@ cuda_only_cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([":delay_kernel_cuda"]), + ], alwayslink = True, ) cc_library( name = "all_runtime", copts = tsl_copts(), + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cublas_plugin", diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 7c0bdb7d3c0361..df510e57075b35 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -498,7 +498,11 @@ cc_library( "redzone_allocator_kernel.h", "redzone_allocator_kernel_cuda.cc", ], - tags = ["manual"], + tags = [ + "cuda-only", + "gpu", + "manual", + ], deps = [ ":gpu_asm_opts", "//xla/stream_executor", @@ -596,9 +600,12 @@ xla_test( cc_library( name = "gpu_cudamallocasync_allocator", - srcs = if_cuda_is_configured(["gpu_cudamallocasync_allocator.cc"]), - hdrs = if_cuda_is_configured(["gpu_cudamallocasync_allocator.h"]), - tags = ["gpu"], + srcs = ["gpu_cudamallocasync_allocator.cc"], + hdrs = ["gpu_cudamallocasync_allocator.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":gpu_init_impl", "//xla/stream_executor:stream_executor_h", diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index 0cf769ddf4eadb..cca0a3001a307d 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -21,7 +21,6 @@ load( load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm", - "if_rocm_is_configured", ) load( "@local_tsl//tsl/platform:rules_cc.bzl", @@ -367,7 +366,7 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs cuda_deps = [] kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] - deps = deps + if_cuda_or_rocm(cuda_deps) + deps = deps + if_cuda(cuda_deps) if "default_copts" in kwargs: copts = kwargs["default_copts"] + copts kwargs.pop("default_copts", None) @@ -375,7 +374,8 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs deps = deps + if_cuda([ clean_dep("//xla/tsl/cuda:cudart"), "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ + ]) + if_rocm([ + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ]), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1", "-DNV_CUDNN_DISABLE_EXCEPTION"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), From d68fedba42a997b105c3e4dd11fe21e2b52c09a2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 02:05:47 -0700 Subject: [PATCH 414/483] Automated Code Change PiperOrigin-RevId: 680477674 --- .../core/kernels/linalg/einsum_op_impl.h | 24 +++++++++---------- .../core/kernels/linalg/linalg_ops_common.h | 6 ++--- tensorflow/core/kernels/linalg/lu_op.cc | 4 ++-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index 156bb80a02e1a9..3d44394752bc88 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -52,12 +52,12 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -using ShapeVec = gtl::InlinedVector; -using Labels = gtl::InlinedVector; -using OperandLabels = gtl::InlinedVector; -using LabelCounts = gtl::InlinedVector; -using OperandLabelCounts = gtl::InlinedVector; -using LabelToDimSizes = gtl::InlinedVector; +using ShapeVec = absl::InlinedVector; +using Labels = absl::InlinedVector; +using OperandLabels = absl::InlinedVector; +using LabelCounts = absl::InlinedVector; +using OperandLabelCounts = absl::InlinedVector; +using LabelToDimSizes = absl::InlinedVector; struct EinsumHelper { // Insert new (unnamed) broadcasting labels at the location of ellipsis. @@ -97,7 +97,7 @@ struct EinsumHelper { // counts. static Status ProcessDimensions( const OpInputList& inputs, - const gtl::InlinedVector& input_has_ellipsis, + const absl::InlinedVector& input_has_ellipsis, const bool output_has_ellipsis, OperandLabels* input_labels, Labels* output_labels, std::vector* label_types, OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, @@ -321,7 +321,7 @@ struct EinsumHelper { const std::vector& label_types) { // Check that ordering is according to dimension type, with the role of // free and contract dimensions swapped. - gtl::InlinedVector remap = {0, 1, 3, 2, 4}; + absl::InlinedVector remap = {0, 1, 3, 2, 4}; for (int i = 0; i + 1 < labels.size(); ++i) { const int dimtype_a = remap[label_types[labels[i]]]; const int dimtype_b = remap[label_types[labels[i + 1]]]; @@ -372,7 +372,7 @@ struct EinsumHelper { // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, // reduce] where we've compacted the dimensions of each EinsumDimensionType. - gtl::InlinedVector reshape(5, 1); + absl::InlinedVector reshape(5, 1); // The output shape is [batch shape] + [free size, contract size] // That is, the batch shape is preserved (for broadcasting while // contracting) while the free dims and contract dims are compressed to one @@ -513,8 +513,8 @@ class EinsumOp : public OpKernel { // dimensions, respectively. const int num_inputs = inputs.size(); OperandLabels free_labels(num_inputs); - gtl::InlinedVector inputs_reduced(num_inputs); - gtl::InlinedVector swap_free_and_contract(num_inputs); + absl::InlinedVector inputs_reduced(num_inputs); + absl::InlinedVector swap_free_and_contract(num_inputs); for (int i = 0; i < num_inputs; ++i) { OP_REQUIRES_OK(ctx, EinsumHelper::ReduceOperand( @@ -627,7 +627,7 @@ class EinsumOp : public OpKernel { std::vector label_types_; OperandLabelCounts input_label_counts_; LabelCounts output_label_counts_; - gtl::InlinedVector input_has_ellipsis_; + absl::InlinedVector input_has_ellipsis_; bool output_has_ellipsis_ = false; }; diff --git a/tensorflow/core/kernels/linalg/linalg_ops_common.h b/tensorflow/core/kernels/linalg/linalg_ops_common.h index d4d66bd4c8f809..d774ad9e6bced5 100644 --- a/tensorflow/core/kernels/linalg/linalg_ops_common.h +++ b/tensorflow/core/kernels/linalg/linalg_ops_common.h @@ -43,7 +43,7 @@ class LinearAlgebraOp : public OpKernel { void Compute(OpKernelContext* context) override; protected: - using TensorShapes = gtl::InlinedVector; + using TensorShapes = absl::InlinedVector; // Returns the number of leading inputs that are to be treated as matrix // inputs. By default this is all the inputs. Derived classes can override // this to tell the base class to ignore one or more trailing inputs. @@ -152,8 +152,8 @@ class LinearAlgebraOp : public OpKernel { OutputMatrixMaps* outputs) = 0; private: - using TensorInputs = gtl::InlinedVector; - using TensorOutputs = gtl::InlinedVector; + using TensorInputs = absl::InlinedVector; + using TensorOutputs = absl::InlinedVector; // This function maps 2-d slices (matrices) of the input and output tensors // using Eigen::Map and calls ComputeMatrix implemented in terms of the // Eigen::MatrixBase API by the derived class. diff --git a/tensorflow/core/kernels/linalg/lu_op.cc b/tensorflow/core/kernels/linalg/lu_op.cc index 770c5d8fe6c67c..e1525bf5937eb6 100644 --- a/tensorflow/core/kernels/linalg/lu_op.cc +++ b/tensorflow/core/kernels/linalg/lu_op.cc @@ -32,8 +32,8 @@ class LuOp : public OpKernel { explicit LuOp(OpKernelConstruction* context) : OpKernel(context) {} protected: - using TensorShapes = gtl::InlinedVector; - using TensorOutputs = gtl::InlinedVector; + using TensorShapes = absl::InlinedVector; + using TensorOutputs = absl::InlinedVector; using Matrix = Eigen::Matrix; From 828eb257a858f5eafe9084ab9c60dad979dbb4b4 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 30 Sep 2024 02:06:49 -0700 Subject: [PATCH 415/483] #sdy remove debug print. PiperOrigin-RevId: 680477952 --- .../xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc index 7e4fbe31f7a4ed..056292a06a3c47 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc @@ -182,7 +182,6 @@ class SdyRoundTripImportShardingsPass // Insert the meshes before any functions. builder.setInsertionPointToStart(moduleOp.getBody()); for (NamedAttribute mesh : sdyMeshes) { - mesh.getValue().dump(); auto meshAttr = mlir::cast(mesh.getValue()); symbolTable.insert(builder.create( moduleOp.getLoc(), mesh.getName(), meshAttr)); From f083ae294eb0733bff4ec31665a30925824e2d5d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 02:45:13 -0700 Subject: [PATCH 416/483] Automated Code Change PiperOrigin-RevId: 680488776 --- third_party/xla/xla/service/gpu/fusions/BUILD | 5 +++++ third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc | 2 ++ third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h | 1 + third_party/xla/xla/service/gpu/fusions/cudnn_test.cc | 2 ++ third_party/xla/xla/service/gpu/fusions/custom.cc | 2 ++ third_party/xla/xla/service/gpu/fusions/fusions.cc | 3 --- third_party/xla/xla/service/gpu/fusions/fusions.h | 1 + .../service/gpu/fusions/in_place_dynamic_update_slice_mlir.h | 1 + 8 files changed, 14 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index d5ef2439b59304..7c6a48a9efb03c 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -18,6 +18,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu/fusions/mlir:computation_partitioner", @@ -62,6 +63,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":fusion_emitter", + "//xla:literal", "//xla:shape_util", "//xla:status_macros", "//xla:util", @@ -449,6 +451,7 @@ xla_test( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -570,6 +573,7 @@ cc_library( srcs = ["concatenate_mlir.cc"], hdrs = ["concatenate_mlir.h"], deps = [ + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", @@ -578,6 +582,7 @@ cc_library( "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/model:indexing_analysis", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc index d4bbe647f152d2..e977408fde2a8f 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -42,6 +43,7 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h index b98db45690389c..c6223d8fc38a4f 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index 9b513ba16871fa..3df1e18b6652a8 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" @@ -46,6 +47,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index c4270462ccdc16..dd6126dc5bc64f 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" #include "xla/service/buffer_assignment.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -61,6 +62,7 @@ limitations under the License. #include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/stream_executor_util.h" diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 200f06f8461db5..ff15b29724c79c 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -20,8 +20,6 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/strings/match.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -42,7 +40,6 @@ limitations under the License. #include "xla/service/gpu/fusions/legacy/scatter.h" #include "xla/service/gpu/fusions/legacy/transpose.h" #include "xla/service/gpu/fusions/loop_mlir.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/reduction_mlir.h" #include "xla/service/gpu/fusions/scatter_mlir.h" #include "xla/service/gpu/fusions/transpose_mlir.h" diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.h b/third_party/xla/xla/service/gpu/fusions/fusions.h index f7406b463b9117..6e9f16aa1deca8 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.h +++ b/third_party/xla/xla/service/gpu/fusions/fusions.h @@ -23,6 +23,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emission_utils.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h index 7d94f74ced0e46..ab1a71f1c6c8a4 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h @@ -28,6 +28,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" From 9036e73624748589d6fb7f336d1cc1244b7d5490 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 03:13:26 -0700 Subject: [PATCH 417/483] Update GraphDef version to 2001. PiperOrigin-RevId: 680496247 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index cea3f2c3853654..641712a9d171ce 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2000 // Updated: 2024/9/29 +#define TF_GRAPH_DEF_VERSION 2001 // Updated: 2024/9/30 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 5bb2d8348aa0d5e655a5086b669bdd038454920c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 03:13:56 -0700 Subject: [PATCH 418/483] compat: Update forward compatibility horizon to 2024-09-30 PiperOrigin-RevId: 680496364 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 80b1f00456289f..cd736eb9b0e74b 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 29) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 30) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From a108b9f605354eb1de505e36cdc37f7783d5b652 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Mon, 30 Sep 2024 03:58:20 -0700 Subject: [PATCH 419/483] Include-what-you-use fixes. Also, remove superfluous `const` on return type. PiperOrigin-RevId: 680507645 --- tensorflow/lite/core/tools/verifier.cc | 5 ++++- tensorflow/lite/core/tools/verifier_internal.cc | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/core/tools/verifier.cc b/tensorflow/lite/core/tools/verifier.cc index d2e110a1358f81..6dc5647ada468c 100644 --- a/tensorflow/lite/core/tools/verifier.cc +++ b/tensorflow/lite/core/tools/verifier.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/lite/core/tools/verifier.h" +#include + #include #include #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" @@ -57,7 +60,7 @@ void ReportError(ErrorReporter* error_reporter, const char* format, ...) { } // Returns the int32_t value pointed by ptr. -const uint32_t GetIntPtr(const char* ptr) { +uint32_t GetIntPtr(const char* ptr) { #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ return flatbuffers::EndianScalar(*reinterpret_cast(ptr)); diff --git a/tensorflow/lite/core/tools/verifier_internal.cc b/tensorflow/lite/core/tools/verifier_internal.cc index 706d534d6320bf..1f0b537acc5001 100644 --- a/tensorflow/lite/core/tools/verifier_internal.cc +++ b/tensorflow/lite/core/tools/verifier_internal.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/lite/core/tools/verifier_internal.h" +#include +#include + #include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/schema/schema_generated.h" From b906740591d383dc8d2205db3d39a15359428166 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 30 Sep 2024 05:29:37 -0700 Subject: [PATCH 420/483] Move CreateDeviceDescription to CudaExecutor and RocmExecutor and add tests The function is only being called from backend-specific code (from `{cuda|rocm}_executor.cc` and `{cuda|rocm}_platform.cc`, so there is no need for it to be present in `GpuExecutor`. I also had to complete the comparison operators for `CudaComputeCapability` and added tests for that. PiperOrigin-RevId: 680533311 --- .../xla/xla/stream_executor/cuda/BUILD | 17 ++++++ .../xla/stream_executor/cuda/cuda_executor.cc | 2 +- .../xla/stream_executor/cuda/cuda_executor.h | 5 +- .../cuda/cuda_executor_test.cc | 58 +++++++++++++++++++ .../xla/stream_executor/cuda/cuda_platform.cc | 4 +- .../xla/stream_executor/cuda/cuda_platform.h | 1 - .../xla/stream_executor/device_description.h | 12 ++++ .../device_description_test.cc | 30 ++++++++++ .../xla/stream_executor/gpu/gpu_executor.h | 2 - .../xla/xla/stream_executor/rocm/BUILD | 8 +++ .../xla/stream_executor/rocm/rocm_executor.cc | 2 +- .../xla/stream_executor/rocm/rocm_executor.h | 5 +- .../rocm/rocm_executor_test.cc | 54 +++++++++++++++++ .../xla/stream_executor/rocm/rocm_platform.cc | 4 +- .../xla/stream_executor/rocm/rocm_platform.h | 1 - 15 files changed, 191 insertions(+), 14 deletions(-) create mode 100644 third_party/xla/xla/stream_executor/cuda/cuda_executor_test.cc create mode 100644 third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 26c394a144300c..4255d6566b0a66 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1019,6 +1019,23 @@ cc_library( alwayslink = True, ) +xla_test( + name = "cuda_executor_test", + srcs = ["cuda_executor_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":cuda_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:semantic_version", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "all_runtime", copts = tsl_copts(), diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 5dc8838aa02a44..7161102179adfa 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -678,7 +678,7 @@ absl::Status CudaExecutor::TrimGraphMemory() { } absl::StatusOr> -GpuExecutor::CreateDeviceDescription(int device_ordinal) { +CudaExecutor::CreateDeviceDescription(int device_ordinal) { GpuDeviceHandle device; TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal, &device)); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h index 18d3fff5c5d976..97276c650abe61 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h @@ -114,7 +114,7 @@ class CudaExecutor : public GpuExecutor { absl::StatusOr> CreateDeviceDescription() const override { - return GpuExecutor::CreateDeviceDescription(device_ordinal()); + return CudaExecutor::CreateDeviceDescription(device_ordinal()); } void* UnifiedMemoryAllocate(uint64_t size) override { return GpuDriver::UnifiedMemoryAllocate(gpu_context(), size); @@ -154,6 +154,9 @@ class CudaExecutor : public GpuExecutor { return it->second; } + static absl::StatusOr> + CreateDeviceDescription(int device_ordinal); + private: // Collects metadata for the specified kernel. absl::Status GetKernelMetadata(GpuKernel* cuda_kernel, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor_test.cc new file mode 100644 index 00000000000000..05f1000818cce9 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_executor.h" + +#include + +#include +#include +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/semantic_version.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using testing::Ge; +using testing::IsEmpty; +using testing::Not; +using testing::VariantWith; + +TEST(CudaExecutorTest, CreateDeviceDescription) { + TF_ASSERT_OK(GpuDriver::Init()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + CudaExecutor::CreateDeviceDescription(0)); + + constexpr SemanticVersion kNullVersion{0, 0, 0}; + EXPECT_NE(result->runtime_version(), kNullVersion); + EXPECT_NE(result->driver_version(), kNullVersion); + EXPECT_NE(result->compile_time_toolkit_version(), kNullVersion); + + EXPECT_THAT(result->platform_version(), Not(IsEmpty())); + EXPECT_THAT(result->name(), Not(IsEmpty())); + EXPECT_THAT(result->model_str(), Not(IsEmpty())); + EXPECT_THAT(result->device_vendor(), "NVIDIA Corporation"); + + EXPECT_THAT( + result->gpu_compute_capability(), + VariantWith(Ge(CudaComputeCapability{1, 0}))); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc index 42c2808dc1f23e..d0de944b070301 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc @@ -38,8 +38,6 @@ namespace gpu { CudaPlatform::CudaPlatform() : name_("CUDA") {} -CudaPlatform::~CudaPlatform() {} - Platform::Id CudaPlatform::id() const { return cuda::kCudaPlatformId; } int CudaPlatform::VisibleDeviceCount() const { @@ -55,7 +53,7 @@ const std::string& CudaPlatform::Name() const { return name_; } absl::StatusOr> CudaPlatform::DescriptionForDevice(int ordinal) const { - return GpuExecutor::CreateDeviceDescription(ordinal); + return CudaExecutor::CreateDeviceDescription(ordinal); } absl::StatusOr CudaPlatform::ExecutorForDevice(int ordinal) { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h index e4ba806343f091..b6aa9e3a4448f4 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h @@ -38,7 +38,6 @@ namespace gpu { class CudaPlatform : public Platform { public: CudaPlatform(); - ~CudaPlatform() override; // Platform interface implementation: // Returns the same value as kCudaPlatform above. diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 99d7f1ce5d83c8..195dd058aa64da 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -123,6 +123,18 @@ struct CudaComputeCapability { return !(*this == other); } + bool operator>(const CudaComputeCapability &other) const { + return ToPair() > other.ToPair(); + } + + bool operator>=(const CudaComputeCapability &other) const { + return ToPair() >= other.ToPair(); + } + + bool operator<=(const CudaComputeCapability &other) const { + return ToPair() <= other.ToPair(); + } + std::string ToString() const { return absl::StrCat(major, ".", minor); } std::pair ToPair() const { return std::make_pair(major, minor); } diff --git a/third_party/xla/xla/stream_executor/device_description_test.cc b/third_party/xla/xla/stream_executor/device_description_test.cc index 4600a7a04e97d5..ba65c78f8460c9 100644 --- a/third_party/xla/xla/stream_executor/device_description_test.cc +++ b/third_party/xla/xla/stream_executor/device_description_test.cc @@ -47,5 +47,35 @@ TEST(CudaComputeCapability, GenerationLiteralTest) { EXPECT_TRUE(CudaComputeCapability::Blackwell().IsAtLeast(10)); } +TEST(CudaComputeCapability, ComparisonTest) { + CudaComputeCapability lower{1, 0}; + CudaComputeCapability slightly_higher{1, 1}; + CudaComputeCapability higher{2, 0}; + + EXPECT_TRUE(lower == lower); + EXPECT_FALSE(lower == slightly_higher); + EXPECT_FALSE(lower == higher); + + EXPECT_TRUE(lower <= lower); + EXPECT_TRUE(lower < slightly_higher); + EXPECT_TRUE(lower <= slightly_higher); + + EXPECT_FALSE(lower < lower); + EXPECT_FALSE(slightly_higher <= lower); + EXPECT_FALSE(slightly_higher < lower); + + EXPECT_TRUE(slightly_higher >= slightly_higher); + EXPECT_TRUE(slightly_higher > lower); + EXPECT_TRUE(slightly_higher >= lower); + + EXPECT_FALSE(slightly_higher > slightly_higher); + EXPECT_FALSE(lower > slightly_higher); + EXPECT_FALSE(lower >= slightly_higher); + + EXPECT_TRUE(higher > slightly_higher); + EXPECT_TRUE(higher >= slightly_higher); + EXPECT_TRUE(higher >= higher); +} + } // namespace } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index f149e242e45fc0..0fdde6096ef5ad 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -58,8 +58,6 @@ class GpuExecutor : public StreamExecutorCommon { // Creates an EventBasedTimer for the given stream. virtual absl::StatusOr> CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) = 0; - static absl::StatusOr> - CreateDeviceDescription(int device_ordinal); // Frees unused memory cached on the device for use with graphs back to the // OS. diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 8d6c8b565f4af7..467ff4fa6e73cc 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -217,6 +217,14 @@ cc_library( alwayslink = True, ) +xla_test( + name = "rocm_executor_test", + srcs = ["rocm_executor_test.cc"], + backends = ["gpu_amd_any"], + tags = ["rocm-only"], + deps = [":rocm_executor"], +) + cc_library( name = "rocm_kernel", srcs = ["rocm_kernel.cc"], diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index ac8f253a955ef4..df7e55ab984b34 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -602,7 +602,7 @@ absl::Status RocmExecutor::TrimGraphMemory() { } absl::StatusOr> -GpuExecutor::CreateDeviceDescription(int device_ordinal) { +RocmExecutor::CreateDeviceDescription(int device_ordinal) { GpuDeviceHandle device; auto status = GpuDriver::GetDevice(device_ordinal, &device); if (!status.ok()) { diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.h b/third_party/xla/xla/stream_executor/rocm/rocm_executor.h index 840ce7b88a36bf..1ad30af30db4b8 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.h @@ -104,7 +104,7 @@ class RocmExecutor : public GpuExecutor { absl::StatusOr> CreateDeviceDescription() const override { - return GpuExecutor::CreateDeviceDescription(device_ordinal()); + return RocmExecutor::CreateDeviceDescription(device_ordinal()); } void* UnifiedMemoryAllocate(uint64_t size) override { return GpuDriver::UnifiedMemoryAllocate(gpu_context(), size); @@ -141,6 +141,9 @@ class RocmExecutor : public GpuExecutor { return it->second; } + static absl::StatusOr> + CreateDeviceDescription(int device_ordinal); + private: // Collects metadata for the specified kernel. absl::Status GetKernelMetadata(GpuKernel* rocm_kernel, diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc new file mode 100644 index 00000000000000..0716b5c3d0ee17 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_executor.h" + +#include +#include +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using testing::Field; +using testing::Ge; +using testing::IsEmpty; +using testing::Not; +using testing::VariantWith; + +TEST(RocmExecutorTest, CreateDeviceDescription) { + TF_ASSERT_OK(GpuDriver::Init()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + CudaExecutor::CreateDeviceDescription(0)); + + constexpr SemanticVersion kNullVersion{0, 0, 0}; + EXPECT_NE(result->runtime_version(), kNullVersion); + EXPECT_NE(result->driver_version(), kNullVersion); + EXPECT_NE(result->compile_time_toolkit_version(), kNullVersion); + + EXPECT_THAT(result->platform_version(), Not(IsEmpty())); + EXPECT_THAT(result->name(), Not(IsEmpty())); + EXPECT_THAT(result->model_str(), Not(IsEmpty())); + EXPECT_THAT(result->device_vendor(), "Advanced Micro Devices, Inc"); + + EXPECT_THAT(result->gpu_compute_capability(), + VariantWith( + Field("gcn_arch_name", &RocmComputeCapability::gcn_arch_name, + Not(IsEmpty())))); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc index a8414a142115cb..0284430523386f 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc @@ -30,8 +30,6 @@ namespace gpu { ROCmPlatform::ROCmPlatform() : name_("ROCM") {} -ROCmPlatform::~ROCmPlatform() {} - Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; } int ROCmPlatform::VisibleDeviceCount() const { @@ -49,7 +47,7 @@ const std::string& ROCmPlatform::Name() const { return name_; } absl::StatusOr> ROCmPlatform::DescriptionForDevice(int ordinal) const { - return GpuExecutor::CreateDeviceDescription(ordinal); + return RocmExecutor::CreateDeviceDescription(ordinal); } absl::StatusOr ROCmPlatform::ExecutorForDevice(int ordinal) { diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h index 6888b64532c0dd..408e7fa1a78e9c 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h @@ -39,7 +39,6 @@ extern const Platform::Id kROCmPlatformId; class ROCmPlatform : public Platform { public: ROCmPlatform(); - ~ROCmPlatform() override; // Platform interface implementation: // Returns the same value as kROCmPlatform above. From 60ddb99198949b561a48aa2cae2f82592bda88ee Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Mon, 30 Sep 2024 05:55:23 -0700 Subject: [PATCH 421/483] Refactor gemm_fusion_autotuner fusion rewriter nested if-else to use early return pattern. PiperOrigin-RevId: 680540454 --- .../gpu/autotuning/gemm_fusion_autotuner.cc | 115 +++++++++++------- 1 file changed, 70 insertions(+), 45 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index f18850be8ab3ee..51f68200fb2da9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -583,63 +583,88 @@ absl::Status RewriteGemmFusionToCustomKernelFusion( return pipeline.Run(hlo_module).status(); } +absl::Status HandleTritonGemm(HloInstruction* fusion_instr, + FusionBackendConfig& fusion_backend_config) { + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(fusion_backend_config.triton_gemm_config())); + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(fusion_instr, config)); + } + return absl::OkStatus(); +} + absl::Status GemmFusionAutotunerRewriterVisitor::HandleFusion( HloInstruction* fusion_instr) { TF_ASSIGN_OR_RETURN(auto gpu_config, fusion_instr->backend_config()); - FusionBackendConfig& backend_config = + FusionBackendConfig& fusion_backend_config = *gpu_config.mutable_fusion_backend_config(); - if (backend_config.kind() != kTritonGemmFusionKind && - backend_config.kind() != kCuDnnFusionKind && - backend_config.kind() != kCustomFusionKind) { + + // Only autotune Triton, cuDNN, and custom kernel fusions. + if (fusion_backend_config.kind() != kTritonGemmFusionKind && + fusion_backend_config.kind() != kCuDnnFusionKind && + fusion_backend_config.kind() != kCustomFusionKind) { return absl::OkStatus(); } - VLOG(4) << "Processing " << fusion_instr->ToString(); - if (!backend_config.has_triton_gemm_config() && - !backend_config.has_cudnn_fusion_config() && - !backend_config.has_custom_fusion_config()) { - TF_ASSIGN_OR_RETURN( - AutotuneResult autotune_result, - AutotunerUtil::Autotune( - fusion_instr, config_, [&]() -> absl::StatusOr { - if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( - "Expect autotune result cache hit for deviceless " - "compilation (HLO: ", - fusion_instr->ToString(), ")")); - } - return absl::InternalError("Expect autotune result cache hit."); - })); - VLOG(4) << "Result: " << autotune_result.ShortDebugString(); - - if (autotune_result.has_triton()) { - *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); - TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); - } else if (autotune_result.has_gemm()) { - TF_RETURN_IF_ERROR(RewriteGemmFusionToCall(fusion_instr)); - } else if (autotune_result.has_custom_kernel_fusion()) { - TF_RETURN_IF_ERROR(RewriteGemmFusionToCustomKernelFusion( - fusion_instr, config_.GetExecutor()->GetDeviceDescription(), - autotune_result.custom_kernel_fusion().kernel_index())); - } else { - CHECK(autotune_result.has_algorithm()); - backend_config.set_kind(std::string(kCuDnnFusionKind)); - backend_config.mutable_cudnn_fusion_config()->set_plan_id( - autotune_result.algorithm().algo_id()); - TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); - } + // Do not autotune if the backend config has already assigned tiling config. + if (fusion_backend_config.has_triton_gemm_config()) { + TF_RETURN_IF_ERROR(HandleTritonGemm(fusion_instr, fusion_backend_config)); + MarkAsChanged(); + return absl::OkStatus(); } - if (backend_config.has_triton_gemm_config()) { - TF_ASSIGN_OR_RETURN( - const TritonGemmConfig config, - TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(fusion_instr, config)); - } + // Do not autotune if the backend config has valid config. + if (fusion_backend_config.has_cudnn_fusion_config() || + fusion_backend_config.has_custom_fusion_config()) { + return absl::OkStatus(); + } + + VLOG(4) << "Autotuning fusion instruction: " << fusion_instr->ToString(); + TF_ASSIGN_OR_RETURN( + AutotuneResult autotune_result, + AutotunerUtil::Autotune( + fusion_instr, config_, [&]() -> absl::StatusOr { + if (config_.IsDeviceless()) { + return absl::InternalError(absl::StrCat( + "Expect autotune result cache hit for deviceless " + "compilation (HLO: ", + fusion_instr->ToString(), ")")); + } + return absl::InternalError("Expect autotune result cache hit."); + })); + VLOG(4) << "Autotuning result: " << autotune_result.ShortDebugString(); + + if (autotune_result.has_triton()) { + *fusion_backend_config.mutable_triton_gemm_config() = + autotune_result.triton(); + TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); + TF_RETURN_IF_ERROR(HandleTritonGemm(fusion_instr, fusion_backend_config)); + MarkAsChanged(); + return absl::OkStatus(); + } + + if (autotune_result.has_gemm()) { + TF_RETURN_IF_ERROR(RewriteGemmFusionToCall(fusion_instr)); + MarkAsChanged(); + return absl::OkStatus(); + } + + if (autotune_result.has_custom_kernel_fusion()) { + TF_RETURN_IF_ERROR(RewriteGemmFusionToCustomKernelFusion( + fusion_instr, config_.GetExecutor()->GetDeviceDescription(), + autotune_result.custom_kernel_fusion().kernel_index())); + MarkAsChanged(); + return absl::OkStatus(); } + // Autotune result has a cuDNN fusion. + CHECK(autotune_result.has_algorithm()); + fusion_backend_config.set_kind(std::string(kCuDnnFusionKind)); + fusion_backend_config.mutable_cudnn_fusion_config()->set_plan_id( + autotune_result.algorithm().algo_id()); + TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); MarkAsChanged(); return absl::OkStatus(); } From 304e5e4c4ee893dc466bbe7e430a7fadc4304b13 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Mon, 30 Sep 2024 05:58:11 -0700 Subject: [PATCH 422/483] Dynamic update slice optimizations: If the update tensor is the entirety of the output, then simply copy it and return. Replace the per element index calculation with a recursive function which memcpys the contiguous dimension. PiperOrigin-RevId: 680541144 --- .../lite/kernels/dynamic_update_slice.cc | 50 ++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/tensorflow/lite/kernels/dynamic_update_slice.cc b/tensorflow/lite/kernels/dynamic_update_slice.cc index 61a3f3d680df24..1bbf84e7804a81 100644 --- a/tensorflow/lite/kernels/dynamic_update_slice.cc +++ b/tensorflow/lite/kernels/dynamic_update_slice.cc @@ -105,6 +105,25 @@ std::vector ClampStartIndices(int input_dims, const int64_t* indices_data, return clamped_start_indices; } +template +void update_slice(int current_dim, int max_dim, const int32_t* output_stride, + const int32_t* update_stride, const int32_t* update_shape, + const T* update, const int32_t* indices_data, T* output) { + if (current_dim == max_dim) return; + if (current_dim == max_dim - 1) { + output += indices_data[current_dim] * output_stride[current_dim]; + memcpy(output, update, update_shape[max_dim - 1] * sizeof(T)); + } else { + output += indices_data[current_dim] * output_stride[current_dim]; + for (int i = 0; i < update_shape[current_dim]; ++i) { + update_slice(current_dim + 1, max_dim, output_stride, update_stride, + update_shape, update, indices_data, output); + output += output_stride[current_dim]; + update += update_stride[current_dim]; + } + } +} + template void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update, const int64_t* indices_data, TfLiteTensor* output) { @@ -114,6 +133,12 @@ void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update, T* output_data = GetTensorData(output); const int input_dims = input_shape.DimensionsCount(); + // If the update is the entirety of the output, then simply copy it and + // return. + if (input_shape.FlatSize() == update_shape.FlatSize()) { + memcpy(output_data, update_data, input_shape.FlatSize() * sizeof(T)); + return; + } // Computes the effective slice indices. // The clamped indices are gauranteed to >= 0 since update is less than or // equal to the operand size for each dimension. @@ -130,18 +155,19 @@ void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update, return; } - std::vector current_dim(input_dims, 0); - // Overwrites update to output. - do { - int flat_update_index = - TensorIndexToFlat(current_dim.data(), input_dims, update_shape); - int flat_input_index = - TensorIndexToFlat(current_dim.data(), input_dims, input_shape, - clamped_start_indices.data()); - output_data[flat_input_index] = update_data[flat_update_index]; - } while (NextIndex(input_dims, - reinterpret_cast(update_shape.DimsData()), - current_dim.data())); + std::vector output_stride(input_dims); + std::vector update_stride(input_dims); + output_stride[input_dims - 1] = 1; + update_stride[input_dims - 1] = 1; + const int32_t* input_shape_data = input_shape.DimsData(); + const int32_t* update_shape_data = update_shape.DimsData(); + for (int i = input_dims - 2; i >= 0; --i) { + output_stride[i] = output_stride[i + 1] * input_shape_data[i + 1]; + update_stride[i] = update_stride[i + 1] * update_shape_data[i + 1]; + } + update_slice(0, input_dims, output_stride.data(), update_stride.data(), + update_shape.DimsData(), update_data, + clamped_start_indices.data(), output_data); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { From 4aff90f25d8627088350af2012bd7f9f3f6bdc5d Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Mon, 30 Sep 2024 07:00:04 -0700 Subject: [PATCH 423/483] [XLA:GPU] Add support for TF32_TF32_F32_X3 algorithm Triton supports this algorithm directly. cuBLAS does not support it. The Triton version is slower than the cublas f32, but we don't use cublas because we need to keep the precision guarantees. PiperOrigin-RevId: 680558821 --- third_party/xla/xla/service/algorithm_util.cc | 2 + .../gpu/autotuning/gemm_fusion_autotuner.cc | 4 +- .../autotuning/gemm_fusion_autotuner_test.cc | 115 +++++++++++++----- .../fusions/triton/kernel_name_tracer_cuda.cc | 2 +- .../fusions/triton/triton_fusion_emitter.cc | 39 +++--- ...riton_fusion_emitter_device_legacy_test.cc | 105 +++++++++++++--- .../fusions/triton/triton_support_legacy.cc | 1 + .../xla/service/gpu/transforms/gemm_fusion.cc | 1 + 8 files changed, 199 insertions(+), 70 deletions(-) diff --git a/third_party/xla/xla/service/algorithm_util.cc b/third_party/xla/xla/service/algorithm_util.cc index 85380d2b2ef72c..0f405f5aad0372 100644 --- a/third_party/xla/xla/service/algorithm_util.cc +++ b/third_party/xla/xla/service/algorithm_util.cc @@ -47,6 +47,7 @@ absl::StatusOr GetBlasComputationType( case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: case PrecisionConfig::ALG_DOT_F16_F16_F32: case PrecisionConfig::ALG_DOT_F32_F32_F32: + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: return se::blas::ComputationType::kF32; case PrecisionConfig::ALG_DOT_TF32_TF32_F32: return se::blas::ComputationType::kTF32AsF32; @@ -188,6 +189,7 @@ bool IsSupportedDotAlgorithmOnGpu( case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && input_storage_type == F32 && output_storage_type == F32; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: case PrecisionConfig::ALG_DOT_TF32_TF32_F32: return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && input_storage_type == F32 && output_storage_type == F32; diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 51f68200fb2da9..688f7090265b2a 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -337,7 +337,9 @@ absl::StatusOr> CublasGemmAutotuneExtractor( if (dot->precision_config().algorithm() == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || dot->precision_config().algorithm() == - PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) { + PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || + dot->precision_config().algorithm() == + PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { dot->mutable_precision_config()->set_algorithm( PrecisionConfig::ALG_DOT_F32_F32_F32); } diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 8a43da38f52e4a..b13da144e32a56 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -50,7 +50,6 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/filecheck.h" @@ -177,8 +176,82 @@ class StatelessAutotunerTest : public HloTestBase { AutotunerUtil::ClearAutotuneResults(); HloTestBase::TearDown(); } + + absl::StatusOr> + GetPossibleMatmulAutotuneConfigs( + const HloModule& module, + const se::CudaComputeCapability& compute_capability, + const se::SemanticVersion& toolkit_version, + const DebugOptions& debug_options) { + const HloFusionInstruction& fusion = *Cast( + module.entry_computation()->root_instruction()); + se::GpuDeviceInfoProto deviceless_proto; + auto ccc = deviceless_proto.mutable_cuda_compute_capability(); + ccc->set_major(compute_capability.major); + ccc->set_minor(compute_capability.minor); + + DeviceConfig test_config{backend().default_stream_executor(), + backend().memory_allocator()}; + AutotuneConfig autotune_config{test_config, debug_options}; + GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, + debug_options, nullptr); + return autotuner.GenerateConfigs(fusion); + } }; +TEST_F(StatelessAutotunerTest, NoCublasFallbackForTf32Tf32F32X3Algorithm) { + // There is no cublas implementation for dot_tf32_tf32_f32_x3 at the moment. + // At the same time cublas f32 is faster than triton for this algorithm. + // But we don't want to fallback to cuBLAS in this case because we lose the + // precision guarantees. + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + computation { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] dot(p0, p1), + algorithm=dot_tf32_tf32_f32_x3, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + + ENTRY main { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT computation = f32[1024,1024] fusion(f32[1024,1024] p0,f32[1024,1024] p1), + kind=kCustom, + calls=computation + } + )")); + + const se::CudaComputeCapability ampere{se::CudaComputeCapability::AMPERE, + /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + auto ampere_configs, + GetPossibleMatmulAutotuneConfigs(*module, ampere, GetToolkitVersion(), + GetDebugOptionsForTest())); + EXPECT_FALSE(std::any_of( + ampere_configs.begin(), ampere_configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative( + config); + })); + + const se::CudaComputeCapability hopper{se::CudaComputeCapability::HOPPER, + /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + auto hopper_configs, + GetPossibleMatmulAutotuneConfigs(*module, hopper, GetToolkitVersion(), + GetDebugOptionsForTest())); + EXPECT_FALSE(std::any_of( + hopper_configs.begin(), hopper_configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative( + config); + })); +} + class GemmFusionAutotunerTest : public StatelessAutotunerTest { public: DebugOptions GetDebugOptionsForTest() override { @@ -197,25 +270,6 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { .cuda_compute_capability(); } - absl::StatusOr> - GetPossibleMatmulAutotuneConfigs( - const HloFusionInstruction& fusion, - const se::CudaComputeCapability& compute_capability, - const se::SemanticVersion& toolkit_version, - const DebugOptions& debug_options) { - se::GpuDeviceInfoProto deviceless_proto; - auto ccc = deviceless_proto.mutable_cuda_compute_capability(); - ccc->set_major(compute_capability.major); - ccc->set_minor(compute_capability.minor); - - DeviceConfig test_config{backend().default_stream_executor(), - backend().memory_allocator()}; - AutotuneConfig autotune_config{test_config, debug_options}; - GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, - debug_options, nullptr); - return autotuner.GenerateConfigs(fusion); - } - void CheckTritonAutotuning(absl::string_view hlo, absl::string_view expected) { HloPassPipeline pipeline("gemm_rewrite"); @@ -1046,10 +1100,9 @@ TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + GetPossibleMatmulAutotuneConfigs(*module, compute_capability, + GetToolkitVersion(), + GetDebugOptionsForTest())); EXPECT_TRUE(std::any_of( configs.begin(), configs.end(), [](const GemmFusionAutotunerImpl::BackendConfig& config) { @@ -1087,10 +1140,9 @@ TEST_F(GemmFusionAutotunerTest, GeneratesConfigForUpcastGemmWithPrologue) { TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + GetPossibleMatmulAutotuneConfigs(*module, compute_capability, + GetToolkitVersion(), + GetDebugOptionsForTest())); EXPECT_TRUE(std::any_of( configs.begin(), configs.end(), [](const GemmFusionAutotunerImpl::BackendConfig& config) { @@ -1133,10 +1185,9 @@ TEST_F(GemmFusionAutotunerTest, TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + GetPossibleMatmulAutotuneConfigs(*module, compute_capability, + GetToolkitVersion(), + GetDebugOptionsForTest())); EXPECT_TRUE(std::any_of( configs.begin(), configs.end(), [](const GemmFusionAutotunerImpl::BackendConfig& config) { diff --git a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc index cd8830542cfec5..5074c8c44f0be0 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc @@ -36,7 +36,7 @@ class KernelNameTracerCuda : public KernelNameTracer { std::string stop() override; private: - std::unique_ptr cupti_tracer_; + profiler::CuptiTracer* cupti_tracer_; // Not owned. std::unique_ptr cupti_collector_; }; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index ccaad403bd954d..69bd809d345b55 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -2208,6 +2208,27 @@ bool IsTf32Allowed(const HloDotInstruction* dot_instr) { return algorithm_util::HasTf32InputType(algorithm); } +mt::InputPrecision InferDotPrecision(const HloDotInstruction* dot_instr) { + auto algorithm = dot_instr->precision_config().algorithm(); + if (algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { + return mt::InputPrecision::TF32x3; + } + // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. + bool is_unsupported_bitwidth = + HloBfsAnyOf({dot_instr}, [&](const HloInstruction* node) { + if (node->opcode() != HloOpcode::kConvert) { + return false; + } + int in_width = + primitive_util::BitWidth(node->operand(0)->shape().element_type()); + return in_width <= 8 && node->shape().element_type() == F32; + }); + + return IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth + ? mt::InputPrecision::TF32 + : mt::InputPrecision::IEEE; +} + bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, mlir::OpBuilder& builder, Value dot_input_lhs, Value dot_input_rhs, @@ -2537,17 +2558,6 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, const HloInstruction* root = dot_instr->parent()->root_instruction(); TF_RET_CHECK(!root->shape().IsTuple()); - // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. - bool is_unsupported_bitwidth = - HloBfsAnyOf({dot_instr}, [&](const HloInstruction* node) { - if (node->opcode() != HloOpcode::kConvert) { - return false; - } - int in_width = - primitive_util::BitWidth(node->operand(0)->shape().element_type()); - return in_width <= 8 && node->shape().element_type() == F32; - }); - // We'll be creating a lot of instructions from a single dot, use an // implicit loc builder so we don't have to pass around the location all the // time. @@ -2700,10 +2710,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a // lower precision than the output type. The change was introduced here: // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a - auto input_precision = - IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth - ? mt::InputPrecision::TF32 - : mt::InputPrecision::IEEE; + auto dot_precision = InferDotPrecision(dot_instr); // Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32. if (dot_instr->precision_config().algorithm() == @@ -2723,7 +2730,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, IsFp8Matmul(dot_instr) ? std::numeric_limits::max() : 0; accumulator_next = b.create(dot_input_lhs, dot_input_rhs, iter_args.back(), - /*inputPrecision=*/input_precision, + /*inputPrecision=*/dot_precision, /*maxNumImpreciseAcc=*/max_num_imprecise_acc); } iter_args_next.push_back(accumulator_next); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 036e3221c71bcb..983c5db3ba56c5 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -147,7 +147,7 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { } }; -class TritonBF16BF16F32BlasTest : public TritonTest { +class BlasAlgorithmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); @@ -157,20 +157,16 @@ class TritonBF16BF16F32BlasTest : public TritonTest { debug_options.set_xla_gpu_enable_triton_gemm(false); return debug_options; } - - protected: - void SetUp() override { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; - } - } }; -TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) { +TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32) { // We check that the algorithm is propagated to the BLAS call. // We also check that the kernel name matches the algorithm for Ampere. // The algorithm for Hopper is not the one we expect because it uses TF32. + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } constexpr std::string_view kHloText = R"( HloModule t @@ -185,6 +181,8 @@ TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) { )"; const std::string pattern = R"(CHECK: "algorithm":"ALG_DOT_BF16_BF16_F32")"; TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + ASSERT_TRUE(ok); auto tracer = KernelNameTracer::Create(); tracer->start(); @@ -215,6 +213,57 @@ TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) { } } +TEST_F(BlasAlgorithmTest, Algorithm_TF32_TF32_F32_X3) { + // We check that the algorithm is propagated to the BLAS call. + // We also check that the kernel name matches the algorithm for Ampere. + + constexpr std::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = f32[8512,256]{1,0} parameter(0) + rhs = f32[256,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_tf32_tf32_f32_x3, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "algorithm":"ALG_DOT_TF32_TF32_F32_X3")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + ASSERT_TRUE(ok); + + auto tracer = KernelNameTracer::Create(); + tracer->start(); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); + auto kernel_name = tracer->stop(); + + if (kernel_name == "kernel_name_tracer_not_implemented") return; + + auto cc = GetCudaComputeCapability(); + using CudaComputeCapabilities = + stream_executor::CudaComputeCapability::CudaComputeCapabilities; + switch (cc.major) { + case CudaComputeCapabilities::BLACKWELL: + GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: " + << kernel_name; + break; + case CudaComputeCapabilities::AMPERE: + // There is no support for TF32_TF32_F32_X3 on Ampere. We use F32_F32_F32. + EXPECT_THAT(kernel_name, ::testing::HasSubstr("ampere_sgemm_128x64_nn")); + break; + case CudaComputeCapabilities::HOPPER: + // There is no support for TF32_TF32_F32_X3 on Hopper. We use F32_F32_F32. + EXPECT_THAT(kernel_name, ::testing::HasSubstr("gemm_f32f32_f32f32_f32")); + break; + default: + GTEST_SKIP() << "Unsupported compute capability: " << cc.major + << " has the kernel name: " << kernel_name; + } +} + TEST_F(TritonGemmTest, RejectDotInt4HLO) { constexpr std::string_view kHloText = R"( HloModule t @@ -5234,7 +5283,7 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 /*arel=*/1e-5})); } -class TritonBF16BF16F32GemmTest : public TritonTest { +class TritonAlgorithmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); @@ -5245,23 +5294,39 @@ class TritonBF16BF16F32GemmTest : public TritonTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); return debug_options; } +}; - protected: - void SetUp() override { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; +TEST_F(TritonAlgorithmTest, Algorithm_TF32_TF32_F32_X3) { + const std::string kHloText = R"( + HloModule t + + ENTRY main { + lhs = f32[8512,64]{1,0} parameter(0) + rhs = f32[64,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_tf32_tf32_f32_x3, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} } - } -}; + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} -TEST_F(TritonBF16BF16F32GemmTest, WorkWithF32InputAndAlgorithm_BF16_BF16_F32) { +TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32) { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule t ENTRY main { - lhs = f32[32,64]{1,0} parameter(0) - rhs = f32[64,16]{1,0} parameter(1) - ROOT dot = f32[32,16]{1,0} dot(lhs, rhs), + lhs = f32[8512,64]{1,0} parameter(0) + rhs = f32[64,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), algorithm=dot_bf16_bf16_f32, lhs_contracting_dims={1}, rhs_contracting_dims={0} diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc index 97c4891441d98c..dc8740807e2fc5 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -220,6 +220,7 @@ bool IsDotAlgorithmSupportedByTriton( auto rocm_compute_capability = std::get_if(&gpu_version); switch (algorithm) { + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: case PrecisionConfig::ALG_DOT_TF32_TF32_F32: if (cuda_compute_capability) { return true; diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc index 1934f59c48e24f..c52c44ab514c32 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc @@ -741,6 +741,7 @@ absl::StatusOr CreateDotFusion( if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32 || + algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 || dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() || dot.sparse_operands()) { return Decision::Allow(); From 19546c965903451ad4d30af282fdd264a7add2b9 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 30 Sep 2024 07:13:30 -0700 Subject: [PATCH 424/483] [XLA:GPU] Use metadata to print and parse indexing maps. PiperOrigin-RevId: 680562614 --- .../xla/service/gpu/fusions/fusion_emitter.cc | 25 ++- .../xla/service/gpu/fusions/ir/xla_gpu_ops.cc | 2 +- .../gpu/fusions/ir/xla_gpu_ops_test.cc | 8 +- .../xla/xla/service/gpu/fusions/legacy/BUILD | 5 +- .../gpu/fusions/legacy/concatenate_test.cc | 12 +- .../in_place_dynamic_update_slice_test.cc | 2 +- .../gpu/fusions/legacy/input_slices_test.cc | 2 +- .../service/gpu/fusions/legacy/loop_test.cc | 44 ++--- .../service/gpu/fusions/legacy/reduction.cc | 12 +- .../gpu/fusions/legacy/reduction_test.cc | 40 ++--- .../gpu/fusions/legacy/scatter_test.cc | 16 +- .../service/gpu/fusions/legacy/tiling_util.cc | 8 +- .../gpu/fusions/legacy/transpose_test.cc | 158 +++++++++--------- .../xla/service/gpu/fusions/reduction_mlir.cc | 17 +- .../fusions/tests/concatenate/concat_1d.hlo | 8 +- .../fusions/tests/scatter/unique_indices.hlo | 2 +- .../xla/service/gpu/fusions/transpose_mlir.cc | 2 +- .../service/gpu/model/indexing_analysis.cc | 15 +- .../gpu/model/indexing_analysis_test.cc | 36 ++-- .../xla/xla/service/gpu/model/indexing_map.cc | 70 +++++--- .../xla/xla/service/gpu/model/indexing_map.h | 25 ++- .../gpu/model/indexing_map_serialization.cc | 155 +++++++++++------ .../gpu/model/indexing_map_serialization.h | 3 +- .../model/indexing_map_serialization_test.cc | 28 +--- .../service/gpu/model/indexing_map_test.cc | 102 ++++++----- .../service/gpu/model/indexing_test_utils.cc | 4 +- .../gpu/model/symbolic_tile_analysis_test.cc | 6 +- 27 files changed, 443 insertions(+), 364 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index 432d600701d1ab..4691c9753b361f 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -168,21 +168,20 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap( divisor *= shape.dimensions(dimension); } - std::vector dim_vars = { - {{0, static_cast(launch_dims.thread_counts_per_block().x) - 1}}, - {{0, static_cast(launch_dims.thread_counts_per_block().y) - 1}}, - {{0, static_cast(launch_dims.thread_counts_per_block().z) - 1}}, - {{0, static_cast(launch_dims.block_counts().x) - 1}}, - {{0, static_cast(launch_dims.block_counts().y) - 1}}, - {{0, static_cast(launch_dims.block_counts().z) - 1}}, - }; + std::vector dim_vars = DimVarsFromGPUGrid( + {static_cast(launch_dims.thread_counts_per_block().x), + static_cast(launch_dims.thread_counts_per_block().y), + static_cast(launch_dims.thread_counts_per_block().z), + static_cast(launch_dims.block_counts().x), + static_cast(launch_dims.block_counts().y), + static_cast(launch_dims.block_counts().z)}); std::vector range_vars; int64_t num_elements = ShapeUtil::ElementsIn(shape); - range_vars.push_back( - {{0, CeilOfRatio(num_elements, - static_cast(launch_dims.launch_bound()) * - unroll_factor) - - 1}}); + range_vars.push_back(RangeVar{ + {0, CeilOfRatio(num_elements, + static_cast(launch_dims.launch_bound()) * + unroll_factor) - + 1}}); range_vars.push_back({0, unroll_factor - 1}); IndexingMap indexing_map( mlir::AffineMap::get(/*dimCount=*/6, diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index 2aa00180e326b1..0692df6bc2c5eb 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -980,7 +980,7 @@ LogicalResult MaterializeOp::verify() { return emitOpError() << "must have thread_id dimension in both indexing maps"; } - if (map_in.GetDimVars(0) != map_out.GetDimVars(0)) { + if (map_in.GetDimVars(0).bounds != map_out.GetDimVars(0).bounds) { return emitOpError() << "thread_id dimension must have the same bounds in " "both indexing maps"; } diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc index 2d9076d7803280..d33680ade36f40 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc @@ -36,8 +36,8 @@ class XLAGPUOpsTest : public HloTestBase { TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) { auto map = IndexingMap( ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_), - /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}}, - /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{}); + /*dimensions=*/{{DimVar{0, 5}}, {DimVar{0, 2}}}, + /*range_vars=*/{{RangeVar{0, 32}}, {RangeVar{0, 1024}}}, /*rt_vars=*/{}); map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), Interval{0, 1}); map.AddConstraint(ParseAffineExpr("s1 mod 4", &mlir_context_), @@ -71,8 +71,8 @@ TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) { TEST_F(XLAGPUOpsTest, GetConstraintsForVariablesEmpty) { auto map = IndexingMap( ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_), - /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}}, - /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{}); + /*dimensions=*/{{DimVar{0, 5}}, {DimVar{0, 2}}}, + /*range_vars=*/{{RangeVar{0, 32}}, {RangeVar{0, 1024}}}, /*rt_vars=*/{}); auto constraints_for_variables = GetConstraintsForVariables(map); EXPECT_THAT(constraints_for_variables.constraints_for_dims, ElementsAre(IsEmpty(), IsEmpty())); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD index 8b9f3a34073441..ff0eca7cf0f8d0 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD @@ -205,6 +205,7 @@ cc_library( "//xla/service/gpu/fusions:fusion_emitter", "//xla/service/gpu/fusions:reduction_base", "//xla/service/gpu/fusions:thunk_util", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/gpu/runtime:thunk", "//xla/service/llvm_ir:fused_ir_emitter", @@ -237,17 +238,13 @@ xla_cc_test( srcs = ["reduction_test.cc"], deps = [ ":reduction", - "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu/fusions:fusion_emitter", "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc index ce7da7bcb22485..a253c3bc07a9e9 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc @@ -84,23 +84,23 @@ TEST_F(ConcatenateTest, ThreadIndexing) { unroll_id in [0, 0], bl_x * 128 + th_x in [0, 399] )"; + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_), - {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), + dim_names, range_names, {}), MatchIndexingString(kIndexing)); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_), - {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), + dim_names, range_names, {}), MatchIndexingString(kIndexing)); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_), - {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), + dim_names, range_names, {}), MatchIndexingString(kIndexing)); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc index 27d3aa2170be3f..2c5e64bd84e5ce 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc @@ -76,7 +76,7 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_); EXPECT_THAT(ToString(*thread_id_update_indexing, {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), + {"chunk_id", "unroll_id"}, {}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( th_x floordiv 6, th_x mod 6), diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc index 9de13b8bd7df5c..45a917ef9b1355 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc @@ -73,7 +73,7 @@ TEST_F(InputSlicesTest, ThreadIndexing) { fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); EXPECT_THAT(ToString(*thread_id_to_output_indexing, {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), + {"chunk_id", "unroll_id"}, {}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0, ((bl_x * 128 + th_x) floordiv 3) mod 2, diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc index b23e9b2b19a213..b167353f93f157 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc @@ -75,10 +75,12 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(ToString(*thread_id_to_output_indexing, - {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), - MatchIndexingString(R"( + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; + EXPECT_THAT( + ToString(*thread_id_to_output_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( (bl_x * 128 + th_x) floordiv 15000, ((bl_x * 128 + th_x) floordiv 75) mod 200, @@ -119,10 +121,12 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(ToString(*thread_id_to_output_indexing, - {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), - MatchIndexingString(R"( + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; + EXPECT_THAT( + ToString(*thread_id_to_output_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x), domain: th_x in [0, 19], @@ -137,10 +141,9 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(ToString(*thread_id_to_input_indexing, - {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), - MatchIndexingString(R"( + EXPECT_THAT( + ToString(*thread_id_to_input_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x), domain: th_x in [0, 19], @@ -176,10 +179,12 @@ TEST_F(LoopTest, Broadcast) { auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(ToString(*thread_id_to_output_indexing, - {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), - MatchIndexingString(R"( + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; + EXPECT_THAT( + ToString(*thread_id_to_output_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( (bl_x * 128 + th_x) floordiv 600, ((bl_x * 128 + th_x) floordiv 30) mod 20, @@ -198,10 +203,9 @@ TEST_F(LoopTest, Broadcast) { auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(ToString(*thread_id_to_input_indexing, - {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}), - MatchIndexingString(R"( + EXPECT_THAT( + ToString(*thread_id_to_input_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (((bl_x * 128 + th_x) floordiv 30) mod 20), domain: diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc index e009ea18e0b48c..c0d862f004b79f 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc @@ -63,6 +63,7 @@ limitations under the License. #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/gpu/reduction_utils.h" #include "xla/service/gpu/runtime/kernel_thunk.h" @@ -1223,14 +1224,9 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( auto physical_shape = ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape()); - std::vector dimension_ranges{ - {{0, tiling_.GetNumThreadsPerBlock() - 1}}, - {}, - {}, - {{0, tiling_.GetNumBlocks() - 1}}, - {{0, static_cast(groups_.grouped_roots.size() - 1)}}, - {}, - }; + std::vector dimension_ranges = DimVarsFromGPUGrid( + {tiling_.GetNumThreadsPerBlock(), 1, 1, tiling_.GetNumBlocks(), + static_cast(groups_.grouped_roots.size()), 1}); constexpr int kRowKept = ReductionDimensions::kRowKeptDimension; constexpr int kRowMinorReduced = diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc index 54fff94a6ed775..8109a7b1c4068d 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc @@ -71,18 +71,18 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { EXPECT_THAT( ToString(*fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( - d3 floordiv 8, - (d3 mod 8) * 8 + d0 floordiv 32, - (d0 mod 32) * 2 + s2 * 64 + s3 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2, s3] -> ( + bl_x floordiv 8, + (bl_x mod 8) * 8 + th_x floordiv 32, + (th_x mod 32) * 2 + s2 * 64 + s3 ), domain: - d0 in [0, 255], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 799], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 255], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 799], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 0], s2 in [0, 7], @@ -91,18 +91,18 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { EXPECT_THAT( ToString(*fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5) -> ( - d3 floordiv 8, - (d3 mod 8) * 8 + d0 floordiv 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z) -> ( + bl_x floordiv 8, + (bl_x mod 8) * 8 + th_x floordiv 32 ), domain: - d0 in [0, 224], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 799], - d4 in [0, 0], - d5 in [0, 0], - d0 mod 32 in [0, 0] + th_x in [0, 224], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 799], + bl_y in [0, 0], + bl_z in [0, 0], + th_x mod 32 in [0, 0] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc index e7d1d8eae303c9..9c6587131d93e4 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc @@ -160,29 +160,29 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { )"; mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}; - mlir::SmallVector symbol_names = {"chunk_id", "unroll_id"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_), - dim_names, symbol_names), + dim_names, range_names, {}), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_), - dim_names, symbol_names), + dim_names, range_names, {}), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_), - dim_names, symbol_names), + dim_names, range_names, {}), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_), - dim_names, symbol_names), + dim_names, range_names, {}), MatchIndexingString(kUpdatesIndexing)); - symbol_names.push_back("index_id"); + range_names.push_back("index_id"); constexpr auto kIndicesIndexing = R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ((bl_x * 128 + th_x) floordiv 200, 0), @@ -201,12 +201,12 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_), - dim_names, symbol_names), + dim_names, range_names, {}), MatchIndexingString(kIndicesIndexing)); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_), - dim_names, symbol_names), + dim_names, range_names, {}), MatchIndexingString(kIndicesIndexing)); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc index a1a7acb58388a7..1de211fea1f52e 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" @@ -334,7 +335,12 @@ IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, offsets.push_back(block + thread); } std::vector dimension_ranges{ - {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {}, + DimVar{0, threads_per_block - 1, ToVariableName(VariableKind::kThreadX)}, + DimVar{0, 0, ToVariableName(VariableKind::kThreadY)}, + DimVar{0, 0, ToVariableName(VariableKind::kThreadZ)}, + DimVar{0, num_blocks - 1, ToVariableName(VariableKind::kBlockX)}, + DimVar{0, 0, ToVariableName(VariableKind::kBlockY)}, + DimVar{0, 0, ToVariableName(VariableKind::kBlockZ)}, }; auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(), block_offsets.getNumSymbols(), offsets, diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc index 1e503025d889d3..318abc04628779 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc @@ -80,18 +80,18 @@ TEST_F(TransposeTest, ThreadIndexing021) { EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4, - (d3 mod 2) * 32 + d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + bl_x floordiv 2, + th_x floordiv 32 + s1 * 4, + (bl_x mod 2) * 32 + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], @@ -100,18 +100,18 @@ TEST_F(TransposeTest, ThreadIndexing021) { EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32, - d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + bl_x floordiv 2, + (bl_x mod 2) * 32 + s1 * 4 + th_x floordiv 32, + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], @@ -142,18 +142,18 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( 0, - d3 * 32 + s1 * 4 + d0 floordiv 32, - d0 mod 32 + bl_x * 32 + s1 * 4 + th_x floordiv 32, + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], @@ -162,18 +162,18 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( 0, - d0 floordiv 32 + s1 * 4, - d3 * 32 + d0 mod 32 + th_x floordiv 32 + s1 * 4, + bl_x * 32 + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], @@ -206,42 +206,42 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d0 floordiv 32 + s0 * 4, - d3, - d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + th_x floordiv 32 + s0 * 4, + bl_x, + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 1], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 1], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 5], s1 in [0, 0], s2 in [0, 0], - d0 mod 32 in [0, 23] + th_x mod 32 in [0, 23] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d0 floordiv 32 + s0 * 4, - d3, - d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + th_x floordiv 32 + s0 * 4, + bl_x, + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 1], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 1], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 5], s1 in [0, 0], s2 in [0, 0], - d0 mod 32 in [0, 23] + th_x mod 32 in [0, 23] )")); } @@ -302,17 +302,17 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + bl_x floordiv 2, + th_x floordiv 32 + s1 * 4 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], @@ -321,18 +321,18 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4, - (d3 mod 2) * 32 + d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + bl_x floordiv 2, + th_x floordiv 32 + s1 * 4, + (bl_x mod 2) * 32 + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index 7897966a91409a..f25caaf155c225 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -404,12 +404,11 @@ IndexingMap MlirReductionFusion::GetIndexingMap( absl::Span symbol_sizes) const { auto* ctx = results.front().getContext(); auto num_groups = static_cast(reduction_heroes_.size()); - return IndexingMap{ - AffineMap::get(6, symbol_sizes.size(), results, ctx), - DimVarsFromTensorSizes( - {Product(num_threads_), 1, 1, Product(num_blocks_), num_groups, 1}), - RangeVarsFromTensorSizes(symbol_sizes), - /*rt_vars=*/{}}; + return IndexingMap{AffineMap::get(6, symbol_sizes.size(), results, ctx), + DimVarsFromGPUGrid({Product(num_threads_), 1, 1, + Product(num_blocks_), num_groups, 1}), + RangeVarsFromTensorSizes(symbol_sizes), + /*rt_vars=*/{}}; } IndexingMap MlirReductionFusion::GetThreadIndexingMap( @@ -419,9 +418,11 @@ IndexingMap MlirReductionFusion::GetThreadIndexingMap( auto affine_map = AffineMap::get(1, symbol_sizes.size(), results, results.front().getContext()); return IndexingMap{affine_map, - DimVarsFromTensorSizes({Product(num_threads_)}), + {DimVar{0, Product(num_threads_) - 1, + ToVariableName(VariableKind::kThreadX)}}, RangeVarsFromTensorSizes(symbol_sizes), - /*rt_vars=*/{}, constraints}; + /*rt_vars=*/{}, + constraints}; } LaunchDimensions MlirReductionFusion::launch_dimensions() const { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo index 5ac91b201c6168..875bb871d287e3 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo @@ -8,10 +8,10 @@ fusion { param2 = f32[300] parameter(2) ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} } -// CHECK-DAG: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 128 + d0) -// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0) -// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0 + 200) -// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0 + 600) +// CHECK-DAG: #[[MAP:.*]] = #xla_gpu.indexing_map<"(th_x, bl_x) -> (bl_x * 128 + th_x) +// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x) +// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 200) +// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 600) // CHECK: func.func @main // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo index 88043829ebc8f2..8abb0d548d1c06 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo @@ -24,7 +24,7 @@ scatter { unique_indices=true, to_apply=add } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 2) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(th_x) -> (th_x floordiv 2) // CHECK-LABEL: func.func @main( // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index b7e409151541c9..cea59e095e356d 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -215,7 +215,7 @@ IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing( dim_var_sizes[KernelFusionInterface::kIndexingMapBlockIdxDims[0]] = Product(block_counts_); return {mlir::AffineMap::get(6, 2, thread_offsets, ctx), - DimVarsFromTensorSizes(dim_var_sizes), + DimVarsFromGPUGrid(dim_var_sizes), RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}), {}}; } diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index a8f70449e10f9b..20721312b3facd 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -458,8 +458,8 @@ IndexingMap ComputeOutputToInputPadOpIndexingImpl( llvm::zip(output_dims, padding_low, padding_high, padding_interior)) { AffineExpr dim_expr = getAffineDimExpr(output_dim_id, mlir_context); dim_vars.push_back( - {Interval{std::max(int64_t{0}, pad_low), - std::min(output_dim - 1, output_dim - 1 - pad_high)}}); + {DimVar{std::max(int64_t{0}, pad_low), + std::min(output_dim - 1, output_dim - 1 - pad_high)}}); if (pad_interior == 0) { exprs.push_back(dim_expr - pad_low); } else { @@ -624,8 +624,8 @@ IndexingMap ComposeIndexingMapsForWindow( exprs.push_back(symbol_expr * window_config.window_dilation() + window_config.stride() * dim_expr); - dim_vars.push_back({Interval{0, output_dimensions[dim_id] - 1}}); - range_vars.push_back({Interval{0, window_config.size() - 1}}); + dim_vars.push_back({DimVar{0, output_dimensions[dim_id] - 1}}); + range_vars.push_back({RangeVar{0, window_config.size() - 1}}); } // Indexing map for pad op that pads the input. IndexingMap padded_input_indexing = ComputeOutputToInputPadOpIndexingImpl( @@ -746,8 +746,8 @@ HloInstructionIndexing ComputeOutputToInputConvolutionOpIndexing( int64_t input_group_size = kernel_shape.dimensions(dnums.kernel_input_feature_dimension()); Interval input_feature_range{0, input_group_size - 1}; - input_symbols.push_back({input_feature_range}); - kernel_symbols.push_back({input_feature_range}); + input_symbols.push_back(RangeVar{input_feature_range}); + kernel_symbols.push_back(RangeVar{input_feature_range}); // With multiple feature groups, the input feature dimension is equally split. if (convolution->feature_group_count() > 1) { @@ -768,7 +768,8 @@ HloInstructionIndexing ComputeOutputToInputConvolutionOpIndexing( output_shape.dimensions(dnums.output_batch_dimension()); AffineExpr batch_group_expr = getAffineSymbolExpr(input_symbols.size(), mlir_context); - input_symbols.push_back({{0, convolution->batch_group_count() - 1}}); + input_symbols.push_back( + RangeVar{{0, convolution->batch_group_count() - 1}}); input_exprs[dnums.input_batch_dimension()] = batch_group_expr * batch_group_size + batch_dim_expr; } else { diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index be577ffa434b02..0b55ac570b2a7d 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -668,18 +668,18 @@ TEST_F(IndexingAnalysisTest, DynamicSliceOp) { )")); EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( operand id = 0 - (d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2), + (d0, d1, d2)[rt0, rt1, rt2] -> (d0 + rt0, d1 + rt1, d2 + rt2), domain: d0 in [0, 0], d1 in [0, 1], d2 in [0, 31], - s0 in [0, 1], + rt0 in [0, 1], hlo: %of1 = s32[] parameter(1), (d0, d1, d2) -> (), - s1 in [0, 0], + rt1 in [0, 0], hlo: %of2 = s32[] parameter(2), (d0, d1, d2) -> (), - s2 in [0, 226], + rt2 in [0, 226], hlo: %of3 = s32[] parameter(3), (d0, d1, d2) -> () operand id = 1 @@ -722,14 +722,14 @@ TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { d0 in [0, 19], d1 in [0, 29] operand id = 1 - (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1), + (d0, d1)[rt0, rt1] -> (d0 - rt0, d1 - rt1), domain: d0 in [0, 19], d1 in [0, 29], - s0 in [0, 15], + rt0 in [0, 15], hlo: %of1 = s32[] parameter(2), (d0, d1) -> (), - s1 in [0, 20], + rt1 in [0, 20], hlo: %of2 = s32[] parameter(3), (d0, d1) -> () operand id = 2 @@ -1053,16 +1053,16 @@ TEST_F(IndexingAnalysisTest, GatherOp) { )")); EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( operand id = 0 - (d0, d1, d2, d3)[s0, s1] -> (d1 + s0, d2 + s1, d3), + (d0, d1, d2, d3)[rt0, rt1] -> (d1 + rt0, d2 + rt1, d3), domain: d0 in [0, 1805], d1 in [0, 6], d2 in [0, 7], d3 in [0, 3], - s0 in [0, 26], + rt0 in [0, 26], hlo: %indices = s32[1806,2]{1,0} parameter(1), (d0, d1, d2, d3) -> (d0, 0), - s1 in [0, 68], + rt1 in [0, 68], hlo: %indices = s32[1806,2]{1,0} parameter(1), (d0, d1, d2, d3) -> (d0, 1) operand id = 1 @@ -1340,20 +1340,20 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDynSliceOfDynSlice) { )")); EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( operand id = 0 - (d0, d1)[s0, s1, s2, s3] -> (d0 + s0 + s2, d1 + s1 + s3), + (d0, d1)[rt0, rt1, rt2, rt3] -> (d0 + rt0 + rt2, d1 + rt1 + rt3), domain: d0 in [0, 24], d1 in [0, 15], - s0 in [0, 100], + rt0 in [0, 100], hlo: %of11 = s32[] parameter(1), (d0, d1) -> (), - s1 in [0, 32], + rt1 in [0, 32], hlo: %of12 = s32[] parameter(2), (d0, d1) -> (), - s2 in [0, 25], + rt2 in [0, 25], hlo: %of21 = s32[] parameter(3), (d0, d1) -> (), - s3 in [0, 16], + rt3 in [0, 16], hlo: %of22 = s32[] parameter(4), (d0, d1) -> () operand id = 1 @@ -2610,14 +2610,14 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDUS) { )hlo")); EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( operand id = 0 - (d0, d1)[s0] -> (0, d1 + s0 - 4096), + (d0, d1)[rt0] -> (0, d1 + rt0 - 4096), domain: d0 in [0, 0], d1 in [0, 4095], - s0 in [0, 4096], + rt0 in [0, 4096], hlo: %slice = s32[1]{0} parameter(1), (d0, d1) -> (0), - d1 + s0 in [4096, 8191] + d1 + rt0 in [4096, 8191] operand id = 1 (d0, d1) -> (0), domain: diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index f7ec1f1f83dd76..fceebbbfc3a6d0 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -30,7 +30,9 @@ limitations under the License. #include #include "absl/base/optimization.h" +#include "absl/log/check.h" #include "absl/numeric/int128.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" @@ -783,15 +785,17 @@ SmallVector MapSymbolsToComposedSymbolsList( } // namespace static constexpr std::string_view kVarKindDefault = "default"; -static constexpr std::string_view kVarKindThreadX = "thread_x"; -static constexpr std::string_view kVarKindThreadY = "thread_y"; -static constexpr std::string_view kVarKindThreadZ = "thread_z"; -static constexpr std::string_view kVarKindBlockX = "block_x"; -static constexpr std::string_view kVarKindBlockY = "block_y"; -static constexpr std::string_view kVarKindBlockZ = "block_z"; - -std::string_view ToString(VariableKind type) { - switch (type) { +static constexpr std::string_view kVarKindThreadX = "th_x"; +static constexpr std::string_view kVarKindThreadY = "th_y"; +static constexpr std::string_view kVarKindThreadZ = "th_z"; +static constexpr std::string_view kVarKindBlockX = "bl_x"; +static constexpr std::string_view kVarKindBlockY = "bl_y"; +static constexpr std::string_view kVarKindBlockZ = "bl_z"; +static constexpr std::string_view kVarKindWarp = "warp"; +static constexpr std::string_view kVarKindWarpThread = "th_w"; + +std::string_view ToVariableName(VariableKind var_kind) { + switch (var_kind) { case VariableKind::kDefault: return kVarKindDefault; case VariableKind::kThreadX: @@ -806,23 +810,28 @@ std::string_view ToString(VariableKind type) { return kVarKindBlockY; case VariableKind::kBlockZ: return kVarKindBlockZ; + case VariableKind::kWarp: + return kVarKindWarp; + case VariableKind::kWarpThread: + return kVarKindWarpThread; } llvm_unreachable("Unknown VariableType"); } -VariableKind ToVariableType(std::string_view type_name) { - if (type_name == kVarKindDefault) return VariableKind::kDefault; - if (type_name == kVarKindThreadX) return VariableKind::kThreadX; - if (type_name == kVarKindThreadY) return VariableKind::kThreadY; - if (type_name == kVarKindThreadZ) return VariableKind::kThreadZ; - if (type_name == kVarKindBlockX) return VariableKind::kBlockX; - if (type_name == kVarKindBlockY) return VariableKind::kBlockY; - if (type_name == kVarKindBlockZ) return VariableKind::kBlockZ; - llvm_unreachable("Unknown VariableType name"); +VariableKind ToVariableType(std::string_view var_name) { + if (var_name == kVarKindThreadX) return VariableKind::kThreadX; + if (var_name == kVarKindThreadY) return VariableKind::kThreadY; + if (var_name == kVarKindThreadZ) return VariableKind::kThreadZ; + if (var_name == kVarKindBlockX) return VariableKind::kBlockX; + if (var_name == kVarKindBlockY) return VariableKind::kBlockY; + if (var_name == kVarKindBlockZ) return VariableKind::kBlockZ; + if (var_name == kVarKindWarp) return VariableKind::kWarp; + if (var_name == kVarKindWarpThread) return VariableKind::kWarpThread; + return VariableKind::kDefault; } std::ostream& operator<<(std::ostream& out, VariableKind var_type) { - out << ToString(var_type); + out << ToVariableName(var_type); return out; } @@ -966,16 +975,16 @@ Interval Interval::FloorDiv(int64_t rhs) const { } bool operator==(const DimVar& lhs, const DimVar& rhs) { - return lhs.bounds == rhs.bounds; + return lhs.bounds == rhs.bounds && lhs.name == rhs.name; } bool operator==(const RangeVar& lhs, const RangeVar& rhs) { - return lhs.range == rhs.range; + return lhs.range == rhs.range && lhs.name == rhs.name; } bool operator==(const RTVar& lhs, const RTVar& rhs) { return lhs.feasible_values == rhs.feasible_values && lhs.hlo == rhs.hlo && - lhs.map == rhs.map; + lhs.map == rhs.map && lhs.name == rhs.name; } std::vector DimVarsFromTensorSizes( @@ -983,17 +992,29 @@ std::vector DimVarsFromTensorSizes( std::vector ranges; ranges.reserve(tensor_sizes.size()); for (int64_t size : tensor_sizes) { - ranges.push_back({Interval{0, size - 1}}); + ranges.push_back(DimVar{0, size - 1}); } return ranges; } +std::vector DimVarsFromGPUGrid(absl::Span grid_sizes) { + CHECK_EQ(grid_sizes.size(), 6) + << "Grid must be 6-dimensional (th_x, th_y, th_z, bl_x, bl_y, bl_z)"; + return { + DimVar{0, grid_sizes[0] - 1, kVarKindThreadX}, + DimVar{0, grid_sizes[1] - 1, kVarKindThreadY}, + DimVar{0, grid_sizes[2] - 1, kVarKindThreadZ}, + DimVar{0, grid_sizes[3] - 1, kVarKindBlockX}, + DimVar{0, grid_sizes[4] - 1, kVarKindBlockY}, + DimVar{0, grid_sizes[5] - 1, kVarKindBlockZ}, + }; +} std::vector RangeVarsFromTensorSizes( absl::Span tensor_sizes) { std::vector ranges; ranges.reserve(tensor_sizes.size()); for (int64_t size : tensor_sizes) { - ranges.push_back({Interval{0, size - 1}}); + ranges.push_back({RangeVar{0, size - 1}}); } return ranges; } @@ -1633,6 +1654,7 @@ void IndexingMap::ResetToKnownEmpty() { } bool IndexingMap::VerifyVariableIntervals() { + // TODO: Check if the variable names are unique. return llvm::all_of(dim_vars_, [](const DimVar& dim_var) { return dim_var.bounds.IsFeasible(); diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 5751cb4c886d10..1522a42e63f242 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -50,10 +50,14 @@ enum class VariableKind : char { kThreadX, kThreadY, kThreadZ, + // GPU warp ID. + kWarp, + // GPU thread ID in the warp. + kWarpThread }; -std::string_view ToString(VariableKind type); -VariableKind ToVariableType(std::string_view type_name); +std::string_view ToVariableName(VariableKind var_kind); +VariableKind ToVariableType(std::string_view var_name); std::ostream& operator<<(std::ostream& out, VariableKind var_type); // Interval represents a closed interval [lower_bound, upper_bound]. @@ -201,7 +205,14 @@ class RangeEvaluator { // Dimension variable represents a dimension of a tensor or a GPU grid. // Dimensions correspond to the dimension parameter of `affine_map_`. struct DimVar { + DimVar() = default; + explicit DimVar(Interval bounds, llvm::StringRef name = "") + : bounds(bounds), name(name) {} + DimVar(int64_t lb, int64_t ub, llvm::StringRef name = "") + : DimVar(Interval{lb, ub}, name) {} + Interval bounds; + std::string name = ""; }; bool operator==(const DimVar& lhs, const DimVar& rhs); inline bool operator!=(const DimVar& lhs, const DimVar& rhs) { @@ -222,7 +233,14 @@ inline size_t hash_value(const DimVar& dim_var) { // tensor. RangeSymbol variables correspond to the front portion of the // symbols in `affine_map_`. struct RangeVar { + RangeVar() = default; + explicit RangeVar(Interval range, llvm::StringRef name = "") + : range(range), name(name) {} + RangeVar(int64_t lb, int64_t ub, llvm::StringRef name = "") + : RangeVar(Interval{lb, ub}, name) {} + Interval range; + std::string name = ""; }; bool operator==(const RangeVar& lhs, const RangeVar& rhs); inline bool operator!=(const RangeVar& lhs, const RangeVar& rhs) { @@ -248,6 +266,7 @@ struct RTVar { // the iteration space of `hlo`. It shows what element of `hlo` we need to // extract to get the runtime value for the RTVar. mlir::AffineMap map; + std::string name = ""; }; bool operator==(const RTVar& lhs, const RTVar& rhs); inline bool operator!=(const RTVar& lhs, const RTVar& rhs) { @@ -264,6 +283,8 @@ H AbslHashValue(H h, const RTVar& rt_var) { std::vector DimVarsFromTensorSizes( absl::Span tensor_sizes); +std::vector DimVarsFromGPUGrid(absl::Span grid_sizes); + std::vector RangeVarsFromTensorSizes( absl::Span tensor_sizes); diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc index 3d6eb9bf1b1b23..bdf848a462fb17 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc @@ -22,11 +22,13 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" @@ -389,20 +391,47 @@ bool ParseAffineExprsWithMLIR(ArrayRef dim_var_names, return true; } -std::string GetSymbolName(int64_t symbol_id, - absl::Span symbol_names = {}) { - if (symbol_names.empty()) { - return absl::StrCat("s", symbol_id); +std::string GetVarName(int64_t id, std::string_view name, + std::string_view prefix) { + if (!name.empty()) { + return std::string(name); } - return symbol_names.at(symbol_id); + return absl::StrFormat("%s%d", prefix, id); } -std::string GetDimensionName(int64_t dim_id, - absl::Span dim_names = {}) { - if (dim_names.empty()) { - return absl::StrCat("d", dim_id); +std::string GetDimVarName(int64_t dim_id, std::string_view dim_name = "") { + return GetVarName(dim_id, dim_name, "d"); +} + +std::string GetRangeVarName(int64_t range_id, + std::string_view range_name = "") { + return GetVarName(range_id, range_name, "s"); +} + +std::string GetRTVarName(int64_t rt_id, std::string_view rt_name = "") { + return GetVarName(rt_id, rt_name, "rt"); +} + +std::string GetAffineSymbolName( + int64_t id, absl::Span symbol_names = {}) { + if (id < symbol_names.size()) { + const auto& name = symbol_names[id]; + if (!name.empty()) { + return name; + } } - return dim_names.at(dim_id); + return absl::StrFormat("%s%d", "s", id); +} + +std::string GetAffineDimensionName( + int64_t id, absl::Span dim_names = {}) { + if (id < dim_names.size()) { + const auto& name = dim_names[id]; + if (!name.empty()) { + return name; + } + } + return absl::StrFormat("%s%d", "d", id); } void PrintAffineExprImpl(const AffineExpr affine_expr, @@ -414,12 +443,12 @@ void PrintAffineExprImpl(const AffineExpr affine_expr, case AffineExprKind::SymbolId: { unsigned symbol_id = mlir::cast(affine_expr).getPosition(); - os << GetSymbolName(symbol_id, symbol_names); + os << GetAffineSymbolName(symbol_id, symbol_names); return; } case AffineExprKind::DimId: { unsigned dim_id = mlir::cast(affine_expr).getPosition(); - os << GetDimensionName(dim_id, dim_names); + os << GetAffineDimensionName(dim_id, dim_names); return; } case AffineExprKind::Constant: @@ -590,7 +619,7 @@ std::optional ParseIndexingMap(llvm::StringRef input, } // Parse dimension variables. std::vector dim_vars; - for (auto& dim_name : dim_var_names) { + for (const auto& [dim_id, dim_name] : llvm::enumerate(dim_var_names)) { std::string var_name; Interval interval; if (!parser.ParseVarName(&var_name) || @@ -605,11 +634,14 @@ std::optional ParseIndexingMap(llvm::StringRef input, llvm::errs() << "Dimension name mismatch\n"; return std::nullopt; } - dim_vars.push_back(DimVar{interval}); + if (var_name == GetDimVarName(dim_id)) { + var_name = ""; + } + dim_vars.push_back(DimVar{interval, var_name}); } // Parse range variables. std::vector range_vars; - for (auto& symbol_var : symbol_var_names) { + for (const auto& [index, range_name] : llvm::enumerate(symbol_var_names)) { std::string var_name; Interval interval; if (!parser.ParseVarName(&var_name) || @@ -620,11 +652,14 @@ std::optional ParseIndexingMap(llvm::StringRef input, llvm::errs() << "Failed to parse RangeVar\n"; return std::nullopt; } - if (var_name != symbol_var) { + if (var_name != range_name) { llvm::errs() << "Symbol name mismatch\n"; return std::nullopt; } - range_vars.push_back(RangeVar{interval}); + if (var_name == GetRangeVarName(index)) { + var_name = ""; + } + range_vars.push_back(RangeVar{interval, var_name}); } // Parse constraints. SmallVector constraint_bounds; @@ -666,15 +701,6 @@ std::optional ParseIndexingMap(llvm::StringRef input, /*rt_vars=*/{}, constraints}; } -std::string ToString(AffineExpr affine_expr) { - return ToString(affine_expr, /*dim_names=*/{}, /*symbol_names=*/{}); -} - -std::ostream& operator<<(std::ostream& out, AffineExpr affine_expr) { - out << ToString(affine_expr); - return out; -} - std::string ToString(AffineExpr affine_expr, absl::Span dim_names, absl::Span symbol_names) { @@ -685,24 +711,12 @@ std::string ToString(AffineExpr affine_expr, return s; } -std::string ToString(AffineMap affine_map) { - int dim_count = affine_map.getNumDims(); - SmallVector dim_names; - dim_names.reserve(affine_map.getNumDims()); - for (int64_t dim_id = 0; dim_id < dim_count; ++dim_id) { - dim_names.push_back(GetDimensionName(dim_id)); - } - int symbol_count = affine_map.getNumSymbols(); - SmallVector symbol_names; - symbol_names.reserve(affine_map.getNumSymbols()); - for (int64_t symbol_id = 0; symbol_id < symbol_count; ++symbol_id) { - symbol_names.push_back(GetSymbolName(symbol_id)); - } - return ToString(affine_map, dim_names, symbol_names); +std::string ToString(AffineExpr affine_expr) { + return ToString(affine_expr, /*dim_names=*/{}, /*symbol_names=*/{}); } -std::ostream& operator<<(std::ostream& out, AffineMap affine_map) { - out << ToString(affine_map); +std::ostream& operator<<(std::ostream& out, AffineExpr affine_expr) { + out << ToString(affine_expr); return out; } @@ -731,39 +745,46 @@ std::string ToString(AffineMap affine_map, return s; } -std::string ToString(const IndexingMap& indexing_map) { - const auto& affine_map = indexing_map.GetAffineMap(); +std::string ToString(AffineMap affine_map) { int dim_count = affine_map.getNumDims(); SmallVector dim_names; dim_names.reserve(affine_map.getNumDims()); for (int64_t dim_id = 0; dim_id < dim_count; ++dim_id) { - dim_names.push_back(GetDimensionName(dim_id)); + dim_names.push_back(GetAffineDimensionName(dim_id)); } int symbol_count = affine_map.getNumSymbols(); SmallVector symbol_names; symbol_names.reserve(affine_map.getNumSymbols()); for (int64_t symbol_id = 0; symbol_id < symbol_count; ++symbol_id) { - symbol_names.push_back(GetSymbolName(symbol_id)); + symbol_names.push_back(GetAffineSymbolName(symbol_id)); } - return ToString(indexing_map, dim_names, symbol_names); + return ToString(affine_map, dim_names, symbol_names); } -std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { - out << ToString(indexing_map); +std::ostream& operator<<(std::ostream& out, AffineMap affine_map) { + out << ToString(affine_map); return out; } std::string ToString(const IndexingMap& indexing_map, absl::Span dim_names, - absl::Span symbol_names) { + absl::Span range_names, + absl::Span rt_names) { std::stringstream ss; if (indexing_map.IsKnownEmpty()) { ss << "KNOWN EMPTY\n"; return ss.str(); } const auto& dim_vars = indexing_map.GetDimVars(); + CHECK_EQ(dim_names.size(), dim_vars.size()); const auto& range_vars = indexing_map.GetRangeVars(); + CHECK_EQ(range_names.size(), range_vars.size()); const auto& rt_vars = indexing_map.GetRTVars(); + CHECK_EQ(rt_names.size(), rt_vars.size()); + SmallVector symbol_names; + symbol_names.reserve(range_names.size() + rt_names.size()); + symbol_names.append(range_names.begin(), range_names.end()); + symbol_names.append(rt_names.begin(), rt_names.end()); ss << ToString(indexing_map.GetAffineMap(), dim_names, symbol_names); if (dim_vars.empty() && range_vars.empty() && rt_vars.empty()) { return ss.str(); @@ -783,10 +804,8 @@ std::string ToString(const IndexingMap& indexing_map, ss << ", "; } } - int64_t num_range_vars = range_vars.size(); for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { - ss << GetSymbolName(num_range_vars + index, symbol_names) << " in " - << rt_var.feasible_values << ", hlo: " + ss << rt_names[index] << " in " << rt_var.feasible_values << ", hlo: " << (rt_var.hlo == nullptr ? "NULL" : rt_var.hlo->ToString()) << ", " << ToString(rt_var.map); if (--remaining_vars_to_print > 0) { @@ -807,5 +826,35 @@ std::string ToString(const IndexingMap& indexing_map, return ss.str(); } +std::string ToString(const IndexingMap& indexing_map) { + // Get variable names for DimVars. + SmallVector dim_names; + dim_names.reserve(indexing_map.GetDimensionCount()); + for (const auto& [index, dim_var] : + llvm::enumerate(indexing_map.GetDimVars())) { + dim_names.push_back(GetDimVarName(index, dim_var.name)); + } + // Get variable names for RangeVars. + SmallVector range_names; + range_names.reserve(indexing_map.GetRangeVarsCount()); + for (const auto& [index, range_var] : + llvm::enumerate(indexing_map.GetRangeVars())) { + range_names.push_back(GetRangeVarName(index, range_var.name)); + } + // Get variable names for RTVars. + SmallVector rt_names; + rt_names.reserve(indexing_map.GetRTVarsCount()); + for (const auto& [index, rt_var] : + llvm::enumerate(indexing_map.GetRTVars())) { + rt_names.push_back(GetRTVarName(index, rt_var.name)); + } + return ToString(indexing_map, dim_names, range_names, rt_names); +} + +std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { + out << ToString(indexing_map); + return out; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h index f308bb16182862..14420173174d0e 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h @@ -62,7 +62,8 @@ std::string ToString(const IndexingMap& indexing_map); // Prints IndexingMap using the provided variable names. std::string ToString(const IndexingMap& indexing_map, absl::Span dim_names, - absl::Span symbol_names); + absl::Span range_names, + absl::Span rt_names); std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc index 98fff83cc277fd..78d573fdaf45b8 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc @@ -101,33 +101,17 @@ TEST_F(IndexingMapSerializationTest, AffineExprsWithParens) { // This test will be updated when the printing uses types of variables. TEST_F(IndexingMapSerializationTest, CustomNames) { - auto indexing_map_str = R"( - (th_x, bl_x)[vector_elem, reduced_dim, contracted_dim] - -> (contracted_dim, th_x + bl_x, reduced_dim, vector_elem), + ParseAndCheck(R"( + (th_x, bl_x)[s0, vector_elem, s2] -> (s2, th_x + bl_x, vector_elem, s0), domain: th_x in [0, 3], bl_x in [0, 4], - vector_elem in [0, 1], - reduced_dim in [0, 1], - contracted_dim in [0, 3], - th_x mod 4 in [0, 0], - bl_x + vector_elem in [0, 45] - )"; - auto indexing_map_golden = R"( - (d0, d1)[s0, s1, s2] -> (s2, d0 + d1, s1, s0), - domain: - d0 in [0, 3], - d1 in [0, 4], s0 in [0, 1], - s1 in [0, 1], + vector_elem in [0, 1], s2 in [0, 3], - d0 mod 4 in [0, 0], - d1 + s0 in [0, 45] - )"; - auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); - ASSERT_TRUE(indexing_map.has_value()); - EXPECT_THAT(ToString(*indexing_map), - MatchIndexingString(indexing_map_golden)); + bl_x + s0 in [0, 45], + th_x mod 4 in [0, 0] + )"); } TEST_F(IndexingMapSerializationTest, AffineMapPrinterTest) { diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index b7e45e141b6af6..b55a319d83380d 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -72,20 +72,24 @@ std::vector ConvertToSTL(const llvm::SmallBitVector& bit_vector) { TEST_F(IndexingMapTest, VariableKind) { EXPECT_EQ(ToVariableType("default"), VariableKind::kDefault); - EXPECT_EQ(ToVariableType("thread_x"), VariableKind::kThreadX); - EXPECT_EQ(ToVariableType("thread_y"), VariableKind::kThreadY); - EXPECT_EQ(ToVariableType("thread_z"), VariableKind::kThreadZ); - EXPECT_EQ(ToVariableType("block_x"), VariableKind::kBlockX); - EXPECT_EQ(ToVariableType("block_y"), VariableKind::kBlockY); - EXPECT_EQ(ToVariableType("block_z"), VariableKind::kBlockZ); - - EXPECT_EQ(ToString(VariableKind::kDefault), "default"); - EXPECT_EQ(ToString(VariableKind::kThreadX), "thread_x"); - EXPECT_EQ(ToString(VariableKind::kThreadY), "thread_y"); - EXPECT_EQ(ToString(VariableKind::kThreadZ), "thread_z"); - EXPECT_EQ(ToString(VariableKind::kBlockX), "block_x"); - EXPECT_EQ(ToString(VariableKind::kBlockY), "block_y"); - EXPECT_EQ(ToString(VariableKind::kBlockZ), "block_z"); + EXPECT_EQ(ToVariableType("th_x"), VariableKind::kThreadX); + EXPECT_EQ(ToVariableType("th_y"), VariableKind::kThreadY); + EXPECT_EQ(ToVariableType("th_z"), VariableKind::kThreadZ); + EXPECT_EQ(ToVariableType("bl_x"), VariableKind::kBlockX); + EXPECT_EQ(ToVariableType("bl_y"), VariableKind::kBlockY); + EXPECT_EQ(ToVariableType("bl_z"), VariableKind::kBlockZ); + EXPECT_EQ(ToVariableType("warp"), VariableKind::kWarp); + EXPECT_EQ(ToVariableType("th_w"), VariableKind::kWarpThread); + + EXPECT_EQ(ToVariableName(VariableKind::kDefault), "default"); + EXPECT_EQ(ToVariableName(VariableKind::kThreadX), "th_x"); + EXPECT_EQ(ToVariableName(VariableKind::kThreadY), "th_y"); + EXPECT_EQ(ToVariableName(VariableKind::kThreadZ), "th_z"); + EXPECT_EQ(ToVariableName(VariableKind::kBlockX), "bl_x"); + EXPECT_EQ(ToVariableName(VariableKind::kBlockY), "bl_y"); + EXPECT_EQ(ToVariableName(VariableKind::kBlockZ), "bl_z"); + EXPECT_EQ(ToVariableName(VariableKind::kWarp), "warp"); + EXPECT_EQ(ToVariableName(VariableKind::kWarpThread), "th_w"); } TEST_F(IndexingMapTest, VerifyDimensions) { @@ -120,21 +124,20 @@ TEST_F(IndexingMapTest, RTVar) { /*instr=*/nullptr, zero_dim_map})}; IndexingMap indexing_map( - ParseAffineMap("(d0, d1)[s0, s1, s2] -> (d1, d0, s0 + s1, s1)", + ParseAffineMap("(d0, d1)[range, rt0, rt1] -> (d1, d0, range + rt0, rt1)", &mlir_context_), - {DimVar{{0, 99}}, DimVar{{0, 43}}}, {RangeVar{{-99, 99}}}, + {DimVar{0, 99, "d0"}, DimVar{0, 43, "d1"}}, {RangeVar{-99, 99, "range"}}, std::move(rt_vars)); - EXPECT_THAT(ToString(indexing_map, {"d0", "d1"}, {"range", "rt_0", "rt_1"}), - MatchIndexingString(R"( - (d0, d1)[range, rt_0, rt_1] -> (d1, d0, range + rt_0, rt_0), + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( + (d0, d1)[range, rt0, rt1] -> (d1, d0, range + rt0, rt1), domain: d0 in [0, 99], d1 in [0, 43], range in [-99, 99], - rt_0 in [0, 2], + rt0 in [0, 2], hlo: NULL, (d0, d1) -> (), - rt_1 in [0, 7], + rt1 in [0, 7], hlo: NULL, (d0, d1) -> () )")); @@ -177,14 +180,12 @@ TEST_F(IndexingMapTest, Composition_Permutation) { s0 in [0, 1], s1 in [0, 1] )"); - IndexingMap consumer = Parse(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 3], s0 in [0, 3] )"); - auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), @@ -276,36 +277,35 @@ TEST_F(IndexingMapTest, Composition_RTVar) { auto zero_dim_map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, &mlir_context_); std::vector rt_vars{ - RTVar{Interval{0, 0}, - /*instr=*/nullptr, zero_dim_map}, + RTVar{Interval{0, 0}, /*instr=*/nullptr, zero_dim_map}, RTVar({Interval{0, 1}, /*instr=*/nullptr, zero_dim_map}), RTVar({Interval{0, 226}, /*instr=*/nullptr, zero_dim_map})}; IndexingMap producer( - ParseAffineMap("(d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2)", - &mlir_context_), + ParseAffineMap( + "(d0, d1, d2)[rt0, rt1, rt2] -> (d0 + rt0, d1 + rt1, d2 + rt2)", + &mlir_context_), {DimVar{{0, 0}}, DimVar{{0, 1}}, DimVar{{0, 226}}}, {}, std::move(rt_vars)); IndexingMap consumer( - ParseAffineMap("(d0, d1)[s0] -> (0, d1, s0)", &mlir_context_), - {DimVar{{0, 0}}, DimVar{{0, 1}}}, {RangeVar{0, 31}}, {}); + ParseAffineMap("(d0, d1)[s] -> (0, d1, s)", &mlir_context_), + {DimVar{0, 0}, DimVar{0, 1}}, {RangeVar{0, 31, "s"}}, {}); auto composed = ComposeIndexingMaps(consumer, producer); - EXPECT_THAT(ToString(composed, {"d0", "d1"}, {"s", "rt_0", "rt_1", "rt_2"}), - MatchIndexingString(R"( - (d0, d1)[s, rt_0, rt_1, rt_2] -> (rt_0, d1 + rt_1, s + rt_2), + EXPECT_THAT(ToString(composed), MatchIndexingString(R"( + (d0, d1)[s, rt0, rt1, rt2] -> (rt0, d1 + rt1, s + rt2), domain: d0 in [0, 0], d1 in [0, 1], s in [0, 31], - rt_0 in [0, 0], + rt0 in [0, 0], hlo: NULL, (d0, d1) -> (), - rt_1 in [0, 1], + rt1 in [0, 1], hlo: NULL, (d0, d1) -> (), - rt_2 in [0, 226], + rt2 in [0, 226], hlo: NULL, (d0, d1) -> () )")); @@ -318,22 +318,20 @@ TEST_F(IndexingMapTest, Composition_OnlyRTVars) { IndexingMap producer( ParseAffineMap("(d0, d1)[s0, s1] -> (d0 + s0, d1 + 4 * s1)", &mlir_context_), - {DimVar{{0, 24}}, DimVar{{0, 15}}}, {}, - {RTVar({Interval{0, 2}, /*instr=*/nullptr, zero_dim_map}), - RTVar({Interval{0, 1}, /*instr=*/nullptr, zero_dim_map})}); + {DimVar{0, 24}, DimVar{0, 15}}, {}, + {RTVar{Interval{0, 2}, /*instr=*/nullptr, zero_dim_map, "ps_0"}, + RTVar{Interval{0, 1}, /*instr=*/nullptr, zero_dim_map, "ps_1"}}); std::vector consumer_rt_vars; IndexingMap consumer( ParseAffineMap("(d0, d1)[s0, s1] -> (d0 + 2 * s0, d1 + 3 * s1)", &mlir_context_), - {DimVar{{0, 24}}, DimVar{{0, 15}}}, {}, - {RTVar({Interval{0, 25}, /*instr=*/nullptr, zero_dim_map}), - RTVar({Interval{0, 16}, /*instr=*/nullptr, zero_dim_map})}); + {DimVar{0, 24}, DimVar{0, 15}}, {}, + {RTVar{Interval{0, 25}, /*instr=*/nullptr, zero_dim_map, "cs_0"}, + RTVar{Interval{0, 16}, /*instr=*/nullptr, zero_dim_map, "cs_1"}}); auto composed = ComposeIndexingMaps(consumer, producer); - EXPECT_THAT( - ToString(composed, {"d0", "d1"}, {"ps_0", "ps_1", "cs_0", "cs_1"}), - MatchIndexingString(R"( + EXPECT_THAT(ToString(composed), MatchIndexingString(R"( (d0, d1)[ps_0, ps_1, cs_0, cs_1] -> (d0 + cs_0 * 2 + ps_0, d1 + cs_1 * 3 + ps_1 * 4), domain: @@ -600,14 +598,14 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { indexing_map.RemoveUnusedSymbols(); // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. EXPECT_THAT(indexing_map, MatchIndexingMap(R"( - (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42), + (d0)[s0, rt0] -> (d0 * 4 + s0 + rt0 - 42), domain: d0 in [0, 31], s0 in [0, 1], - s1 in [0, 3], + rt0 in [0, 3], hlo: NULL, (d0) -> (), - d0 * 4 + s0 + s1 in [24, 459] + d0 * 4 + s0 + rt0 in [24, 459] )")); }; @@ -1832,10 +1830,10 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0)[s0] -> (d0, s0), + (d0)[rt0] -> (d0, rt0), domain: d0 in [0, 23], - s0 in [0, 512], + rt0 in [0, 512], hlo: %constant = s64[12]{0} constant({...}), (d0) -> (d0 floordiv 2) )")); @@ -1936,7 +1934,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { // arbitrary values that cannot be represent as an affine expression, hence // the RTVar remains in-place. IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + ParseAffineMap("(d0)[rt0] -> (d0, rt0)", &mlir_context_), /*dimensions=*/{{0, 11}}, /*range_vars=*/{}, {RTVar{Interval{0, 11}, @@ -1946,10 +1944,10 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0)[s0] -> (d0, d0 * 2 + s0), + (d0)[rt0] -> (d0, d0 * 2 + rt0), domain: d0 in [0, 11], - s0 in [0, 11], + rt0 in [0, 11], hlo: %constant = s64[12]{0} constant({...}), (d0) -> (d0) )")); diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc b/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc index d014b597dba2b6..6e3fd105be784e 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc @@ -348,13 +348,13 @@ absl::Status VerifyExprsAreIdentical( std::vector dims; dims.reserve(dimension_ranges.size()); for (const auto& interval : dimension_ranges) { - dims.push_back({interval}); + dims.push_back(DimVar{interval}); } std::vector symbols; symbols.reserve(symbol_ranges.size()); for (const auto& interval : symbol_ranges) { - symbols.push_back({interval}); + symbols.push_back(RangeVar{interval}); } IndexingMap map(mlir::AffineMap::get(dimension_ranges.size(), diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index d9607223fae319..baadbc05f0c496 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -961,14 +961,14 @@ ENTRY main { /*tile_sizes=*/{1, 1, 32}, /*tile_strides=*/{0, 1, 1}, /*tile_offsets_indexing=*/R"( - (d0, d1)[s0, s1] -> (s0, d1, s1), + (d0, d1)[rt0, rt1] -> (rt0, d1, rt1), domain: d0 in [0, 0], d1 in [0, 1], - s0 in [0, 1], + rt0 in [0, 1], hlo: %of1 = s32[] parameter(1), (d0, d1, d2) -> (), - s1 in [0, 226], + rt1 in [0, 226], hlo: %of3 = s32[] parameter(3), (d0, d1, d2) -> () )")); From ac67966aa0f838fb9b255ab24bb1b247fcdd7272 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 08:18:14 -0700 Subject: [PATCH 425/483] Support SPMD in XlaBroadcast pass PiperOrigin-RevId: 680583117 --- ...stribution_spmd_clustering_end_to_end.mlir | 41 ++++++++++++------ ...ght_distribution_spmd_mlir_end_to_end.mlir | 43 ++++++++++++------- .../mlir/tensorflow/tests/xla_broadcast.mlir | 33 +++++++++++++- .../tf2xla/internal/passes/xla_broadcast.cc | 5 --- 4 files changed, 87 insertions(+), 35 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_clustering_end_to_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_clustering_end_to_end.mlir index 3db8828faa3dcb..92520a85d73886 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_clustering_end_to_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_clustering_end_to_end.mlir @@ -4,20 +4,37 @@ // CHECK %0 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster__train_helper", device = ""} : () -> tensor // CHECK %1 = "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> // CHECK %2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> -// CHECK %3:2 = tf_device.replicate {n = 2 : i32} { -// CHECK %6 = "tf_device.cluster_func"(%1, %2) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], input_sharding_configuration = ["\08\03\1A\01\04\22\04\00\01\02\03", ""], num_cores_per_replica = 4 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> -// CHECK tf_device.return %6 : tensor<*xf32> +// CHECK %cst = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %cst_0 = "tf.Const"() <{value = dense<[128, 1024]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %3 = "tf.Fill"(%cst_0, %cst) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<128x1024xf32> +// CHECK %cst_1 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %cst_2 = "tf.Const"() <{value = dense<1024> : tensor<1xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<1xi64> +// CHECK %4 = "tf.Fill"(%cst_2, %cst_1) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1xi64>, tensor) -> tensor<1024xf32> +// CHECK %5:2 = tf_device.replicate([%1, %3] as %arg22: tensor<128x1024xf32>, [%2, %4] as %arg23: tensor<1024xf32>) {n = 2 : i32} { +// CHECK %8 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ +// CHECK %11 = "tf.Identity"(%arg22) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> +// CHECK tf_device.return %11 : tensor<128x1024xf32> +// CHECK }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x1024xf32> +// CHECK %9 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ +// CHECK %11 = "tf.Identity"(%arg23) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1024xf32>) -> tensor<1024xf32> +// CHECK tf_device.return %11 : tensor<1024xf32> +// CHECK }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<1024xf32> +// CHECK %10 = "tf_device.cluster_func"(%8, %9) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], input_sharding_configuration = ["\08\03\1A\01\04\22\04\00\01\02\03", ""], num_cores_per_replica = 4 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> +// CHECK tf_device.return %10 : tensor<*xf32> // CHECK } -// CHECK %4 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor -// CHECK %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor -// CHECK return %5 : tensor +// CHECK %6 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor +// CHECK %7 = "tf.Identity"(%6) {device = ""} : (tensor) -> tensor +// CHECK return %7 : tensor - -// CHECK-LABEL: func.func private @_func(%arg0: tensor<128x1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}, %arg1: tensor<1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = ""}) -> (tensor<*xf32> {mhlo.sharding = ""}) { -// CHECK %0 = "tf.XlaSharding"(%arg0) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> -// CHECK %1 = "tf.MatMul"(%0, %arg1) : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> -// CHECK return %1 : tensor<*xf32> -// CHECK } +// CHECK-LABEL: func.func private @_func(%arg0: tensor<128x1024xf32> {mhlo.sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}, %arg1: tensor<1024xf32> {mhlo.sharding = ""}) -> (tensor<*xf32> {mhlo.sharding = ""}) { +// CHECK %cst = "tf.Const"() <{value = dense<[[0, 1]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32> +// CHECK %0 = "tf.XlaAllReduce"(%arg0, %cst) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor<128x1024xf32>, tensor<1x2xi32>) -> tensor<128x1024xf32> +// CHECK %cst_0 = "tf.Const"() <{value = dense<[[0, 1]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32> +// CHECK %1 = "tf.XlaAllReduce"(%arg1, %cst_0) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor<1024xf32>, tensor<1x2xi32>) -> tensor<1024xf32> +// CHECK %2 = "tf.XlaSharding"(%0) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> +// CHECK %3 = "tf.MatMul"(%2, %1) : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> +// CHECK return %3 : tensor<*xf32> +// CHECK } module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1857 : i32}} { func.func @main(%arg0: tensor {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor<*xi64> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_mlir_end_to_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_mlir_end_to_end.mlir index ce7a4f02f96619..dbdad39a35eabf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_mlir_end_to_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_mlir_end_to_end.mlir @@ -3,22 +3,33 @@ // CHECK-LABEL: func.func @main // CHECK: %outputs, %control = tf_executor.island wraps "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> // CHECK: %outputs_0, %control_1 = tf_executor.island wraps "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> -// CHECK: %outputs_2:5, %control_3 = tf_executor.island wraps "tf._TPUCompileMlir"() -// CHECK: %outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_2#0) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor -// CHECK: %control_6 = tf_executor.island wraps "tf.TPUCompileSucceededAssert"(%outputs_4) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> () -// CHECK: %outputs_7, %control_8 = tf_executor.island wraps "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %outputs_9:4, %control_10 = tf_executor.island wraps "tf.Split"(%outputs_7, %outputs) {num_split = 4 : i32} : (tensor, tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) -// CHECK: %outputs_11, %control_12 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#0, %outputs_0, %outputs_2#1) {_parallel_execution_ids = "r0:0,p0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_13, %control_14 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#1, %outputs_0, %outputs_2#2) {_parallel_execution_ids = "r0:0,p0:1", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_15, %control_16 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#2, %outputs_0, %outputs_2#3) {_parallel_execution_ids = "r0:0,p0:2", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_17, %control_18 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#3, %outputs_0, %outputs_2#4) {_parallel_execution_ids = "r0:0,p0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_19, %control_20 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#0, %outputs_0, %outputs_2#1) {_parallel_execution_ids = "r0:1,p0:0", device = "/job:tpu_host_worker/replica:0/task:2/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_21, %control_22 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#1, %outputs_0, %outputs_2#2) {_parallel_execution_ids = "r0:1,p0:1", device = "/job:tpu_host_worker/replica:0/task:2/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_23, %control_24 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#2, %outputs_0, %outputs_2#3) {_parallel_execution_ids = "r0:1,p0:2", device = "/job:tpu_host_worker/replica:0/task:3/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_25, %control_26 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#3, %outputs_0, %outputs_2#4) {_parallel_execution_ids = "r0:1,p0:3", device = "/job:tpu_host_worker/replica:0/task:3/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_27, %control_28 = tf_executor.island wraps "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor -// CHECK: %outputs_29, %control_30 = tf_executor.island wraps "tf.Identity"(%outputs_27) {device = ""} : (tensor) -> tensor -// CHECK: tf_executor.fetch %outputs_29, %control, %control_1, %control_12, %control_14, %control_16, %control_18, %control_20, %control_22, %control_24, %control_26, %control_28 : tensor, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control +// CHECK: %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK: %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() <{value = dense<[128, 1024]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK: %outputs_6, %control_7 = tf_executor.island wraps "tf.Fill"(%outputs_4, %outputs_2) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<128x1024xf32> +// CHECK: %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK: %outputs_10, %control_11 = tf_executor.island wraps "tf.Const"() <{value = dense<1024> : tensor<1xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<1xi64> +// CHECK: %outputs_12, %control_13 = tf_executor.island wraps "tf.Fill"(%outputs_10, %outputs_8) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1xi64>, tensor) -> tensor<1024xf32> +// CHECK: %outputs_14:5, %control_15 = tf_executor.island wraps "tf._TPUCompileMlir"() +// CHECK: %outputs_16, %control_17 = tf_executor.island wraps "tf.Identity"(%outputs_14#0) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor +// CHECK: %control_18 = tf_executor.island wraps "tf.TPUCompileSucceededAssert"(%outputs_16) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> () +// CHECK: %outputs_19, %control_20 = tf_executor.island wraps "tf.Const"() <{value = dense<0> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK: %outputs_21, %control_22 = tf_executor.island wraps "tf.Identity"(%outputs) {_parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> +// CHECK: %outputs_23, %control_24 = tf_executor.island wraps "tf.Identity"(%outputs_0) {_parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1024xf32>) -> tensor<1024xf32> +// CHECK: %outputs_25:4, %control_26 = tf_executor.island wraps "tf.Split"(%outputs_19, %outputs_21) {_parallel_execution_ids = "r0:0", ici_weight_distribution_mlir_bridge_marker = true, num_split = 4 : i32} : (tensor, tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) +// CHECK: %outputs_27, %control_28 = tf_executor.island wraps "tf.TPUExecute"(%outputs_25#0, %outputs_23, %outputs_14#1) {_parallel_execution_ids = "r0:0,p0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_29, %control_30 = tf_executor.island wraps "tf.TPUExecute"(%outputs_25#1, %outputs_23, %outputs_14#2) {_parallel_execution_ids = "r0:0,p0:1", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_31, %control_32 = tf_executor.island wraps "tf.TPUExecute"(%outputs_25#2, %outputs_23, %outputs_14#3) {_parallel_execution_ids = "r0:0,p0:2", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_33, %control_34 = tf_executor.island wraps "tf.TPUExecute"(%outputs_25#3, %outputs_23, %outputs_14#4) {_parallel_execution_ids = "r0:0,p0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_35, %control_36 = tf_executor.island wraps "tf.Identity"(%outputs_6) {_parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> +// CHECK: %outputs_37, %control_38 = tf_executor.island wraps "tf.Identity"(%outputs_12) {_parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1024xf32>) -> tensor<1024xf32> +// CHECK: %outputs_39:4, %control_40 = tf_executor.island wraps "tf.Split"(%outputs_19, %outputs_35) {_parallel_execution_ids = "r0:1", ici_weight_distribution_mlir_bridge_marker = true, num_split = 4 : i32} : (tensor, tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) +// CHECK: %outputs_41, %control_42 = tf_executor.island wraps "tf.TPUExecute"(%outputs_39#0, %outputs_37, %outputs_14#1) {_parallel_execution_ids = "r0:1,p0:0", device = "/job:tpu_host_worker/replica:0/task:2/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_43, %control_44 = tf_executor.island wraps "tf.TPUExecute"(%outputs_39#1, %outputs_37, %outputs_14#2) {_parallel_execution_ids = "r0:1,p0:1", device = "/job:tpu_host_worker/replica:0/task:2/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_45, %control_46 = tf_executor.island wraps "tf.TPUExecute"(%outputs_39#2, %outputs_37, %outputs_14#3) {_parallel_execution_ids = "r0:1,p0:2", device = "/job:tpu_host_worker/replica:0/task:3/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_47, %control_48 = tf_executor.island wraps "tf.TPUExecute"(%outputs_39#3, %outputs_37, %outputs_14#4) {_parallel_execution_ids = "r0:1,p0:3", device = "/job:tpu_host_worker/replica:0/task:3/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_49, %control_50 = tf_executor.island wraps "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor +// CHECK: %outputs_51, %control_52 = tf_executor.island wraps "tf.Identity"(%outputs_49) {device = ""} : (tensor) -> tensor +// CHECK: tf_executor.fetch %outputs_51, %control, %control_1, %control_28, %control_30, %control_32, %control_34, %control_42, %control_44, %control_46, %control_48, %control_50 : tensor, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1857 : i32}} { func.func @main(%arg0: tensor {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor<*xi64> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_broadcast.mlir index a28ea04aec9308..2cc0f99e29e31f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_broadcast.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_broadcast.mlir @@ -1,7 +1,7 @@ // RUN: tf-opt %s -split-input-file -tf-xla-broadcast | FileCheck %s module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1850 : i32}} { -// CHECK-LABEL: func @move_broadcast -func.func @move_broadcast(%arg0: tensor) -> () { +// CHECK-LABEL: func @move_broadcast_non_spmd +func.func @move_broadcast_non_spmd(%arg0: tensor) -> () { // CHECK: %[[ELEM_0:.*]] = "tf.Const"() // CHECK: {ici_weight_distribution_mlir_bridge_marker = true} // CHECK-NEXT: %[[SHAPE_0:.*]] = "tf.Const"() @@ -31,3 +31,32 @@ func.func @move_broadcast(%arg0: tensor) -> () { func.return } } + +// ----- +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2", "/job:worker/replica:0/task:0/device:TPU:3", "/job:worker/replica:0/task:0/device:TPU:4", "/job:worker/replica:0/task:0/device:TPU:5", "/job:worker/replica:0/task:0/device:TPU:6", "/job:worker/replica:0/task:0/device:TPU:7"]} { +// CHECK-LABEL: func @move_broadcast_spmd +func.func @move_broadcast_spmd(%arg0: tensor) -> () { + // CHECK: %[[ELEM_0:.*]] = "tf.Const"() + // CHECK: {ici_weight_distribution_mlir_bridge_marker = true} + // CHECK-NEXT: %[[SHAPE_0:.*]] = "tf.Const"() + // CHECK: {ici_weight_distribution_mlir_bridge_marker = true} + // CHECK-NEXT: %[[FULL_0:.*]] = "tf.Fill"(%[[SHAPE_0]], %[[ELEM_0]]) {ici_weight_distribution_mlir_bridge_marker = true} + // CHECK-NEXT: tf_device.replicate([%arg0, %[[FULL_0]], %[[FULL_0]], %[[FULL_0]]] as %[[REPVAR:.*]]: tensor) {n = 4 : i32} { + // CHECK-NEXT: %[[ID:.*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ + // CHECK-NEXT: %[[IDINSIDE:.*]] = "tf.Identity"(%[[REPVAR]]) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor) -> tensor + // CHECK-NEXT: tf_device.return %[[IDINSIDE]] : tensor + // CHECK-NEXT: }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor + // CHECK-NEXT: "tf_device.cluster"() ({ + // CHECK-NEXT: %[[GROUP:.*]] = "tf.Const"() + // CHECK-SAME: [0, 1, 2, 3] + // CHECK-NEXT: %[[REDUCED:.*]] = "tf.XlaAllReduce"(%[[ID]], %[[GROUP]]) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor, tensor<1x4xi32>) -> tensor + // CHECK-NEXT: "tf.OpA"(%[[REDUCED]]) : (tensor) -> () + tf_device.replicate {n = 4 : i32} { + "tf_device.cluster"() ({ + "tf.OpA"(%arg0) : (tensor) -> () + tf_device.return + }) {allow_soft_placement = false, computation_shape = [], device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, num_replicas = 4 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + } + func.return +} +} diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc index 5e0ec8a50da525..80ba50fcd898e8 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc @@ -302,11 +302,6 @@ LogicalResult MoveAllBroadcastsToCluster(ClusterOp cluster, if (!num_cores_per_replica_attr) return cluster.emitOpError( CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr)); - int num_cores_per_replica = num_cores_per_replica_attr.getInt(); - - // TODO(b/329483850): Support spmd ICI weight distribution so when num of core - // per replica > 1, it does not need to be skipped. - if (num_cores_per_replica != 1) return success(); llvm::SetVector bcasts; cluster->walk([&](Operation* op) { From 3244e6cffeee9bc93659c610c31e670c5c0e53c6 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Mon, 30 Sep 2024 08:31:34 -0700 Subject: [PATCH 426/483] [tsl] Add functions to check if a CPU is x86/aarch64. PiperOrigin-RevId: 680587338 --- .../xla/third_party/tsl/tsl/platform/BUILD | 11 ++++++ .../third_party/tsl/tsl/platform/cpu_info.h | 19 ++++++++++ .../tsl/tsl/platform/cpu_info_test.cc | 36 +++++++++++++++++++ 3 files changed, 66 insertions(+) create mode 100644 third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD index dfe09b5a36e1b9..f7d995e3e4065f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD @@ -100,6 +100,17 @@ cc_library( ], ) +tsl_cc_test( + name = "cpu_info_test", + size = "small", + srcs = ["cpu_info_test.cc"], + deps = [ + ":platform_port", + ":test", + ":test_main", + ], +) + cc_library( name = "criticality", compatible_with = get_compatible_with_portable(), diff --git a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h index 68506b1d34ae8e..c8d3903ffa6f33 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h @@ -21,6 +21,7 @@ limitations under the License. // TODO(ahentz): This is not strictly required here but, for historical // reasons, many people depend on cpu_info.h in order to use kLittleEndian. #include "tsl/platform/byte_order.h" +#include "tsl/platform/platform.h" #if defined(_MSC_VER) // included so __cpuidex function is available for GETCPUID on Windows @@ -150,6 +151,24 @@ bool TestAarch64CPU(Aarch64CPU cpu); // Checks CPU registers to return hardware capabilities. bool TestCPUFeature(CPUFeature feature); +// Checks whether the current processor is x86. +constexpr bool IsX86CPU() { +#ifdef PLATFORM_IS_X86 + return true; +#else + return false; +#endif +} + +// Checks whether the current processor is aarch64. +constexpr bool IsAarch64CPU() { +#if defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__) + return true; +#else + return false; +#endif +} + // Returns CPU Vendor string (i.e. 'GenuineIntel', 'AuthenticAMD', etc.) std::string CPUVendorIDString(); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc new file mode 100644 index 00000000000000..dbef5a57f47397 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software + +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/platform/cpu_info.h" + +#include "tsl/platform/test.h" + +namespace tsl { + +TEST(CPUInfo, CommonX86CPU) { + // CPUs from 1999 onwards support SSE. + if (port::TestCPUFeature(port::CPUFeature::SSE)) { + EXPECT_TRUE(port::IsX86CPU()); + } +} + +TEST(CPUInfo, Aarch64NeoverseV1CPU) { + if (port::TestAarch64CPU(port::Aarch64CPU::ARM_NEOVERSE_V1)) { + EXPECT_TRUE(port::IsAarch64CPU()); + } +} + +} // namespace tsl From 968519d7a48921be9f07deca6100670cf2f41e1b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 09:00:07 -0700 Subject: [PATCH 427/483] Support multiple floating point types in client library test base PiperOrigin-RevId: 680596554 --- third_party/xla/xla/literal_util.cc | 10 +++ third_party/xla/xla/literal_util.h | 2 + third_party/xla/xla/tests/BUILD | 1 + .../xla/xla/tests/client_library_test_base.cc | 90 +++++++++---------- .../xla/xla/tests/client_library_test_base.h | 56 ++++++------ .../xla/xla/tests/reduce_window_test.cc | 2 +- 6 files changed, 84 insertions(+), 77 deletions(-) diff --git a/third_party/xla/xla/literal_util.cc b/third_party/xla/xla/literal_util.cc index 2330aca215483b..503a746bb9ac6f 100644 --- a/third_party/xla/xla/literal_util.cc +++ b/third_party/xla/xla/literal_util.cc @@ -254,6 +254,16 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, return ConvertType(f32_literal); } +/* static */ Literal LiteralUtil::ConvertF32ToF8E5M2( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + +/* static */ Literal LiteralUtil::ConvertF32ToF8E4M3FN( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + /* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index 1048682e2d5f4e..f43d1dbeddffa7 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -244,6 +244,8 @@ class LiteralUtil { static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal); static Literal ConvertF32ToF8E4M3FNUZ(const LiteralSlice& f32_literal); static Literal ConvertF32ToF8E5M2FNUZ(const LiteralSlice& f32_literal); + static Literal ConvertF32ToF8E5M2(const LiteralSlice& f32_literal); + static Literal ConvertF32ToF8E4M3FN(const LiteralSlice& f32_literal); static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); static Literal ConvertF32ToS8(const LiteralSlice& f32_literal); static Literal ConvertF32ToF64(const LiteralSlice& f32_literal); diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 3da7cbf0818787..a48ff93d498731 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -272,6 +272,7 @@ cc_library( "//xla/client:local_client", "//xla/client:xla_builder", "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", "//xla/stream_executor", diff --git a/third_party/xla/xla/tests/client_library_test_base.cc b/third_party/xla/xla/tests/client_library_test_base.cc index 71b6f9bc175a80..e0599f0f35e151 100644 --- a/third_party/xla/xla/tests/client_library_test_base.cc +++ b/third_party/xla/xla/tests/client_library_test_base.cc @@ -26,11 +26,13 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" #include "xla/execution_options_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/test_helpers.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" namespace xla { @@ -291,7 +293,7 @@ absl::StatusOr ClientLibraryTestBase::ComputeAndTransfer( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } @@ -315,7 +317,7 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } @@ -326,20 +328,20 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; } - // We allow using a float expected literal for a bfloat16 output. In this - // case, we need to convert the expected literal to bfloat16. + // We allow using a float expected literal for non float outputs. In this + // case, we need to convert the expected literal to test_type_. const Literal* expected_ptr = &expected; Literal converted_expected; Shape layout_shape; - if (use_bfloat16()) { - converted_expected = LiteralUtil::ConvertF32ToBF16(expected); + if (test_type_ != F32) { + converted_expected = MaybeConvertLiteralToTestType(expected); expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); + subshape->set_element_type(test_type_); } }); shape_with_layout = &layout_shape; @@ -377,27 +379,27 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - // We allow using a float expected literal for a bfloat16 output. In this - // case, we need to convert the expected literal to bfloat16. + // We allow using a float expected literal for a non float outputs. In this + // case, we need to convert the expected literal to type_test_. const Literal* expected_ptr = &expected; Literal converted_expected; Shape layout_shape; - if (use_bfloat16()) { - converted_expected = LiteralUtil::ConvertF32ToBF16(expected); + if (test_type_ != F32) { + converted_expected = MaybeConvertLiteralToTestType(expected); expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); + subshape->set_element_type(test_type_); } }); shape_with_layout = &layout_shape; @@ -535,13 +537,11 @@ ClientLibraryTestBase::ComputeValueAndReference( return std::make_pair(std::move(reference), std::move(result)); } -XlaComputation ClientLibraryTestBase::CreateScalarRelu() { +XlaComputation ClientLibraryTestBase::CreateScalarReluF32() { XlaBuilder builder("relu"); - auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(F32, {}); auto z_value = Parameter(&builder, 0, shape, "z_value"); - auto zero = use_bfloat16() - ? ConstantR0(&builder, static_cast(0.0f)) - : ConstantR0(&builder, 0.0f); + auto zero = ConstantR0(&builder, 0.0f); Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -550,7 +550,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaBuilder builder("max"); - auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(test_type_, {}); auto x = Parameter(&builder, 0, shape, "x"); auto y = Parameter(&builder, 1, shape, "y"); Max(x, y); @@ -559,22 +559,6 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { return std::move(computation_status).value(); } -XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { - XlaBuilder builder("relu_sensitivity"); - auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); - auto activation = Parameter(&builder, 0, shape, "activation"); - auto backprop = Parameter(&builder, 1, shape, "backprop"); - auto zero = use_bfloat16() - ? ConstantR0(&builder, static_cast(0.0f)) - : ConstantR0(&builder, 0.0f); - auto activation_gtz = Gt(activation, zero); - Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); - - auto computation_status = builder.Build(); - TF_CHECK_OK(computation_status.status()); - return std::move(computation_status).value(); -} - std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( int rows, int cols, float offset) { auto array = std::make_unique>(rows, cols); @@ -605,7 +589,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { arguments_.push_back(argument.Clone()); return Parameter(builder, /*parameter_number=*/arguments_.size() - 1, - MaybeConvertShapeToBfloat16(argument.shape()), ""); + MaybeConvertShapeToTestType(argument.shape()), ""); } XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, @@ -623,26 +607,34 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( nullptr, builder, data_handle); } -Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { - if (!use_bfloat16()) { +Shape ClientLibraryTestBase::MaybeConvertShapeToTestType(const Shape& shape) { + if (test_type_ == F32) { return shape; } Shape new_shape = shape; - ShapeUtil::ForEachMutableSubshape(&new_shape, - [](Shape* subshape, const ShapeIndex&) { - if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); - } - }); + ShapeUtil::ForEachMutableSubshape( + &new_shape, [test_type = test_type_](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == F32) { + subshape->set_element_type(test_type); + } + }); return new_shape; } -Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( +Literal ClientLibraryTestBase::MaybeConvertLiteralToTestType( const Literal& literal) { - if (use_bfloat16()) { - return LiteralUtil::ConvertF32ToBF16(literal); + switch (test_type_) { + case BF16: + return LiteralUtil::ConvertF32ToBF16(literal); + case F32: + return literal.Clone(); + case F8E5M2: + return LiteralUtil::ConvertF32ToF8E5M2(literal); + case F8E4M3FN: + return LiteralUtil::ConvertF32ToF8E4M3FN(literal); + default: + LOG(FATAL) << "Unsupported test type: " << test_type_; } - return literal.Clone(); } absl::StatusOr> @@ -650,7 +642,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle) { - Literal param_literal = MaybeConvertLiteralToBfloat16(literal); + Literal param_literal = MaybeConvertLiteralToTestType(literal); TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(param_literal, device_handle)); *data_handle = diff --git a/third_party/xla/xla/tests/client_library_test_base.h b/third_party/xla/xla/tests/client_library_test_base.h index 8610dd6e5ae3cb..2814c7032ea425 100644 --- a/third_party/xla/xla/tests/client_library_test_base.h +++ b/third_party/xla/xla/tests/client_library_test_base.h @@ -61,6 +61,20 @@ std::vector ExpandUseBfloat16( return expanded; } +template +std::vector ExpandTestType( + absl::Span test_type_params, + absl::Span specs) { + std::vector expanded; + for (const PrimitiveType test_type : test_type_params) { + for (const auto& spec : specs) { + expanded.push_back(spec); + expanded.back().test_type = test_type; + } + } + return expanded; +} + // A client library test establishes an in-process XLA client connection. class ClientLibraryTestBase : public ::testing::Test { protected: @@ -236,9 +250,8 @@ class ClientLibraryTestBase : public ::testing::Test { absl::Span arguments, ErrorSpec error); // Create scalar operations for use in reductions. - XlaComputation CreateScalarRelu(); + XlaComputation CreateScalarReluF32(); XlaComputation CreateScalarMax(); - XlaComputation CreateScalarReluSensitivity(); // Special case convenience functions for creating filled arrays. @@ -277,7 +290,7 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a parameter instruction, transfers the literal for the parameter to // server, then stores into "data_handle" the global handle for that // parameter. When the test_type is bfloat16 but the literal has F32 elements, - // the literal will be converted to BF16 before being transferred. + // the literal will be converted to test_type_ before being transferred. absl::StatusOr> CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, XlaBuilder* builder, XlaOp* data_handle); @@ -304,7 +317,7 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a constant instruction with the given literal. When the test_type // is bfloat16 but the literal has F32 elements, the literal will be converted - // to BF16 before being transferred. + // to test_type_ before being transferred. XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the test_type is @@ -417,8 +430,8 @@ class ClientLibraryTestBase : public ::testing::Test { absl::StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); - // Converts an f32 literal to bf16 if test_type is BF16. - Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + // Converts a literal to the test_type if the literal's type is F32. + Literal MaybeConvertLiteralToTestType(const Literal& literal); LocalClient* client_; LocalClient* ref_client_; // To compute reference result. @@ -439,10 +452,11 @@ class ClientLibraryTestBase : public ::testing::Test { verify_output, const Shape* output_with_layout = nullptr); - // Converts an f32 shape to bf16 if use_bfloat16_ is true. - Shape MaybeConvertShapeToBfloat16(const Shape& shape); + // Converts an f32 shape to test_type_. + Shape MaybeConvertShapeToTestType(const Shape& shape); - // Type to use when running tests. + // Type to use when running tests. By default, we use F32 for historical + // reasons and we rely on the underlying tests to change it. PrimitiveType test_type_ = F32; // Arguments to be passed to the computation when it runs. @@ -584,9 +598,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR0(value); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -597,9 +609,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( absl::Span values, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR1(values); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -610,9 +620,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -623,9 +631,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -636,9 +642,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR4Parameter( const Array4D& array_4d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -649,9 +653,7 @@ std::unique_ptr ClientLibraryTestBase::CreateParameter( const Array& array, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateFromArray(array); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; diff --git a/third_party/xla/xla/tests/reduce_window_test.cc b/third_party/xla/xla/tests/reduce_window_test.cc index c65cd9c9af1969..4417ded2499353 100644 --- a/third_party/xla/xla/tests/reduce_window_test.cc +++ b/third_party/xla/xla/tests/reduce_window_test.cc @@ -1339,7 +1339,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window_dilations=*/param.window_dilation, /*padding=*/padding); - ComputeAndCompare(&b, {MaybeConvertLiteralToBfloat16(input_literal)}, + ComputeAndCompare(&b, {MaybeConvertLiteralToTestType(input_literal)}, DefaultErrorSpec()); } }; From c8a8b4cb52295b2fd83e261c8ceb184aebbf53e4 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 30 Sep 2024 09:09:22 -0700 Subject: [PATCH 428/483] [XLA:GPU][IndexAnalysis] Move RTVars folding tests to indexing_analysis_test. Next step will be moving the folding to indexing_analysis.cc and removing `hlo` and `map` fields from RTVar struct. PiperOrigin-RevId: 680600028 --- .../gpu/model/indexing_analysis_test.cc | 233 +++++++++++ .../service/gpu/model/indexing_map_test.cc | 393 ------------------ 2 files changed, 233 insertions(+), 393 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index 0b55ac570b2a7d..aa53db68f98812 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -2589,6 +2589,239 @@ TEST_F(IndexingAnalysisTest, BroadcastingElementwise) { )")); } +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_ScalarConstant) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = s32[4096] parameter(0) + offset = s64[] constant(42) + ROOT dynamic-slice = s32[10] + dynamic-slice(p0, offset), dynamic_slice_sizes={10} + } + ENTRY main { + p0 = s32[4096] parameter(0) + ROOT fusion = s32[10] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0) -> (d0 + 42), + domain: + d0 in [0, 9] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Iota) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = f32[33,76] parameter(0) + iota = s64[42,1] iota(), iota_dimension=0 + ROOT gather = f32[42,1,1] gather(p0, iota), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1,1} + } + ENTRY main { + p0 = f32[33,76] parameter(0) + ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1, d2) -> (d0, 0), + domain: + d0 in [0, 41], + d1 in [0, 0], + d2 in [0, 0] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_IotaAsConstant) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = f32[33,76] parameter(0) + iota = s64[42,1] iota(), iota_dimension=1 + ROOT gather = f32[42,1,1] gather(p0, iota), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1,1} + } + ENTRY main { + p0 = f32[33,76] parameter(0) + ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1, d2) -> (0, 0), + domain: + d0 in [0, 41], + d1 in [0, 0], + d2 in [0, 0] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Broadcast) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = f32[33,76] parameter(0) + c42 = s64[] constant(42) + bcast = s64[42, 1] broadcast(s64[] c42), dimensions={} + ROOT gather = f32[42,1,1] gather(p0, bcast), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1,1} + } + ENTRY main { + p0 = f32[33,76] parameter(0) + ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1, d2) -> (42, 0), + domain: + d0 in [0, 41], + d1 in [0, 0], + d2 in [0, 0] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Reverse) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = f32[33,76] parameter(0) + iota = s64[42,1] iota(), iota_dimension=0 + reverse = s64[42,1] reverse(iota), dimensions={0} + ROOT gather = f32[42,1,1] gather(p0, reverse), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1,1} + } + ENTRY main { + p0 = f32[33,76] parameter(0) + ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1, d2) -> (-d0 + 41, 0), + domain: + d0 in [0, 41], + d1 in [0, 0], + d2 in [0, 0] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Add) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + c42 = s64[] constant(42) + add = s64[] add(c42, p1) + ROOT dynamic-slice = s32[10] + dynamic-slice(p0, add), dynamic_slice_sizes={10} + } + ENTRY main { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 (d0)[rt0] -> (d0 + rt0 + 42), + domain: + d0 in [0, 9], + rt0 in [0, 4086], + hlo: %p1 = s64[] parameter(1), + (d0) -> () + operand id = 1 + (d0) -> (), + domain: + d0 in [0, 9] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Multiply) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + c42 = s64[] constant(42) + add = s64[] multiply(c42, p1) + ROOT dynamic-slice = s32[10] + dynamic-slice(p0, add), dynamic_slice_sizes={10} + } + ENTRY main { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation + } + )hlo")); + // TODO: Figure out why the bounds are not updated. + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 (d0)[rt0] -> (d0 + rt0 * 42), + domain: + d0 in [0, 9], + rt0 in [0, 4086], + hlo: %p1 = s64[] parameter(1), + (d0) -> () + operand id = 1 + (d0) -> (), + domain: + d0 in [0, 9] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_ChainedOps) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + c42 = s64[] constant(42) + c2 = s64[] constant(2) + add = s64[] add(c42, p1) + multiply = s64[] multiply(c2, add) + ROOT dynamic-slice = s32[10] + dynamic-slice(p0, multiply), dynamic_slice_sizes={10} + } + ENTRY main { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0)[rt0] -> (d0 + rt0 * 2 + 84), + domain: d0 in [0, 9], + rt0 in [0, 4086], + hlo: %p1 = s64[] parameter(1), + (d0) -> () + operand id = 1 + (d0) -> (), + domain: + d0 in [0, 9] + )")); +} + TEST_F(IndexingAnalysisTest, FusionOpWithDUS) { auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( HloModule m diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index b55a319d83380d..6e3e2b2cb796b2 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -1560,399 +1560,6 @@ TEST(IntervalMathTest, MultiplicationSaturating) { EXPECT_THAT(any * neg_one, IntervalIs(any)); } -TEST_F(IndexingMapTest, ReplaceConstantRTVars_ScalarConstant) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %constant = s64[] constant(42) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("()[s0] -> (s0)", &mlir_context_), - /*dimensions=*/{}, - /*range_vars=*/{}, - {RTVar{Interval{42, 42}, - hlo_module.value()->entry_computation()->root_instruction(), - AffineMap::get(0, 0, {}, &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString("() -> (42)")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_StaticIndexIntoTensorConstant) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %constant = s64[2, 4]{1,0} constant({{1, 2, 3, 4}, {11, 12, 13, 14}}) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("()[s0] -> (s0)", &mlir_context_), - /*dimensions=*/{}, - /*range_vars=*/{}, - {RTVar{Interval{1, 14}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("() -> (1,2)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString("() -> (13)")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_NonFoldableTensor) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %constant = s64[2, 4]{1,0} constant({{1, 2, 3, 4}, {11, 12, 13, 14}}) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (s0)", &mlir_context_), - /*dimensions=*/{}, - /*range_vars=*/{}, - {RTVar{Interval{1, 14}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (1, d0)", &mlir_context_)}}); - - EXPECT_FALSE(indexing_map.Simplify()); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %iota = s64[10, 10]{1,0} iota(), iota_dimension=0 - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 255}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 9}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0) -> (d0, d0), - domain: - d0 in [0, 255] - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %iota = s64[10, 10]{1,0} iota(), iota_dimension=1 - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 255}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 9}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0) -> (d0, 7), - domain: - d0 in [0, 255] - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %iota = s64[10, 10]{1,0} iota(), iota_dimension=0 - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 255}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 9}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), - Interval{0, 0}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0) -> (d0, d0), - domain: - d0 in [0, 254], - d0 mod 2 in [0, 0] - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %iota = s64[12]{0} iota(), iota_dimension=0 - ROOT %broadcast = s64[32, 12]{1,0} broadcast(s64[12]{0} %iota), dimensions={1} - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // (d0, 11): d0 maps into the broadcasted dimension, so it doesn't matter - // and 11 maps to 11 in iota. - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 31}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 11)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0) -> (d0, 11), - domain: - d0 in [0, 31] - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %iota = s64[12]{0} iota(), iota_dimension=0 - %reverse = s64[12]{0} reverse(s64[12]{0} %iota), dimensions={0} - %reshape = s64[3,4]{1,0} reshape(s64[12]{0} %reverse) - ROOT %broadcast = s64[36,3,4]{2,1,0} broadcast(s64[3,4]{1,0} %reshape), dimensions={1,2} - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // - Iota: [0, 1, ,,,, 11] - // - Reverse: [11, 10, ..., 0] - // - Reshape: [[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]] - // - Coordinates: (d0 floordiv 12, 3) - // - y-coordinate=3 means we index into [8, 4, 0] - // - x-coordinate=(d0 floordiv 12) means our constant looks like this: - // [8, ..., 8, 4, ..., 4, 0, ..., 0] - // - Hence our final expression: (d0 floordiv 12) * -4 + 8 - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 35}}, - /*range_vars=*/{}, - {RTVar{ - Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, d0 floordiv 12, 3)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0) -> (d0, (d0 floordiv 12) * -4 + 8), - domain: - d0 in [0, 35] - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %constant = s64[12]{0} constant({...}) - ROOT %broadcast = s64[24,12]{1,0} broadcast(s64[12]{0} %constant), dimensions={1} - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // (d0, d0 floordiv 2): d0 maps into the broadcasted dimension, so it can't be - // removed, but d0 floordiv 2 doesn't yield an affine expression so we need to - // keep the RTVar, but can optimize it by removing the broadcast. - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 23}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 512}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, d0 floordiv 2)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0)[rt0] -> (d0, rt0), - domain: - d0 in [0, 23], - rt0 in [0, 512], - hlo: %constant = s64[12]{0} constant({...}), - (d0) -> (d0 floordiv 2) - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_Add) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %constant = s64[] constant(42) - %broadcast = s64[12,13,24]{2,1,0} broadcast(s64[] %constant), dimensions={} - %iota = s64[12,13,24]{2,1,0} iota(), iota_dimension=2 - ROOT %add = s64[12,13,24]{2,1,0} add(s64[12,13,24]{2,1,0} %broadcast, s64[12,13,24]{2,1,0} %iota) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // The iota dimension is the last dimension in (d0, 7, 2 * d0), hence this - // composes to 42 + 2 * d0 - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 11}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7, 2 * d0)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0) -> (d0, d0 * 2 + 42), - domain: - d0 in [0, 11] - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_Multiply) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %iota0 = s64[12,12]{1,0} iota(), iota_dimension=0 - %iota1 = s64[12]{0} iota(), iota_dimension=0 - %broadcast = s64[12,12]{1,0} broadcast(s64[12]{0} %iota1), dimensions={1} - %multiply = s64[12,12]{1,0} multiply(s64[12,12]{1,0} %iota0, s64[12,12]{1,0} %broadcast) - ROOT %reverse = s64[12,12]{1,0} reverse(s64[12,12]{1,0} %multiply), dimensions={0} - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // Iota0: [[0, ..., 0], [1, ..., 1], ..., [11, ..., 11]] - // Iota1: [0, ..., 11] - // Broadcast1: [[0, 1, ..., 11], [0, 1, ..., 11], ..., [0, 1, ..., 11]] - // Mul: [[0, .., 0], [0, 1, ..., 11], [0, 2, ..., 22], ..., [0, 11, ..., 121]] - // Reverse: [[0, 11, ..., 121], [0, 10, ..., 110], ..., [0, ..., 0]] - // Therefore (d0, d0) evaluates to: (11 - d0) * d0. - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 11}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, d0)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0) -> (d0, (-d0 + 11) * d0), - domain: - d0 in [0, 11] - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %constant = s64[12]{0} constant({...}) - %broadcast = s64[12,13,24]{2,1,0} broadcast(s64[12]{0} %constant), dimensions={0} - %iota = s64[12,13,24]{2,1,0} iota(), iota_dimension=2 - ROOT %add = s64[12,13,24]{2,1,0} add(s64[12,13,24]{2,1,0} %broadcast, s64[12,13,24]{2,1,0} %iota) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // The iota dimension is the last dimension in (d0, 7, 2 * d0), the constant - // only depends on the first dimension. The constant consists of some - // arbitrary values that cannot be represent as an affine expression, hence - // the RTVar remains in-place. - IndexingMap indexing_map( - ParseAffineMap("(d0)[rt0] -> (d0, rt0)", &mlir_context_), - /*dimensions=*/{{0, 11}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7, 2 * d0)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - - EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( - (d0)[rt0] -> (d0, d0 * 2 + rt0), - domain: - d0 in [0, 11], - rt0 in [0, 11], - hlo: %constant = s64[12]{0} constant({...}), - (d0) -> (d0) - )")); -} - template void ExpectSupportsAbslHashAndEqAndNe(absl::Span values) { EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly(values)); From eec244439a4f0555deeda8a07af3ac5a35fb698f Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 30 Sep 2024 10:04:43 -0700 Subject: [PATCH 429/483] Fork `xla::ExecuteOptions` into `xla::ifrt::ExecuteOptions` Only the fields that are currently being used by any IFRT implementation are copied to the fork. Two fields that are currently being used as a passthrough in PjRt-IFRT but are removed are: * `untuple_result`: Since IFRT does not have XLA tuple types, all PjRt-IFRT invocations must have already been setting this to true. So this CL changes PjRt-IFRT to unconditionally set `untuple_result` to true when invoking PjRt and gets rid of this field from `xla.ifrt.ExecuteOptions`. * `use_major_to_minor_data_layout_for_callbacks`: The meaning of this field is very specific to PjRt. Since this field is set to true in every IFRT invocation that cares about this field, this CL instead changes PjRt-IFRT to internally always set this field to true and avoid exposing this to IFRT. In order to not break the IFRT Proxy's version compatibility, the forked `ExecuteOptionsProto` uses the same field tags as the original proto. PiperOrigin-RevId: 680620535 --- .../core/tfrt/ifrt/ifrt_serving_executable.cc | 12 +++------ third_party/xla/xla/python/ifrt/BUILD | 7 +++++ third_party/xla/xla/python/ifrt/executable.cc | 25 +++++++++++++++++ third_party/xla/xla/python/ifrt/executable.h | 27 +++++++++++++++++-- .../xla/xla/python/ifrt/execute_options.proto | 26 ++++++++++++++++++ .../xla/xla/python/ifrt_proxy/common/BUILD | 1 + .../ifrt_proxy/common/ifrt_service.proto | 4 +-- .../xla/python/pjrt_ifrt/pjrt_executable.cc | 10 ++++--- third_party/xla/xla/python/py_executable.cc | 4 +-- third_party/xla/xla/python/py_executable.h | 4 +-- 10 files changed, 99 insertions(+), 21 deletions(-) create mode 100644 third_party/xla/xla/python/ifrt/execute_options.proto diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index 4bf833133849ee..4472947052abc1 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -636,14 +636,10 @@ absl::StatusOr> IfrtServingExecutable::Execute( if (UsePortableExecution(compile_metadata)) { execution_device_list = device_list; } - TF_ASSIGN_OR_RETURN( - auto execution_result, - executable_bundle->ifrt_executable->Execute( - absl::MakeSpan(args), - /*options=*/ - {.untuple_result = true, - .use_major_to_minor_data_layout_for_callbacks = true}, - std::move(execution_device_list))); + TF_ASSIGN_OR_RETURN(auto execution_result, + executable_bundle->ifrt_executable->Execute( + absl::MakeSpan(args), /*options=*/{}, + std::move(execution_device_list))); auto status = execution_result.status.Await(); TF_RETURN_IF_ERROR(status); diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 51795c0f3b32ba..df149adb101a30 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -87,6 +87,7 @@ cc_library( ":attribute_map", ":device_proto_cc", ":dtype_proto_cc", + ":execute_options_proto_cc", ":remap_plan_proto_cc", ":serdes", ":shape_proto_cc", @@ -109,6 +110,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/functional:function_ref", @@ -165,6 +167,11 @@ xla_cc_test( ], ) +tf_proto_library( + name = "execute_options_proto", + srcs = ["execute_options.proto"], +) + xla_cc_test( name = "future_test", size = "small", diff --git a/third_party/xla/xla/python/ifrt/executable.cc b/third_party/xla/xla/python/ifrt/executable.cc index 77cabe7f6a9389..31f5aee05ebaef 100644 --- a/third_party/xla/xla/python/ifrt/executable.cc +++ b/third_party/xla/xla/python/ifrt/executable.cc @@ -15,11 +15,36 @@ limitations under the License. #include "xla/python/ifrt/executable.h" +#include "absl/status/statusor.h" +#include "xla/python/ifrt/execute_options.pb.h" + namespace xla { namespace ifrt { char Executable::ID = 0; char LoadedExecutable::ID = 0; +absl::StatusOr ExecuteOptions::ToProto() const { + ExecuteOptionsProto proto; + + proto.set_launch_id(launch_id); + proto.mutable_non_donatable_input_indices()->Add( + non_donatable_input_indices.begin(), non_donatable_input_indices.end()); + + return proto; +} + +absl::StatusOr ExecuteOptions::FromProto( + const xla::ifrt::ExecuteOptionsProto& proto) { + ExecuteOptions options; + + options.launch_id = proto.launch_id(); + options.non_donatable_input_indices.insert( + proto.non_donatable_input_indices().begin(), + proto.non_donatable_input_indices().end()); + + return options; +} + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/executable.h b/third_party/xla/xla/python/ifrt/executable.h index 6b642bd5d178d6..adcf1e6ebe9409 100644 --- a/third_party/xla/xla/python/ifrt/executable.h +++ b/third_party/xla/xla/python/ifrt/executable.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/execute_options.pb.h" #include "xla/python/ifrt/future.h" #include "xla/tsl/concurrency/ref_count.h" @@ -104,6 +106,28 @@ class Executable : public llvm::RTTIExtends { static char ID; // NOLINT }; +struct ExecuteOptions { + // If non-zero, identifies this execution as part of a potentially + // multi-device launch. This can be used to detect scheduling errors, e.g. if + // multi-host programs are launched in different orders on different hosts, + // the launch IDs may be used by the runtime to detect the mismatch. + int32_t launch_id = 0; + + // A set of indices denoting the input arrays that should not be donated. An + // input array may be non-donable, for example, if it is referenced more than + // once. Since such runtime information is not available at compile time, the + // compiler might mark the input as `may-alias`, which could lead IFRT to + // donate the input array when it should not. By defining this set of indices, + // a higher-level IFRT caller can instruct IFRT client not to donate specific + // input arrays. + absl::flat_hash_set non_donatable_input_indices; + + absl::StatusOr ToProto() const; + + static absl::StatusOr FromProto( + const ExecuteOptionsProto& proto); +}; + // Wraps a computation that has been fully compiled and loaded for execution. class LoadedExecutable : public llvm::RTTIExtends { @@ -176,8 +200,7 @@ class LoadedExecutable // `LoadedExecutable` methods. - // Short-term alias. - using ExecuteOptions = ::xla::ExecuteOptions; + using ExecuteOptions = xla::ifrt::ExecuteOptions; // Result from an execution. struct ExecuteResult { diff --git a/third_party/xla/xla/python/ifrt/execute_options.proto b/third_party/xla/xla/python/ifrt/execute_options.proto new file mode 100644 index 00000000000000..6cc20c0996f4f6 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/execute_options.proto @@ -0,0 +1,26 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +message ExecuteOptionsProto { + bool untuple_result = 2; + int32 launch_id = 3; + repeated int32 non_donatable_input_indices = 7; + + reserved 1, 4 to 6, 8; +} diff --git a/third_party/xla/xla/python/ifrt_proxy/common/BUILD b/third_party/xla/xla/python/ifrt_proxy/common/BUILD index 23859f59e1ad8f..d8c5feefe6d439 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/common/BUILD @@ -72,6 +72,7 @@ tf_proto_library( "//xla/pjrt:execute_options_proto", "//xla/python/ifrt:attribute_map_proto", "//xla/python/ifrt:dtype_proto", + "//xla/python/ifrt:execute_options_proto", "//xla/python/ifrt:remap_plan_proto", "//xla/python/ifrt:serdes_proto", "//xla/python/ifrt:shape_proto", diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 4a342c253af9cc..3f17ee69abbd57 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -17,9 +17,9 @@ syntax = "proto3"; package xla.ifrt.proxy; import "google/protobuf/any.proto"; -import "xla/pjrt/execute_options.proto"; import "xla/python/ifrt/attribute_map.proto"; import "xla/python/ifrt/dtype.proto"; +import "xla/python/ifrt/execute_options.proto"; import "xla/python/ifrt/remap_plan.proto"; import "xla/python/ifrt/serdes.proto"; import "xla/python/ifrt/shape.proto"; @@ -428,7 +428,7 @@ message LoadedExecutableMetadataResponse { message LoadedExecutableExecuteRequest { fixed64 loaded_executable_handle = 1; repeated fixed64 args_handles = 2; - xla.ExecuteOptionsProto execute_options = 3; + xla.ifrt.ExecuteOptionsProto execute_options = 3; repeated int32 device_ids = 4; } message LoadedExecutableExecuteResponse { diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index 19c441cdea448d..7225a3a40db8fd 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -542,7 +542,11 @@ PjRtLoadedExecutable::Execute( const bool returned_future_supported = pjrt_loaded_executable_->IsReturnedFutureSupported(); - auto opts = options; + xla::ExecuteOptions opts; + opts.untuple_result = true; + opts.launch_id = options.launch_id; + opts.use_major_to_minor_data_layout_for_callbacks = true; + opts.non_donatable_input_indices = options.non_donatable_input_indices; if (!all_loaded_host_callbacks_->empty() && !returned_future_supported) { return Internal( @@ -565,9 +569,7 @@ PjRtLoadedExecutable::Execute( contexts.push_back(CreateHostCallbackStateAndAppendSendRecvCallbacks( host_send_recv_callback->host_callback(), /*host_memory_for_device_manager=*/nullptr, send_callbacks, - recv_callbacks, - /*use_major_to_minor_data_layout_for_callbacks=*/ - options.use_major_to_minor_data_layout_for_callbacks)); + recv_callbacks, opts.use_major_to_minor_data_layout_for_callbacks)); } } opts.send_callbacks = host_callback_states->send_callbacks; diff --git a/third_party/xla/xla/python/py_executable.cc b/third_party/xla/xla/python/py_executable.cc index b7395ad7793050..7e68c944f94386 100644 --- a/third_party/xla/xla/python/py_executable.cc +++ b/third_party/xla/xla/python/py_executable.cc @@ -93,13 +93,11 @@ PyLoadedExecutable::PyLoadedExecutable( if (next_) { next_->prev_ = this; } - options_.untuple_result = true; if (fingerprint_) { options_.launch_id = tsl::Fingerprint32(*fingerprint_); VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() << ": " << *fingerprint_; } - options_.use_major_to_minor_data_layout_for_callbacks = true; } PyLoadedExecutable::~PyLoadedExecutable() { @@ -203,7 +201,7 @@ void PopulateExecuteShardedResults( template > absl::StatusOr ExecuteShardedOnLocalDevicesInternal( - const ExecuteOptions& options, const nb_class_ptr& client, + const ifrt::ExecuteOptions& options, const nb_class_ptr& client, ifrt::LoadedExecutable* ifrt_loaded_executable, absl::Span args, std::optional>>& returned_futures, bool attach_status_to_results) { diff --git a/third_party/xla/xla/python/py_executable.h b/third_party/xla/xla/python/py_executable.h index ed34ce99ef1a89..e032ee7b4acdda 100644 --- a/third_party/xla/xla/python/py_executable.h +++ b/third_party/xla/xla/python/py_executable.h @@ -227,7 +227,7 @@ class PyLoadedExecutable { return exec->shared_ptr_pjrt_loaded_executable(); } - const ExecuteOptions& options() const { return options_; } + const ifrt::ExecuteOptions& options() const { return options_; } const std::optional& fingerprint() const { return fingerprint_; } // Keep `obj` alive as long as PyLoadedExecutable. @@ -246,7 +246,7 @@ class PyLoadedExecutable { std::optional fingerprint_; // The options to pass to `executable_.Execute`. - ExecuteOptions options_; + ifrt::ExecuteOptions options_; // Python objects to keep alive as requested by user. std::vector keepalives_; From bb67f9de6af3d547f8d2aa79f51da7e0aa5c481d Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Mon, 30 Sep 2024 10:04:46 -0700 Subject: [PATCH 430/483] [XLA:GPU] Add sub-byte normalization after TransposeDimensionGrouper TransposeDimensionGrouper inserts bitcasts, and XLA requires subbyte types (int4 in this case) to have explicit bit witdth. PiperOrigin-RevId: 680620568 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 1f0b2b82509fde..f574d64a2290c4 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1564,6 +1564,10 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( !debug_options.xla_gpu_enable_priority_fusion(); pipeline.AddPass>(ignore_small_reduce_dims); pipeline.AddPass>(gpu_version); + // Normalization passes might have introduced s4 tensors without bit width + // annotations, this pass will add the annotations. + pipeline.AddPass( + SubByteNormalization::SET_ELEMENT_SIZE); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } From 4009c1de914ca5267d0845eccfa6e6dd53c02d44 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 10:17:57 -0700 Subject: [PATCH 431/483] Updates GenerateScatterShardingFromOperands() to avoid hash sets (whose iteration order is nondeterministic). PiperOrigin-RevId: 680626186 --- .../auto_sharding/auto_sharding_strategy.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 5e05087c2da1cd..08c926c5268c02 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -145,7 +145,12 @@ void GenerateScatterShardingFromOperands( const HloSharding& update_sharding, const HloSharding& scatter_sharding)> yield_sharding) { - absl::flat_hash_set scatter_shardings; + std::vector scatter_shardings; + auto scatter_shardings_insert = [&](const HloSharding& sharding) { + const auto it = + std::find(scatter_shardings.begin(), scatter_shardings.end(), sharding); + if (it == scatter_shardings.end()) scatter_shardings.push_back(sharding); + }; CHECK_EQ(scatter->scatter_operand_count(), 1); const HloInstruction* scatter_data = scatter->scatter_operands()[0]; const HloInstruction* scatter_indices = scatter->scatter_indices(); @@ -155,11 +160,11 @@ void GenerateScatterShardingFromOperands( ScatterIndexShardingFromUpdateIndexPassthroughDimensions(update_sharding, scatter); - scatter_shardings.insert(data_sharding); + scatter_shardings_insert(data_sharding); if (std::optional maybe_from_update = hlo_sharding_util::ScatterOutputShardingFromUpdate(update_sharding, *scatter)) { - scatter_shardings.insert(*maybe_from_update); + scatter_shardings_insert(*maybe_from_update); } std::optional @@ -182,21 +187,21 @@ void GenerateScatterShardingFromOperands( aligned_operand_parallel_dims; // Infer output sharding from scatter operand sharding. const Shape& shape = scatter->shape(); - scatter_shardings.insert( + scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( data_sharding, scatter_data->shape(), shape, absl::MakeConstSpan(aligned_operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); // Infer output sharding from scatter indices sharding. - scatter_shardings.insert( + scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( indices_sharding, scatter_indices->shape(), shape, absl::MakeConstSpan(scatter_parallel_dims->indices_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); // Infer output sharding from scatter update sharding. - scatter_shardings.insert( + scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( update_sharding, scatter_update->shape(), shape, absl::MakeConstSpan(update_parallel_dims), From 21fc27f2a5ddb8ac35c236f38b49708d2ece605a Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Mon, 30 Sep 2024 10:20:52 -0700 Subject: [PATCH 432/483] PR #17719: Allow compare/select on int4 data Imported from GitHub PR https://github.com/openxla/xla/pull/17719 A simple HLO with kCompare and kSelect opcodes doesn't work for int4 types. This PR fixes the issue. Copybara import of the project: -- 467bb68c4d2a1d1776c7254666a93bba90156bf1 by Sergey Kozub : Allow compare/select on int4 data Merging this change closes #17719 PiperOrigin-RevId: 680627161 --- third_party/xla/xla/service/hlo_verifier.cc | 8 ++++++-- third_party/xla/xla/service/hlo_verifier_test.cc | 15 +++++++++++++++ .../xla/xla/service/layout_normalization.cc | 6 +++++- .../xla/xla/service/layout_normalization_test.cc | 16 ++++++++++++++++ third_party/xla/xla/service/shape_inference.cc | 4 ++++ .../xla/xla/service/shape_inference_test.cc | 12 ++++++++++++ third_party/xla/xla/shape_util.cc | 3 +++ third_party/xla/xla/shape_util_test.cc | 6 ++++++ 8 files changed, 67 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 6578b4ff765ec9..b3205232fafd98 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -2911,8 +2911,12 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { const Layout& operand_layout = operand_shape.layout(); Layout::Equal equal_predicate = Layout::Equal().IgnoreTiles().IgnoreMemorySpace(); - if (instruction->opcode() == HloOpcode::kConvert) { - // Convert instructions can change element_size_in_bits + if (instruction->opcode() == HloOpcode::kConvert || + instruction->opcode() == HloOpcode::kCompare || + (instruction->opcode() == HloOpcode::kSelect && + operand_shape.element_type() == PRED)) { + // Convert and Compare instructions can change element_size_in_bits + // Select instructions ignore element_size_in_bits for predicate equal_predicate.IgnoreElementSize(); } else if (instruction->opcode() == HloOpcode::kDynamicSlice || instruction->opcode() == HloOpcode::kDynamicUpdateSlice || diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index 1737bea0eca27b..cfa7a48eaf1d0e 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -3488,5 +3488,20 @@ TEST_F(HloVerifierTest, NoErrorOnDuplicateChannelId) { ASSERT_IS_OK(verifier.Run(module.get()).status()); } +TEST_F(HloVerifierTestLayoutSensitive, Int4CompareSelect) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main { + a = s4[10]{0:E(4)} parameter(0) + b = s4[10]{0:E(4)} parameter(1) + less = pred[10] compare(a, b), direction=LT + ROOT result = select(less, a, b) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/layout_normalization.cc b/third_party/xla/xla/service/layout_normalization.cc index 16781509e22c60..74100a62e20111 100644 --- a/third_party/xla/xla/service/layout_normalization.cc +++ b/third_party/xla/xla/service/layout_normalization.cc @@ -347,7 +347,11 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto s = hlo->shape(); auto a = hlo->mutable_operand(0); auto b = hlo->mutable_operand(1); - TF_RET_CHECK(a->shape().layout() == s.layout()); + auto layout_equal = Layout::Equal(); + if (hlo->opcode() == HloOpcode::kCompare) { + layout_equal.IgnoreElementSize(); + } + TF_RET_CHECK(layout_equal(a->shape().layout(), s.layout())); TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a)); TF_ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b)); diff --git a/third_party/xla/xla/service/layout_normalization_test.cc b/third_party/xla/xla/service/layout_normalization_test.cc index 88ea4828ec597a..6fcf848ea46be8 100644 --- a/third_party/xla/xla/service/layout_normalization_test.cc +++ b/third_party/xla/xla/service/layout_normalization_test.cc @@ -922,5 +922,21 @@ ENTRY main.17 { }); } +TEST_F(LayoutNormalizationTest, CompareInt4) { + const char* hlo = R"( +HloModule module + +ENTRY main { + a = s4[10]{0:E(4)} parameter(0) + b = s4[10]{0:E(4)} parameter(1) + ROOT out = compare(a, b), direction=EQ +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: pred[10]{0} compare({{.*}}) +)"); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index 4271cc897f41d7..cbcbf97e2e4f88 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -3755,6 +3755,10 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { on_false.is_dynamic_dimension(dimension)); } } + if (result.has_layout()) { + result.mutable_layout()->set_element_size_in_bits( + on_true.layout().element_size_in_bits()); + } return std::move(result); } diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 29ae32add358e3..6c2cf78ab0245c 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -239,6 +239,18 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { HasSubstr("Expected array argument for select pred")); } +TEST_F(ShapeInferenceTest, SelectPreservesElementSize) { + Shape pred_shape = ShapeUtil::MakeShape(PRED, {10}); + Shape int4_shape = ShapeUtil::MakeShape(S4, {10}); + int4_shape.mutable_layout()->set_element_size_in_bits(4); + + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_shape, + int4_shape, int4_shape); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, int4_shape)); +} + TEST_F(ShapeInferenceTest, ClampAllMatrix) { const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 01f7cacfc9b441..9def58503a0854 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -1067,6 +1067,9 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } else { Shape new_shape = original; new_shape.set_element_type(type); + if (new_shape.has_layout() && type == PRED) { + new_shape.mutable_layout()->set_element_size_in_bits(0); + } return new_shape; } } diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index e239a96ce6aa02..2ed50604569880 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -1224,6 +1224,12 @@ TEST(ShapeUtilTest, Int4ShapeSize) { layout->set_element_size_in_bits(4); EXPECT_EQ(ShapeUtil::ArrayDataSize(int4_shape2), 9216 * 6144 / 2); EXPECT_EQ(ShapeUtil::ArraySize(int4_shape2), 9216 * 6144 / 2); + + // Changing the type to PRED should clear element_size_in_bits. + Shape pred_shape = ShapeUtil::ChangeElementType(int4_shape, PRED); + EXPECT_EQ(pred_shape.layout().element_size_in_bits(), 0); + Shape u4_shape = ShapeUtil::ChangeElementType(int4_shape, U4); + EXPECT_EQ(u4_shape.layout().element_size_in_bits(), 4); } TEST(XlaShapeUtilTest, ZeroSize) { From 132f85b3ab002da12b6d798f7b23f17957955c6d Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Mon, 30 Sep 2024 10:42:42 -0700 Subject: [PATCH 433/483] [XLA:SPMD] Propagate shardings forward along explicit batch dims in gather/scatter instructions. We modify and use `InferScatterParallelShardingFromOperands` to propagate shardings along the explicit batch dims in the forward direction (operand -> result). PiperOrigin-RevId: 680635514 --- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 61 +++++-- .../xla/xla/service/sharding_propagation.cc | 57 +++++-- .../xla/service/sharding_propagation_test.cc | 151 ++++++++++++++++++ 3 files changed, 238 insertions(+), 31 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 6d01b8fe6424fa..b4add60edcba55 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -2458,6 +2458,15 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims( return passthrough_dims; } +template +std::vector argsort(absl::Span data) { + std::vector indices(data.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&data](int64_t i1, int64_t i2) { return data[i1] < data[i2]; }); + return indices; +} + HloSharding InferGatherScatterParallelShardingFromOperandSharding( const HloSharding& operand_sharding, const Shape& operand_shape, const Shape& shape, @@ -2466,36 +2475,56 @@ HloSharding InferGatherScatterParallelShardingFromOperandSharding( if (operand_sharding.IsTileMaximal()) { return operand_sharding; } - std::vector output_tile_dims(shape.rank(), 1); - std::vector operand_non_parallel_dims; - operand_non_parallel_dims.reserve(operand_shape.rank()); - // Detect non parallel dimensions in the operand. - for (int i = 0; i < operand_shape.rank(); ++i) { - if (!absl::c_linear_search(output_aligned_operand_parallel_dims, i)) { - operand_non_parallel_dims.push_back(i); + + HloSharding replicate_non_parallel_dims = + PartiallyReplicateTiledShardingOnAllDimsExcept( + operand_sharding, output_aligned_operand_parallel_dims); + if (replicate_non_parallel_dims.IsTileMaximal()) { + return replicate_non_parallel_dims; + } + + // output_aligned_operand_parallel_dims and output_parallel_dims may not be + // in the same order. We need to transpose the sharding accordingly. For + // example, if output_aligned_operand_parallel_dims = [2, 4, 1] and + // output_parallel_dims = [2, 1, 3], the sharding needs to be transposed with + // perm = [3, 2, 1, 4, 0] to adjust the order of devices. + std::vector argsort_output_aligned_operand_parallel_dims = + argsort(output_aligned_operand_parallel_dims); + std::vector argsort_output_parallel_dims = + argsort(output_parallel_dims); + if (argsort_output_aligned_operand_parallel_dims != + argsort_output_parallel_dims) { + std::vector perm( + replicate_non_parallel_dims.tile_assignment().num_dimensions(), -1); + for (int64_t i = 0; i < output_aligned_operand_parallel_dims.size(); ++i) { + perm[output_aligned_operand_parallel_dims + [argsort_output_parallel_dims[i]]] = i; + } + int64_t i = output_aligned_operand_parallel_dims.size(); + for (int64_t& perm_element : perm) { + if (perm_element == -1) { + perm_element = i++; + } } + replicate_non_parallel_dims = + TransposeSharding(replicate_non_parallel_dims, perm); } - // Collect tile dimensions in the operand. The order of the parallel - // dimensions in output_aligned_operand_parallel_dims is the same as that of - // the output + + // Collect tile dimensions in the operand. + std::vector output_tile_dims(shape.rank(), 1); for (int i = 0; i < output_aligned_operand_parallel_dims.size(); ++i) { const int64_t operand_idx = output_aligned_operand_parallel_dims[i]; const int64_t output_idx = output_parallel_dims[i]; output_tile_dims[output_idx] = operand_sharding.tile_assignment().dim(operand_idx); } - HloSharding replicate_non_parallel_dims = - PartiallyReplicateTiledShardingOnDims(operand_sharding, - operand_non_parallel_dims); - if (replicate_non_parallel_dims.IsTileMaximal()) { - return replicate_non_parallel_dims; - } for (int64_t i = replicate_non_parallel_dims.TiledDataRank(); i < replicate_non_parallel_dims.tile_assignment().num_dimensions(); ++i) { output_tile_dims.push_back( replicate_non_parallel_dims.tile_assignment().dim(i)); } + auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment().Reshape(output_tile_dims); return replicate_non_parallel_dims.ReplicateOnLastTileDim() diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 2c58d92a9d2a07..52e84f633a05d2 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -473,21 +473,20 @@ bool InferGatherParallelShardingFromOperands( bool may_combine_partial_sharding) { CHECK(DynCast(instruction)); bool changed = false; - auto aligned_operand_parallel_dims = parallel_dims.operand_parallel_dims; auto output_parallel_dims = hlo_sharding_util::GetGatherParallelOutputDims( *instruction, parallel_dims); - // Infer output sharding from scatter operand sharding. + // Infer output sharding from gather operand sharding. if (hlo_sharding_util::IsSpatiallyPartitioned(instruction->operand(0))) { changed |= MaybeImproveInstructionSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( instruction->operand(0)->sharding(), instruction->operand(0)->shape(), instruction->shape(), - absl::MakeConstSpan(aligned_operand_parallel_dims), + absl::MakeConstSpan(parallel_dims.operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims)), instruction, may_combine_partial_sharding); } - // Infer output sharding from scatter indices sharding. + // Infer output sharding from gather indices sharding. if (hlo_sharding_util::IsSpatiallyPartitioned(instruction->operand(1))) { changed |= MaybeImproveInstructionSharding( hlo_sharding_util:: @@ -514,10 +513,8 @@ bool InferScatterParallelShardingFromOperands( auto scatter_indices = scatter->scatter_indices(); auto scatter_updates = scatter->scatter_updates(); bool changed = false; - auto aligned_operand_parallel_dims = parallel_dims.operand_parallel_dims; auto update_parallel_dims = hlo_sharding_util::GetScatterParallelUpdateDims( *instruction, parallel_dims); - auto output_parallel_dims = aligned_operand_parallel_dims; // Infer output sharding from scatter operand sharding. Shape shape = operand_count == 1 ? instruction->shape() @@ -528,8 +525,9 @@ bool InferScatterParallelShardingFromOperands( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( scatter_operands[i]->sharding(), scatter_operands[i]->shape(), - shape, absl::MakeConstSpan(aligned_operand_parallel_dims), - absl::MakeConstSpan(output_parallel_dims)), + shape, + absl::MakeConstSpan(parallel_dims.operand_parallel_dims), + absl::MakeConstSpan(parallel_dims.operand_parallel_dims)), instruction, {i}, may_combine_partial_sharding); } } @@ -539,7 +537,7 @@ bool InferScatterParallelShardingFromOperands( InferGatherScatterParallelShardingFromOperandSharding( scatter_indices->sharding(), scatter_indices->shape(), shape, absl::MakeConstSpan(parallel_dims.indices_parallel_dims), - absl::MakeConstSpan(output_parallel_dims)); + absl::MakeConstSpan(parallel_dims.operand_parallel_dims)); for (int64_t i = 0; i != operand_count; ++i) { changed |= MaybeImproveInstructionSubSharding( parallel_sharding_from_indices, instruction, {i}, @@ -554,7 +552,7 @@ bool InferScatterParallelShardingFromOperands( InferGatherScatterParallelShardingFromOperandSharding( scatter_updates[i]->sharding(), scatter_updates[i]->shape(), shape, absl::MakeConstSpan(update_parallel_dims), - absl::MakeConstSpan(output_parallel_dims)), + absl::MakeConstSpan(parallel_dims.operand_parallel_dims)), instruction, {i}, may_combine_partial_sharding); } } @@ -2501,6 +2499,21 @@ bool ShardingPropagation::InferShardingFromOperands( } case HloOpcode::kGather: { bool changed = false; + + const GatherDimensionNumbers& dnums = + instruction->gather_dimension_numbers(); + if (!dnums.operand_batching_dims().empty()) { + hlo_sharding_util::GatherScatterParallelDims explict_batch_dims; + explict_batch_dims.operand_parallel_dims.assign( + dnums.operand_batching_dims().begin(), + dnums.operand_batching_dims().end()); + explict_batch_dims.indices_parallel_dims.assign( + dnums.start_indices_batching_dims().begin(), + dnums.start_indices_batching_dims().end()); + changed |= InferGatherParallelShardingFromOperands( + instruction, explict_batch_dims, may_combine_partial_sharding); + } + if (hlo_sharding_util::IsSpatiallyPartitioned(instruction->operand(1))) { HloSharding new_sharding = hlo_sharding_util:: GatherOutputShardingFromIndexIndexPassthroughDimensions( @@ -2540,11 +2553,26 @@ bool ShardingPropagation::InferShardingFromOperands( } case HloOpcode::kScatter: { auto& scatter = *Cast(instruction); + bool changed = false; + + const ScatterDimensionNumbers& dnums = + instruction->scatter_dimension_numbers(); + if (!dnums.input_batching_dims().empty()) { + hlo_sharding_util::GatherScatterParallelDims explict_batch_dims; + explict_batch_dims.operand_parallel_dims.assign( + dnums.input_batching_dims().begin(), + dnums.input_batching_dims().end()); + explict_batch_dims.indices_parallel_dims.assign( + dnums.scatter_indices_batching_dims().begin(), + dnums.scatter_indices_batching_dims().end()); + changed |= InferScatterParallelShardingFromOperands( + instruction, explict_batch_dims, may_combine_partial_sharding); + } + const int64_t operand_count = scatter.scatter_operand_count(); auto scatter_operands = scatter.scatter_operands(); auto scatter_indices = scatter.scatter_indices(); auto scatter_updates = scatter.scatter_updates(); - bool changed = false; if (is_spmd_) { for (int64_t i = 0; i != operand_count; ++i) { if (hlo_sharding_util::IsSpatiallyPartitioned(scatter_operands[i])) { @@ -2559,10 +2587,9 @@ bool ShardingPropagation::InferShardingFromOperands( })) { return changed; } - auto scatter_parallel_dims = - hlo_sharding_util::GetScatterParallelBatchDims(*instruction, - call_graph); - if (scatter_parallel_dims) { + if (auto scatter_parallel_dims = + hlo_sharding_util::GetScatterParallelBatchDims(*instruction, + call_graph)) { changed |= InferScatterParallelShardingFromOperands( instruction, *scatter_parallel_dims, may_combine_partial_sharding); diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index 96303a8d1b6880..f46e4e1a04c1c6 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -6487,6 +6487,157 @@ ENTRY %module { EXPECT_THAT(copy_p, op::Sharding("{replicated}")); } +TEST_F(ShardingPropagationTest, GatherExplicitBatchDimsFromOperandToResult) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0), sharding={devices=[2,2,2,2]<=[16]} + %indices = s32[14,10,6,2] parameter(1) + ROOT %gather = f32[14,10,6,4] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,4} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{devices=[2,2,1,2,2]<=[2,2,2,2]T(2,0," + "3,1) last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, GatherExplicitBatchDimsFromIndicesToResult) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,2,2,2]<=[16]} + ROOT %gather = f32[14,10,6,4] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,4} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Sharding("{devices=[2,2,2,1,2]<=[16] last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ScatterExplicitBatchDimsFromOperandToResult) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0), sharding={devices=[2,2,2,2]<=[16]} + %indices = s32[14,10,6,2] parameter(1) + %updates = f32[14,10,6,2] parameter(2) + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{devices=[2,2,2,2]<=[16]}")); +} + +TEST_F(ShardingPropagationTest, ScatterExplicitBatchDimsFromIndicesToResult) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,2,2,2]<=[16]} + %updates = f32[14,10,6,2] parameter(2) + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Sharding( + "{devices=[2,1,2,1,4]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ScatterExplicitBatchDimsFromUpdatesToResult) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1) + %updates = f32[14,10,6,4] parameter(2), sharding={devices=[2,2,2,2]<=[16]} + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{devices=[2,1,2,2,2]<=[2,2,2,2]T(1,0,3,2) " + "last_tile_dim_replicate}")); +} + TEST_P(ParameterizedMetadataTest, ParallelGatherFromOperandForwardPass) { const char* const hlo_string = R"( HloModule module From 852478e1b1812b30c33b6e69475feb23b1aeb504 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 11:34:53 -0700 Subject: [PATCH 434/483] Get the initial value for `useNewBackend` from the server. PiperOrigin-RevId: 680656732 --- .../core/profiler/convert/trace_viewer/trace_events_to_json.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h index 66de83fe1991b4..f5585cb3eb9b08 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h @@ -68,6 +68,7 @@ struct JsonTraceOptions { TraceEventsColorerInterface* colorer = nullptr; bool generate_stack_frames = true; + bool use_new_backend = false; }; // Counts generated JSON events by type. @@ -516,6 +517,8 @@ void TraceEventsToJson(const JsonTraceOptions& options, output->Append( R"({"displayTimeUnit":"ns","metadata":{"highres-ticks":true},)"); + output->Append(absl::StrFormat(R"("useNewBackend": %s,)", + options.use_new_backend ? "true" : "false")); WriteDetails(options.details, output); WriteSelectedDeviceIds(options.selected_device_ids, output); WriteReturnedEventsSize(events.NumEvents(), output); From a926311f66e2f8f41d5cf718c0989fbfef21e10a Mon Sep 17 00:00:00 2001 From: Twice Date: Mon, 30 Sep 2024 11:35:29 -0700 Subject: [PATCH 435/483] PR #17737: [XLA:GPU] Rename `uint32_count` to `uint8_count` in `GPUDriver::AsynchronousMemsetUint8` Imported from GitHub PR https://github.com/openxla/xla/pull/17737 In `GpuDriver::AsynchronousMemsetUint8`, the count should be in bytes, instead of in uint32. Copybara import of the project: -- 17575b22ef2fa43d9409c2b29ea408e3fc958155 by PragmaTwice : [XLA:GPU] Rename uint32_count to uint8_count in GPUDriver::AsynchronousMemsetUint8 Merging this change closes #17737 PiperOrigin-RevId: 680656988 --- third_party/xla/xla/stream_executor/cuda/cuda_driver.cc | 4 ++-- third_party/xla/xla/stream_executor/gpu/gpu_driver.h | 3 +-- third_party/xla/xla/stream_executor/rocm/rocm_driver.cc | 5 ++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 39f0d211e1611a..40e55a9a07c5d6 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -1079,10 +1079,10 @@ absl::Status GpuDriver::SynchronousMemsetUint32(Context* context, absl::Status GpuDriver::AsynchronousMemsetUint8(Context* context, CUdeviceptr location, uint8_t value, - size_t uint32_count, + size_t uint8_count, CUstream stream) { ScopedActivateContext activation(context); - return cuda::ToStatus(cuMemsetD8Async(location, value, uint32_count, stream), + return cuda::ToStatus(cuMemsetD8Async(location, value, uint8_count, stream), "Failed to enqueue async memset operation"); } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 2b299e544b307e..5c21b8ea1bfd44 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -498,8 +498,7 @@ class GpuDriver { // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gaef08a7ccd61112f94e82f2b30d43627 static absl::Status AsynchronousMemsetUint8(Context* context, GpuDevicePtr location, - uint8_t value, - size_t uint32_count, + uint8_t value, size_t uint8_count, GpuStreamHandle stream); // Performs an asynchronous memset of the device memory segment via diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index 81f7f3d76fbd79..79af3605f3aeb8 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -955,12 +955,11 @@ absl::Status GpuDriver::SynchronousMemsetUint32(Context* context, absl::Status GpuDriver::AsynchronousMemsetUint8(Context* context, hipDeviceptr_t location, - uint8 value, - size_t uint32_count, + uint8 value, size_t uint8_count, GpuStreamHandle stream) { ScopedActivateContext activation{context}; RETURN_IF_ROCM_ERROR( - wrap::hipMemsetAsync(location, value, uint32_count, stream), + wrap::hipMemsetAsync(location, value, uint8_count, stream), "Failed to enqueue async memset operation"); return absl::OkStatus(); } From 1b240fae66605d39e14d4459450664c0ac20e97f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 11:47:09 -0700 Subject: [PATCH 436/483] copy the legalize tf passes under lite/stablehlo PiperOrigin-RevId: 680660984 --- tensorflow/compiler/mlir/lite/BUILD | 1 + tensorflow/compiler/mlir/lite/stablehlo/BUILD | 93 +- .../mlir/lite/stablehlo/odml_to_stablehlo.cc | 1 + .../lite/stablehlo/tests/legalize-tf.mlir | 2532 ++++++ .../lite/stablehlo/transforms/legalize_tf.cc | 6911 +++++++++++++++++ .../stablehlo/transforms/legalize_tf_passes.h | 51 + .../transforms/legalize_tf_patterns.td | 802 ++ .../stablehlo/transforms/tf_stablehlo_pass.cc | 4 +- .../mlir/lite/stablehlo/transforms/utils.cc | 55 + .../mlir/lite/stablehlo/transforms/utils.h | 61 + .../lite/stablehlo/transforms/utils_test.cc | 83 + .../mlir/lite/transforms/prepare_tf.cc | 1 + .../mlir/tf2xla/tests/legalize-tf.mlir | 3810 --------- 13 files changed, 10591 insertions(+), 3814 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 7bb70a19f4f116..3684b75dd13bbe 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -906,6 +906,7 @@ cc_library( "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/lite/stablehlo:optimize_layout", "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index a0c3febeead92f..53d08196547f42 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -1,5 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -119,6 +120,92 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "legalize_utils", + srcs = ["transforms/utils.cc"], + hdrs = ["transforms/utils.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_xla//xla/mlir_hlo", + ], +) + +tf_cc_test( + name = "legalize_utils_test", + srcs = ["transforms/utils_test.cc"], + deps = [ + ":legalize_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_xla//xla/mlir_hlo", + ], +) + +gentbl_cc_library( + name = "legalize_tf_patterns_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "transforms/generated_legalize_tf.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/legalize_tf_patterns.td", + deps = [ + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", + "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + ], +) + +cc_library( + name = "legalize_tf", + srcs = [ + "transforms/generated_legalize_tf.inc", + "transforms/legalize_tf.cc", + ], + hdrs = [ + "transforms/legalize_tf_passes.h", + ], + deps = [ + ":legalize_tf_patterns_inc_gen", + ":legalize_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:conv_grad_shape_utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:bfloat16", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/client:padding", + "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/client/lib:conv_grad_size_util", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:convert_op_folder", + "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", + "@stablehlo//:chlo_ops", + ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), +) + cc_library( name = "tf_stablehlo", srcs = [ @@ -131,6 +218,7 @@ cc_library( "-Ithird_party", ], deps = [ + ":legalize_tf", ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", @@ -953,6 +1041,7 @@ tf_cc_binary( " [tf.lite.OpsSet.EXPERIMENTAL_STABLEHLO_OPS]", deps = [ ":check_accepted_ops_pass", + ":legalize_tf", ":op_stat_pass", ":stablehlo_util", ":transforms", @@ -969,7 +1058,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", - "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", "//tensorflow/core/ir/types:Dialect", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index c2579fb3619911..bd18a351bd86e8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -56,6 +56,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir new file mode 100644 index 00000000000000..bc2ce85d20f9f2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir @@ -0,0 +1,2532 @@ +// RUN: odml-to-stablehlo-opt --tf-stablehlo \ +// RUN: %s | FILECHECK_OPTS="" FileCheck %s + +//===----------------------------------------------------------------------===// +// BatchNorm op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same +// code), so only do a couple of basic checks. + +// CHECK-LABEL: fusedBatchNormV2_noTraining +func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV2_training +func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_noTraining +func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision +// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) +func.func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { + // CHECK: [[DUMMY:%.*]] = stablehlo.constant dense<0.000000e+00> : tensor<0xf32> + // CHECK: [[CONVERT_X:%.*]] = stablehlo.convert [[X]] : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK: [[Y:%.*]] = "stablehlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) + // CHECK: [[Y_CONVERT:%.*]] = stablehlo.convert [[Y]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> + // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] + func.return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training +func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance +func.func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: return %[[VAR]] + func.return %0#4 : tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor +func.func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { + // CHECK-DAG: %[[ALPHA:.*]] = stablehlo.constant dense<0.199999988> + // CHECK-DAG: %[[BETA:.*]] = stablehlo.constant dense<8.000000e-01> + // CHECK-DAG: %[[FACTOR:.*]] = stablehlo.constant dense<1.00195694> + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %[[CORRECTED_VAR:.*]] = stablehlo.multiply %[[VAR]], %[[FACTOR]] + + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = stablehlo.multiply %arg3, %[[ALPHA]] + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = stablehlo.multiply %[[MEAN]], %[[BETA]] + // CHECK: %[[NEW_BATCH_MEAN:.*]] = stablehlo.add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] + + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = stablehlo.multiply %arg4, %[[ALPHA]] + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = stablehlo.multiply %[[CORRECTED_VAR]], %[[BETA]] + // CHECK: %[[NEW_BATCH_VAR:.*]] = stablehlo.add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] + + // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[MEAN]], %[[VAR]] + func.return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision +func.func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK: stablehlo.convert %arg0 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: stablehlo.convert {{.*}} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_NCHW +func.func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_NDHWC +func.func @fusedBatchNormV3_NDHWC(%arg0: tensor<8x8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>) { + // CHECK: feature_index = 4 : i64 + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NDHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported +func.func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { + // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor, tensor, tensor, tensor, tensor) -> tensor + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) + func.return %0#0 : tensor +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1 +func.func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { + // CHECK: tf.FusedBatchNormV3 + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) + func.return %0#0 : tensor +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2 +func.func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor) { + // CHECK: tf.FusedBatchNormV3 + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) + func.return %0#0 : tensor +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGrad_noTraining +func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> + // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + + // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEXT: %[[BCAST_MUL2:.+]] = stablehlo.broadcast_in_dim %[[MUL2]], {{.*}} : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[MUL3:.*]] = stablehlo.multiply %arg0, %[[BCAST_MUL2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[MUL3]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGrad_Training +func.func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) {{.*}} + // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_noTraining +func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> + // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[MUL2:.*]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_Training +func.func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) {{.*}} + // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision +func.func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEST: %[[CST:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + // CHECK-NEST: %[[ADD:.*]] = stablehlo.add %arg4, %[[CST]] : tensor<8xf32> + // CHECK-NEST: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + // CHECK-NEST: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEST: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEST: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> + // CHECK-NEST: %[[CONVERT:.*]] = stablehlo.convert %[[MUL2]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEST: return %[[CONVERT]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision +func.func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%[[CONVERT]], %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %[[GRAD_OPERAND]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[CONVERT]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining +func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { +// CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> +// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> +// CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> +// CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> +// CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> +// CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> +// CHECK-NEXT: return %[[MUL2]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_Training +func.func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<0xf32> + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[EPS]] : tensor<0xf32> to tensor<*xf32> + // CHECK-NEXT: return %[[GRAD_OPERAND]], %[[EPS]], %[[CAST]] : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) + func.return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision +func.func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> + // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %[[MUL2]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[CONVERT]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision +func.func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%[[CONVERT]], %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[CONVERT2:.*]] = stablehlo.convert %[[GRAD_OPERAND]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[CONVERT2]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW +func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> + // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [1] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[MUL2]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW +func.func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +//===----------------------------------------------------------------------===// +// Bias op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @biasAdd_default +func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> + // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} + +// ----- + +// CHECK-LABEL: func @biasAdd_NHWC +func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> + // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} + +// ----- + +// CHECK-LABEL: func @biasAdd_NCHW +func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> + // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} + +// ----- + +// CHECK-LABEL: func @biasAdd_dynamic +func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor -> tensor<4xindex> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[SHAPE]], dims = [1] : (tensor, tensor<4xindex>) -> tensor + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor + // CHECK-NEXT: return %[[ADD]] : tensor + %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @biasAdd_partial_dynamic +func.func @biasAdd_partial_dynamic(%arg0: tensor, %arg1: tensor<512xi32>) -> tensor { + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor -> tensor<4xindex> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[SHAPE]], dims = [3] : (tensor<512xi32>, tensor<4xindex>) -> tensor + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor + // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ADD]] : tensor to tensor + // CHECK-NEXT: return %[[CAST]] : tensor + %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor, tensor<512xi32>) -> tensor + func.return %0 : tensor +} + + +//===----------------------------------------------------------------------===// +// ClipByValue +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @clip +func.func @clip(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK: [[VAL:%.+]] = stablehlo.clamp %arg1, %arg0, %arg2 + + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + // CHECK: return [[VAL]] + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @clip_dynamic +func.func @clip_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp %arg1, %arg0, %arg2 + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[CLAMP]] + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @clip_static_broadcast +func.func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<5xf32> { + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor) -> tensor<5xf32> + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<5xf32> + // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor, tensor) -> tensor<5xf32> + + // CHECK: return [[CLAMP]] + func.return %0 : tensor<5xf32> +} + + +// CHECK-LABEL: @clip_dynamic_broadcast +func.func @clip_dynamic_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK: [[SHP:%.+]] = shape.shape_of %arg0 + // CHECK: [[SHPIDX:%.+]] = arith.index_cast [[SHP]] : tensor<1xindex> to tensor<1xi32> + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = stablehlo.dynamic_broadcast_in_dim %arg1, [[SHPIDX]], dims = [] : (tensor, tensor<1xi32>) -> tensor + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = stablehlo.dynamic_broadcast_in_dim %arg2, [[SHPIDX]], dims = [] : (tensor, tensor<1xi32>) -> tensor + // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[CLAMP]] + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// DiagPart +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @diag_part +// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> +func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { + // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<12x12xf32> + // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> + // CHECK-NEXT: %[[IOTA:.*]] = stablehlo.iota dim = 0 : tensor<12xi32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[IOTA]], dims = [0] : (tensor<12xi32>) -> tensor<12x12xi32> + // CHECK-NEXT: %[[IOTA2:.*]] = stablehlo.iota dim = 0 : tensor<12xi32> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[IOTA2]], dims = [1] : (tensor<12xi32>) -> tensor<12x12xi32> + // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare EQ, %[[BCAST]], %[[BCAST2]], NOTYPE : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> + // CHECK-NEXT: %[[SEL:.*]] = stablehlo.select %[[CMP]], %[[RESHAPE]], %[[CST0]] : tensor<12x12xi1>, tensor<12x12xf32> + // CHECK-NEXT: %[[REDUCE:.*]] = stablehlo.reduce(%[[SEL]] init: %[[CST1]]) applies stablehlo.add across dimensions = [0] : (tensor<12x12xf32>, tensor) -> tensor<12xf32> + // CHECK-NEXT: %[[RESHAPE2:.*]] = stablehlo.reshape %[[REDUCE]] : (tensor<12xf32>) -> tensor<4x3xf32> + // CHECK-NEXT: return %[[RESHAPE2]] : tensor<4x3xf32> + + %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32> + func.return %0: tensor<4x3xf32> +} + +//===----------------------------------------------------------------------===// +// MatrixDiagPart +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @matrix_diag_part +// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32> +func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x22x128xi32> + // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST3:.*]] = stablehlo.constant dense<11> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST4:.*]] = stablehlo.constant dense<0> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<22xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<22xi32>) -> tensor<1x22x128xi32> + // CHECK-NEXT: %[[IOTA1:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %[[IOTA1]], dims = [2] : (tensor<128xi32>) -> tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[CST3]], %[[BCAST0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[NEG0:.*]] = stablehlo.negate %[[SUB0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MIN0:.*]] = stablehlo.minimum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %[[MIN0]], %[[CST2]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX0:.*]] = stablehlo.maximum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB1:.*]] = stablehlo.subtract %[[CST1]], %[[MAX0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MIN1:.*]] = stablehlo.minimum %[[ADD0]], %[[SUB1]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[SUB0]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[CST1]], %[[MIN1]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[CMP0]], %[[SUB2]], %[[CST4]] : tensor<1x22x128xi1>, tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX1:.*]] = stablehlo.maximum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[MAX1]], %[[SELECT0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX2:.*]] = stablehlo.maximum %[[NEG0]], %[[CST4]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB3:.*]] = stablehlo.subtract %[[MAX2]], %[[SELECT0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %[[BCAST1]], %[[SUB2]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD2:.*]] = stablehlo.add %[[BCAST1]], %[[SUB3]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare GE, %[[ADD1]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %[[ADD1]], %[[CST1]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP1]], %[[CMP2]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare GE, %[[ADD2]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP4:.*]] = stablehlo.compare LT, %[[ADD2]], %[[CST2]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP3]], %[[CMP4]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> + // CHECK-NEXT: %[[CONCAT0:.*]] = stablehlo.concatenate %[[ADD2]], %[[ADD1]], dim = 0 : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> + // CHECK-NEXT: %[[GATHER0:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT0]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast %[[RESHAPE0]], sizes = [7] : (tensor<22x128xi1>) -> tensor<7x22x128xi1> + // CHECK-NEXT: %[[SELECT1:.*]] = stablehlo.select %[[BCAST1]], %[[GATHER0]], %[[CST0]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> + // CHECK-NEXT: return %[[SELECT1]] : tensor<7x22x128xi32> + + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_zero_dim_complex +func.func @matrix_diag_part_zero_dim_complex(%arg0: tensor<4x0xcomplex>) -> tensor<0xcomplex> { + %cst = "tf.Const"() {value = dense<-3> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<(0.000000e+00,0.000000e+00)> : tensor>} : () -> tensor> + %0 = "tf.MatrixDiagPartV3"(%arg0, %cst, %cst_0) {align = "RIGHT_LEFT", device = ""} : (tensor<4x0xcomplex>, tensor, tensor>) -> tensor<0xcomplex> + // CHECK: return %{{[0-9]*}} : tensor<0xcomplex> + return %0 : tensor<0xcomplex> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_single_diagonal +func.func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> { + // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x1x128xi32> + // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x1x128xi32> + // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x1x128xi32> + // CHECK-NEXT: %[[FALSE:.*]] = stablehlo.constant dense<0> : tensor<1x1x128xi32> + // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> + // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[IOTA0]] : (tensor<128xi32>) -> tensor<1x1x128xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[RESHAPE0]], %[[FALSE]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %[[RESHAPE0]], %[[CST1]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP1]] : tensor<1x1x128xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare GE, %[[RESHAPE0]], %[[FALSE]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare LT, %[[RESHAPE0]], %[[CST2]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> + // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP2]], %[[CMP3]] : tensor<1x1x128xi1> + // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x1x128xi1> + // CHECK-NEXT: %[[RESHAPE1:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x1x128xi1>) -> tensor<1x128xi1> + // CHECK-NEXT: %[[CONCAT:.*]] = stablehlo.concatenate %[[RESHAPE0]], %[[RESHAPE0]], dim = 0 : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<2x1x128xi32> + // CHECK-NEXT: %[[GATHER:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x1x128xi32>) -> tensor<7x1x128xi32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast %[[RESHAPE1]], sizes = [7] : (tensor<1x128xi1>) -> tensor<7x1x128xi1> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[BCAST]], %[[GATHER]], %[[CST0]] : tensor<7x1x128xi1>, tensor<7x1x128xi32> + // CHECK-NEXT: %[[RESHAPE2:.*]] = stablehlo.reshape %[[SELECT0]] : (tensor<7x1x128xi32>) -> tensor<7x128xi32> + // CHECK-NEXT: return %[[RESHAPE2]] : tensor<7x128xi32> + + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<0> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x128xi32> + func.return %2: tensor<7x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_ll +func.func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x22x128xi32> + // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST3:.*]] = stablehlo.constant dense<11> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[FALSE:.*]] = stablehlo.constant dense<0> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<22xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<22xi32>) -> tensor<1x22x128xi32> + // CHECK-NEXT: %[[IOTA1:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %[[IOTA1]], dims = [2] : (tensor<128xi32>) -> tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[CST3]], %[[BCAST0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[NEG0:.*]] = stablehlo.negate %[[SUB0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX0:.*]] = stablehlo.maximum %[[SUB0]], %[[FALSE]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB1:.*]] = stablehlo.subtract %[[MAX0]], %[[FALSE]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX1:.*]] = stablehlo.maximum %[[NEG0]], %[[FALSE]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[MAX1]], %[[FALSE]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %[[BCAST1]], %[[SUB1]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %[[BCAST1]], %[[SUB2]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[ADD0]], %[[FALSE]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %[[ADD0]], %[[CST1]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP1]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare GE, %[[ADD1]], %[[FALSE]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare LT, %[[ADD1]], %[[CST2]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP2]], %[[CMP3]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> + // CHECK-NEXT: %[[CONCAT0:.*]] = stablehlo.concatenate %[[ADD1]], %[[ADD0]], dim = 0 : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> + // CHECK-NEXT: %[[GATHER0:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT0]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast %[[RESHAPE0]], sizes = [7] : (tensor<22x128xi1>) -> tensor<7x22x128xi1> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[BCAST2]], %[[GATHER0]], %[[CST0]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> + // CHECK-NEXT: return %[[SELECT0]] : tensor<7x22x128xi32> + + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "LEFT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_lr +func.func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "LEFT_RIGHT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK: %[[LE:.*]] = stablehlo.compare LE, %{{.*}}, %{{.*}} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK: %{{.*}} = stablehlo.select %[[LE]], %{{.*}}, %{{.*}} : tensor<1x22x128xi1>, tensor<1x22x128xi32> + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_rl +func.func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK: %[[GE:.*]] = stablehlo.compare GE, %{{.*}}, %{{.*}} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK: %{{.*}} = stablehlo.select %[[GE]], %{{.*}}, %{{.*}} : tensor<1x22x128xi1>, tensor<1x22x128xi32> + + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_rr +func.func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_RIGHT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK-NOT: MatrixDiagPartV3 + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_7d +// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> +func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> { + %0 = mhlo.constant dense<-1.> : tensor // padding value + %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = f32, align = "LEFT_RIGHT" + } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor) -> tensor<3x5x7x9x11x4x10xf32> + func.return %2: tensor<3x5x7x9x11x4x10xf32> +} + +//===----------------------------------------------------------------------===// +// Erf +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @erf +func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: mhlo.erf(%arg0) {{.*}} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + func.return %0 : tensor<2x3xf32> +} + +//===----------------------------------------------------------------------===// +// Erfc +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @erfc +func.func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK-NOT: tf.Erfc + %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + func.return %0 : tensor<2x3xf32> +} + +//===----------------------------------------------------------------------===// +// Einsum. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @einsum +func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { + // CHECK: stablehlo.einsum + %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> + func.return %0: tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func @unary_einsum +func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { + // CHECK: stablehlo.constant{{.*}}1.000000e+00 + // CHECK: stablehlo.einsum{{.*}}",ab->aa" + %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + func.return %0: tensor<2x2xf32> +} + +//===----------------------------------------------------------------------===// +// FloorDiv and FloorMod. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @floordiv_broadcast_i32 +func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2x3xi32> + // CHECK-NEXT: %[[ZEROS0:.*]] = stablehlo.constant dense<0> : tensor<3xi32> + // CHECK-NEXT: %[[ZEROS1:.*]] = stablehlo.constant dense<0> : tensor<2x3xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %[[DIV0:.*]] = stablehlo.divide %arg0, %[[BCAST0]] : tensor<2x3xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %[[MUL0:.*]] = stablehlo.multiply %[[DIV0]], %[[BCAST1]] : tensor<2x3xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare NE, %[[MUL0]], %arg0 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %arg0, %[[ZEROS1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %arg1, %[[ZEROS0]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[CMP2]], dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare NE, %[[CMP1]], %[[BCAST2]] : (tensor<2x3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP3]] : tensor<2x3xi1> + // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[DIV0]], %[[ONES]] : tensor<2x3xi32> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[AND0]], %[[SUB0]], %[[DIV0]] : tensor<2x3xi1>, tensor<2x3xi32> + // CHECK-NEXT: return %[[SELECT0]] : tensor<2x3xi32> + + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_reverse_broadcast_i32 +func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2x3xi32> + // CHECK-NEXT: %[[ZEROS0:.*]] = stablehlo.constant dense<0> : tensor<2x3xi32> + // CHECK-NEXT: %[[ZEROS1:.*]] = stablehlo.constant dense<0> : tensor<3xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %[[DIV0:.*]] = stablehlo.divide %[[BCAST0]], %arg1 : tensor<2x3xi32> + // CHECK-NEXT: %[[MUL0:.*]] = stablehlo.multiply %[[DIV0]], %arg1 : tensor<2x3xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare NE, %[[MUL0]], %[[BCAST1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %arg0, %[[ZEROS1]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %arg1, %[[ZEROS0]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[CMP1]], dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare NE, %[[BCAST2]], %[[CMP2]] : (tensor<2x3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP3]] : tensor<2x3xi1> + // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[DIV0]], %[[ONES]] : tensor<2x3xi32> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[AND0]], %[[SUB0]], %[[DIV0]] : tensor<2x3xi1>, tensor<2x3xi32> + // CHECK-NEXT: return %[[SELECT0]] : tensor<2x3xi32> + + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_f32 +func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: %[[DIV:.*]] = stablehlo.divide %arg0, %arg0 + // CHECK-NEXT: %[[FLOOR:.*]] = stablehlo.floor %[[DIV]] + // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> + %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0: tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_bf16 +func.func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { + // CHECK-NEXT: stablehlo.convert + // CHECK-NEXT: stablehlo.convert + // CHECK-NEXT: stablehlo.divide + // CHECK-NEXT: stablehlo.floor + // CHECK-NEXT: stablehlo.convert + // CHECK-NEXT: return + %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + func.return %0: tensor<2xbf16> +} + +// ----- + +// CHECK-LABEL: func @floordiv_f16_broadcast +func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: stablehlo.broadcast_in_dim + // CHECK-NEXT: stablehlo.divide + // CHECK-NEXT: stablehlo.floor + // CHECK-NEXT: return + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + func.return %0: tensor<2x3xf16> +} + +// ----- + +// CHECK-LABEL: func @floordiv_dynamic +func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.divide + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.multiply + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.compare + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.and + // + // CHECK: %[[SELECT:.*]] = stablehlo.select + // CHECK: return %[[SELECT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floordiv_unsigned +func.func @floordiv_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[RESULT:.*]] = shape.assuming + // CHECK: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, + // CHECK: %[[BCAST1:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, + // CHECK: %[[DIV:.*]] = stablehlo.divide %[[BCAST0]], %[[BCAST1]] + // CHECK: shape.assuming_yield %[[DIV]] + // CHECK: return %[[RESULT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floordiv_int +func.func @floordiv_int(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.divide + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.multiply + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.compare + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.compare + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.and + // + // CHECK: %[[SELECT:.*]] = stablehlo.select + // CHECK: return %[[SELECT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floormod_broadcast_numerator +func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK: %[[REM:.*]] = stablehlo.remainder %[[BCAST0]], %arg1 : tensor<2x3xi32> + // CHECK: %[[AND:.*]] = stablehlo.and + // CHECK: %[[ADD:.*]] = stablehlo.add + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] + // CHECK-NEXT: return %[[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floormod_broadcast_denominator +func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { + // CHECK: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK: %[[REM:.*]] = stablehlo.remainder %arg0, %[[BCAST0]] + // CHECK: %[[AND:.*]] = stablehlo.and + // CHECK: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[BCAST1]], %[[REM]] + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] + // CHECK-NEXT: return %[[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floormod_unsigned_broadcast_denominator +func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> { + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xui32>) -> tensor<2x3xui32> + // CHECK-NEXT: %[[REM:.*]] = stablehlo.remainder %arg0, %[[BCAST0]] : tensor<2x3xui32> + // CHECK-NEXT: return %[[REM]] : tensor<2x3xui32> + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32> + func.return %0: tensor<2x3xui32> +} + +// ----- + +// CHECK-LABEL: func @floormod_dynamic_broadcast_numerator +func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[REM:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.remainder + // CHECK: shape.assuming {{.*}} { + // CHECK: stablehlo.compare + // CHECK: %[[AND:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.and + // CHECK: %[[ADD:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.add + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] + // CHECK-NEXT: return %[[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floormod_dynamic_broadcast_denominator +func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NOT: tf.FloorMod + // CHECK: %[[REM:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.remainder + // CHECK: shape.assuming {{.*}} { + // CHECK: stablehlo.compare + // CHECK: %[[AND:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.and + // CHECK: %[[ADD:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.add + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] + // CHECK-NEXT: return %[[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +//===----------------------------------------------------------------------===// +// OnesLike +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @ones_like +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) +func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<2x?xf32> -> tensor<2xindex> + // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ONES]], %[[SHAPE]], dims = [] : (tensor, tensor<2xindex>) -> tensor<2x?xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x?xf32> + %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> + func.return %0 : tensor<2x?xf32> +} + +//===----------------------------------------------------------------------===// +// ZerosLike +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @zeros_like +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) +func.func @zeros_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { + // CHECK-NEXT: %[[ZEROS:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<2x?xf32> -> tensor<2xindex> + // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZEROS]], %[[SHAPE]], dims = [] : (tensor, tensor<2xindex>) -> tensor<2x?xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x?xf32> + %0 = "tf.ZerosLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> + func.return %0 : tensor<2x?xf32> +} + +//===----------------------------------------------------------------------===// +// BroadcastTo. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @broadcast_to +func.func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> + // CHECK: stablehlo.broadcast_in_dim %arg0 + %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} + +//===----------------------------------------------------------------------===// +// Complex op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @complex +func.func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { + // CHECK: stablehlo.complex + %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> + func.return %1 : tensor<3xcomplex> +} + +// ----- + +// CHECK-LABEL: func @imag +func.func @imag(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { + // CHECK: stablehlo.imag + %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex>) -> tensor<3xf32> + func.return %1 : tensor<3xf32> +} + +// ----- + +// CHECK-LABEL: func @real +func.func @real(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { + // CHECK: stablehlo.real + %1 = "tf.Real"(%arg0) : (tensor<3xcomplex>) -> tensor<3xf32> + func.return %1 : tensor<3xf32> +} + +//===----------------------------------------------------------------------===// +// Concat op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @concat_v2 +func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 0 + %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> + func.return %1 : tensor<6x3xf32> +} + +// ----- + +// CHECK-LABEL: func @concat_v2_neg_axis +func.func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 0 + + %axis = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> + func.return %1 : tensor<6x3xf32> +} + +// ----- + +// CHECK-LABEL: func @concat_v2_1d_axis +func.func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { + // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 1 + + %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> + func.return %1 : tensor<3x6xf32> +} + +// ----- + +// CHECK-LABEL: func @concat_v2_non_const_axis +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 12 : i32}} { +func.func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor) -> tensor<3x6xf32> { + // CHECK: "tf.ConcatV2" + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> + func.return %1 : tensor<3x6xf32> +} +} + +//===----------------------------------------------------------------------===// +// Pad op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @padv2_1D +func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { + %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> + // CHECK: stablehlo.pad %arg0, %arg1, low = [1], high = [2], interior = [0] + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor) -> tensor<6xf32> + func.return %1 : tensor<6xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_2D +func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { + %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> + // CHECK: stablehlo.pad %arg0, %arg1, low = [1, 3], high = [2, 4], interior = [0, 0] + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor) -> tensor<6x9xf32> + func.return %1 : tensor<6x9xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_i32_paddings +func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { + %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> + // CHECK: stablehlo.pad %arg0, %arg1, low = [1, 3], high = [2, 4], interior = [0, 0] + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor) -> tensor<6x9xf32> + func.return %1 : tensor<6x9xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_dynamic +func.func @padv2_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor<1x2xi64>) -> tensor { + // CHECK-NEXT: %[[ZEROS:.*]] = stablehlo.constant dense<0> : tensor<1xi64> + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg2 : (tensor<1x2xi64>) -> tensor<2xi64> + // CHECK-NEXT: %[[SLICE0:.*]] = stablehlo.slice %[[RESHAPE]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK-NEXT: %[[SLICE1:.*]] = stablehlo.slice %[[RESHAPE]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_pad %arg0, %arg1, %[[SLICE0]], %[[SLICE1]], %[[ZEROS]] : (tensor, tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + // CHECK-NEXT: return %[[RESULT]] : tensor + + %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor, tensor<1x2xi64>, tensor) -> tensor + func.return %1 : tensor +} + +//===----------------------------------------------------------------------===// +// Identity op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @identity +func.func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-NEXT: return %arg0 : tensor<1xi32> + %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @identityN +func.func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { + // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> + %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) + func.return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: func @stopgradient +func.func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-NEXT: return %arg0 : tensor<1xi32> + %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @preventgradient +func.func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-NEXT: return %arg0 : tensor<1xi32> + %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @checkNumerics +func.func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-NEXT: return %arg0 : tensor<1xf32> + %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32> + func.return %0: tensor<1xf32> +} + +//===----------------------------------------------------------------------===// +// InfeedDequeueTuple legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @infeed_dequeue_tuple +func.func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) { + // CHECK: [[TOKEN:%.*]] = stablehlo.create_token : !stablehlo.token + // CHECK: [[INFEED:%.*]]:3 = "stablehlo.infeed"([[TOKEN]]) <{infeed_config = ""{{.*}}}> : (!stablehlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !stablehlo.token) + // CHECK: return [[INFEED]]#0, [[INFEED]]#1 + %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) + func.return %0#0, %0#1 : tensor<1x8x4x4xi32>, tensor<1x100x1xf32> +} + +// ----- + +// CHECK-LABEL: func @infeed_dequeue_tuple_dynamic_error +func.func @infeed_dequeue_tuple_dynamic_error() -> (tensor<3x3xf32>, tensor<4x?xf32>) { + // We expect legalization to fail for dynamic shapes: + // CHECK: [[INFEED:%.*]] = "tf.InfeedDequeueTuple"{{.*}} + %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3x3xf32>, tensor<4x?xf32>) + func.return %0#0, %0#1 : tensor<3x3xf32>, tensor<4x?xf32> +} + +// The following op sharding is used: +// Proto debug string: +// type: TUPLE +// tuple_shardings { +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// } +// Serialized string: +// "\08\02*\08\08\01\1A\01\01\22\01\00" + +// CHECK-LABEL: infeed_dequeue_tuple_sharding +func.func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { + // CHECK: "stablehlo.infeed" + // An additional sharding is added at the end to account for token result. + // Proto debug string: + // type: TUPLE + // tuple_shardings { + // type: MAXIMAL + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // } + // tuple_shardings { + // type: MAXIMAL + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // } + // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" + %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> + func.return %0 : tensor<8xi32> +} + +//===----------------------------------------------------------------------===// +// Nullary op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @const +func.func @const() -> tensor<2xi32> { + // CHECK: stablehlo.constant dense<0> : tensor<2xi32> + %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) + func.return %0: tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: @const_dynamic_output +func.func @const_dynamic_output() -> tensor<*xi32> { + // CHECK: [[CONST:%.*]] = stablehlo.constant dense<0> : tensor<2xi32> + // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32> + %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) + // CHECK: return [[CAST]] + func.return %0: tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @opaque_const +func.func @opaque_const() -> tensor>> { + // CHECK-NOT: stablehlo.constant + %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = #tf_type : tensor} : () -> tensor>> + func.return %0 : tensor>> +} + +//===----------------------------------------------------------------------===// +// Matmul op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: matmul_notranspose +// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>) +func.func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> { + // CHECK: stablehlo.dot %[[A]], %[[B]] + %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32> + + func.return %0 : tensor<5x11xf32> +} + +// ----- + +// CHECK-LABEL: matmul_transpose_b +// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>) +func.func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { + // CHECK: %[[UPDATED_B:.*]] = stablehlo.transpose %[[B]], dims = [1, 0] + // CHECK: stablehlo.dot %[[A]], %[[UPDATED_B]] + %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> + + func.return %0 : tensor<5x11xf32> +} + +// ----- + +// CHECK-LABEL: matmul_transpose_both +// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>) +func.func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { + // CHECK: %[[UPDATED_A:.*]] = stablehlo.transpose %[[A]] + // CHECK: %[[UPDATED_B:.*]] = stablehlo.transpose %[[B]] + // CHECK: stablehlo.dot %[[UPDATED_A]], %[[UPDATED_B]] + %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> + + func.return %0 : tensor<5x11xf32> +} + +// Verify that MatMul with ranked inputs are lowered to HLO. +// CHECK-LABEL: matmul_ranked +func.func @matmul_ranked(%a: tensor, %b: tensor<7x?xf32>) -> tensor { + // CHECK: stablehlo.dot + %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor, tensor<7x?xf32>) -> tensor + + func.return %0 : tensor +} + +// Verify SparseMatMul is legalized to dot. +// CHECK-LABEL: test_sparse_mat_mul +func.func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> { + // CHECK: stablehlo.dot + %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> + func.return %0: tensor<3x5xf32> +} + +// SparseMatMul where one operand needs to be transposed and the other one not. +// +// CHECK-LABEL: @test_sparse_mat_mul_with_transpose + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> + // CHECK-SAME: -> tensor<3x5xf32> + // CHECK: %[[TRANSPOSE:.*]] = stablehlo.transpose %[[ARG1]] + // CHECK-SAME: dims = [1, 0] + // CHECK-SAME: -> tensor<4x5xf32> + // CHECK: %[[RESULT:.*]] = stablehlo.dot %[[ARG0]], %[[TRANSPOSE]] + // CHECK-SAME: -> tensor<3x5xf32> + // CHECK: return %[[RESULT]] +func.func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { + %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> + func.return %0: tensor<3x5xf32> +} + +// SparseMatMul where one operand needs to be casted and the other one not. +// +// CHECK-LABEL: @test_sparse_mat_mul_with_cast + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> + // CHECK-SAME: -> tensor<3x5xf32> + // CHECK: %[[CAST:.*]] = stablehlo.convert %[[ARG1]] + // CHECK-SAME: -> tensor<4x5xf32> + // CHECK: %[[RESULT:.*]] = stablehlo.dot %[[ARG0]], %[[CAST]] + // CHECK-SAME: -> tensor<3x5xf32> + // CHECK: return %[[RESULT]] +func.func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { + %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> + func.return %0: tensor<3x5xf32> +} + +//===----------------------------------------------------------------------===// +// MaxPool op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: maxpool_valid_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: %[[INIT:.*]] = stablehlo.constant dense<-2147483648> : tensor + // CHECK: "stablehlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK-SAME: <{window_dimensions = array, window_strides = array}> + // CHECK: stablehlo.maximum + // CHECK: stablehlo.return + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + func.return %0 : tensor<2x3x5x7xi32> +} + +// ----- + +// CHECK-LABEL: maxpool_same_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> + func.return %0 : tensor<2x4x7x7xi32> +} + +// ----- + +// CHECK-LABEL: maxpool_3d_valid_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { + // CHECK: %[[INIT:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: "stablehlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK-SAME: <{window_dimensions = array, window_strides = array}> + // CHECK: stablehlo.maximum + // CHECK: stablehlo.return + + %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> + func.return %0 : tensor<2x8x3x5x7xf32> +} + +// ----- + +// CHECK-LABEL: maxpool_3d_same_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + + %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> + func.return %0 : tensor<2x8x4x7x7xf32> +} + +// ----- + +// CHECK-LABEL: maxpool_explicit_padding +func.func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: tf.MaxPool + // TODO(b/165938852): need to support explicit padding in max_pool. + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + func.return %0 : tensor<2x3x5x7xi32> +} + +//===----------------------------------------------------------------------===// +// MaxPoolGrad op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @max_pool_grad_valid +// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> +func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { + // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "stablehlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) + // CHECK-SAME: <{window_dimensions = array, window_strides = array}> ({ + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }, { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + ksize = [1, 2, 2, 1], + padding = "VALID", + strides = [1, 2, 2, 1] + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> + func.return %result : tensor<10x24x24x64xf32> +} + +// ----- + +// CHECK-LABEL: @max_pool_3d_grad_valid +// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> +func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { + // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "stablehlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) + // CHECK-SAME: <{window_dimensions = array, window_strides = array}> ({ + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }, { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }) : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> + // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> + %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> + func.return %result : tensor<10x8x24x24x64xf32> +} + +// ----- + +// CHECK-LABEL: @max_pool_grad_same +func.func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + ksize = [1, 2, 3, 1], + padding = "SAME", + strides = [1, 4, 4, 1] + } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> + func.return %result : tensor<2x13x25x7xf32> +} + +// ----- + +// CHECK-LABEL: @max_pool_3d_grad_same +func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> + func.return %result : tensor<2x8x13x25x7xf32> +} + +//===----------------------------------------------------------------------===// +// OneHot op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL:one_hot +func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { + // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<5xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<5xi32>) -> tensor<3x5xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<3xi32>) -> tensor<3x5xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare EQ, %[[BCAST1]], %[[BCAST0]], NOTYPE : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast %arg1, sizes = [3, 5] : (tensor) -> tensor<3x5xf32> + // CHECK-NEXT: %[[BCAST3:.*]] = stablehlo.broadcast %arg2, sizes = [3, 5] : (tensor) -> tensor<3x5xf32> + // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.select %[[CMP0]], %[[BCAST2]], %[[BCAST3]] : tensor<3x5xi1>, tensor<3x5xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<3x5xf32> + %depth = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<3x5xf32> + func.return %result : tensor<3x5xf32> +} + +//===----------------------------------------------------------------------===// +// tf.OutfeedEnqueueTuple legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @outfeed_enqueue_tuple +// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) +func.func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { + // CHECK: [[TOKEN:%.*]] = stablehlo.create_token : !stablehlo.token + // CHECK: "stablehlo.outfeed"([[VAL_0]], [[VAL_1]], [[TOKEN]]) <{outfeed_config = ""}> : (tensor<3xi32>, tensor<4xf32>, !stablehlo.token) -> !stablehlo.token + "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () + func.return +} + +//===----------------------------------------------------------------------===// +// Pack op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @pack +func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { + // CHECK: stablehlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> + // CHECK: stablehlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> + // CHECK: stablehlo.concatenate {{.*}}, {{.*}}, dim = 0 + + %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + func.return %0 : tensor<2x2xi32> +} + +//===----------------------------------------------------------------------===// +// PartitionedCall op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @partitioned_call +func.func @partitioned_call(%arg0: tensor) -> tensor { + // CHECK: call @pcall_func(%arg0) : (tensor) -> tensor + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor) -> (tensor) + func.return %0 : tensor +} + + +func.func @pcall_func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func @partitioned_call_multi_input +func.func @partitioned_call_multi_input(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor, tensor) -> (tensor) + func.return %0 : tensor +} + + +func.func @pcall_multi_input(%arg0: tensor, %arg1: tensor) -> tensor { + func.return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func @partitioned_call_multi_in_out +func.func @partitioned_call_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor, tensor) -> (tensor, tensor) + func.return %0, %1 : tensor, tensor +} + + +func.func @pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + func.return %arg1, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @unhandled_partitioned_call +func.func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor, tensor) { + // The argument types don't match the parameter types for the + // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not + // for a standard CallOp, so this op can't be lowered. + // CHECK: "tf.PartitionedCall" + %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor, tensor) + func.return %0, %1 : tensor, tensor +} + + +// CHECK-LABEL: func @unhandled_partitioned_call_2 +func.func @unhandled_partitioned_call_2(%arg0: tensor, %arg1: tensor<*xi32>) -> (tensor, tensor) { + // CHECK: "tf.PartitionedCall" + %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor, tensor<*xi32>) -> (tensor, tensor) + func.return %0, %1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @no_args_and_results +func.func @no_args_and_results() { + // CHECK: call @callee() : () -> () + // CHECK: call @callee() : () -> () + // CHECK: call @callee() : () -> () + "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + "tf.LegacyCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + func.return +} + +func.func @callee() { + func.return +} + +//===----------------------------------------------------------------------===// +// ReverseV2 op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @reverse_func_32 +func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [0] : tensor<5xi32> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + func.return %reversed : tensor<5xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_func_64 +func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) + + // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [0] : tensor<5xi32> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + func.return %reversed : tensor<5xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_func_neg +func.func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { + %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [1] : tensor<5x5xi32> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> + + // CHECK: return [[VAL]] : tensor<5x5xi32> + func.return %reversed : tensor<5x5xi32> +} + +//===----------------------------------------------------------------------===// +// StatefulPartitionedCall op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @stateful_partitioned_call +// CHECK-SAME: [[ARG:%.+]]: tensor +func.func @stateful_partitioned_call(%arg0: tensor) -> tensor { + // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor) -> tensor + %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @stateful_pcall_func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) +func.func @stateful_partitioned_call_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor, tensor) -> (tensor, tensor) + %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor, tensor) -> (tensor, tensor) + func.return %0, %1 : tensor, tensor +} + +func.func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + func.return %arg1, %arg0 : tensor, tensor +} + +//===----------------------------------------------------------------------===// +// Elu op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @elu +func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<1xf32> + // CHECK-DAG: %[[PRED:.*]] = stablehlo.compare GT, %arg0, %[[ZERO]] + // CHECK-DAG: %[[EXP:.*]] = stablehlo.exponential_minus_one %arg0 + // CHECK: %[[RESULT:.*]] = stablehlo.select %[[PRED]], %arg0, %[[EXP]] + // CHECK: return %[[RESULT]] + %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + func.return %0: tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: func @elu_grad +// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) +func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { + // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZERO]], {{.*}}, dims = [] : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[PRED:.*]] = stablehlo.compare GT, %[[FEATURES]], %[[BCAST0]] + // CHECK-DAG: %[[BCAST1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ONE]], {{.*}}, dims = [] : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ADD1:.*]] = stablehlo.add %[[FEATURES]], %[[BCAST1]] + // CHECK-DAG: %[[MULGRAD:.*]] = stablehlo.multiply %[[GRADIENTS]], %[[ADD1]] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + // CHECK: %[[RESULT:.*]] = stablehlo.select %[[PRED]], %[[GRADIENTS]], %[[MULGRAD]] + // CHECK: return %[[RESULT]] + %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + func.return %2 : tensor<4x8xf32> +} + +//===----------------------------------------------------------------------===// +// Relu op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @relu +func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor<1xi32> + // CHECK: stablehlo.maximum %arg0, %[[ZERO]] + %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @relu_unsigned +func.func @relu_unsigned(%arg0: tensor) -> tensor { + // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor + // CHECK: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZERO]], {{.*}}, dims = [] + // CHECK: stablehlo.maximum %arg0, %[[BCAST0]] + %0 = "tf.Relu"(%arg0) : (tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @relu6 +func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor + // CHECK-DAG: %[[SIX:.*]] = stablehlo.constant dense<6> : tensor + // CHECK: stablehlo.clamp %[[ZERO]], %arg0, %[[SIX]] + %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @relu6_unsigned +func.func @relu6_unsigned(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor + // CHECK-DAG: %[[SIX:.*]] = stablehlo.constant dense<6> : tensor + // CHECK: stablehlo.clamp %[[ZERO]], %arg0, %[[SIX]] + %0 = "tf.Relu6"(%arg0) : (tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @leaky_relu +func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} { + // CHECK-NEXT: %[[ALPHA:.*]] = stablehlo.constant dense<2.000000e-01> + // CHECK-NEXT: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> + // CHECK-NEXT: %[[LEAKY:.*]] = stablehlo.multiply %arg0, %[[ALPHA]] + // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare GT, %arg0, %[[ZERO]] + // CHECK-NEXT: %[[RES:.*]] = stablehlo.select %[[CMP]], %arg0, %[[LEAKY]] + // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> + %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> + func.return %0 : tensor<1x4x4x3xf32> +} + +// ----- + +// CHECK-LABEL: func @leaky_relu_grad +func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} { + // CHECK-NEXT: %[[ALPHA:.*]] = stablehlo.constant dense<2.000000e-01> + // CHECK-NEXT: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> + // CHECK-NEXT: %[[LEAKYGRAD:.*]] = stablehlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] + // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE + // CHECK-NEXT: %[[RES:.*]] = stablehlo.select %[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]] + // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> + %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> + func.return %0 : tensor<1x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @softsign +func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { + // CHECK-NEXT: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> + // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %{{.*}} + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[ABS]], %[[ONE]] + // CHECK-NEXT: %[[DIV:.*]] = stablehlo.divide %{{.*}}, %[[ADD]] + // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32> + %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32> + func.return %0 : tensor<4x10xf32> +} + +// ----- + +// CHECK-LABEL: func @softsign_grad +func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> { + + // CHECK-NEXT: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> + // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %{{.*}} : tensor<4x10xf32> + // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = stablehlo.add %[[ABS]], %[[ONE]] + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] + // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = stablehlo.divide %{{.*}}, %[[MUL]] + // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> + %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> + func.return %0 : tensor<4x10xf32> +} + +//===----------------------------------------------------------------------===// +// Roll op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @Roll_0D +func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor) -> tensor<512xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor} : () -> (tensor) + // CHECK: %[[AXIS_SIZE:.*]] = stablehlo.constant dense<512> : tensor + // CHECK: %[[T1:.+]] = stablehlo.remainder %arg1, %[[AXIS_SIZE]] : tensor + // CHECK: %[[T2:.+]] = stablehlo.add %[[T1]], %[[AXIS_SIZE]] : tensor + // CHECK: %[[T3:.+]] = stablehlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor + // CHECK: %[[CONCAT:.+]] = stablehlo.concatenate %arg0, %arg0, dim = 0 + // CHECK: %[[OFFSET:.+]] = stablehlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor + // CHECK: stablehlo.dynamic_slice %[[CONCAT]], %[[OFFSET]], sizes = [512] + %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor, tensor) -> tensor<512xi32> + func.return %0 : tensor<512xi32> +} + +//===----------------------------------------------------------------------===// +// Select op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @select_batch_static +func.func @select_batch_static(%arg0: tensor<2xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { + // CHECK: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0] + // CHECK: stablehlo.select %[[BCAST]], %arg1, %arg2 + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> + func.return %0: tensor<2x6x8xi32> +} + +// ----- + +// CHECK-LABEL: func @select_batch_static_r1 +func.func @select_batch_static_r1(%arg0: tensor, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { + // CHECK: stablehlo.select %arg0, %arg1, %arg2 + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> + func.return %0: tensor<2x6x8xi32> +} + +// ----- + +// CHECK-LABEL: func @select_batch_static_all_same +func.func @select_batch_static_all_same(%arg0: tensor<2x6x8xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { + // CHECK: stablehlo.select %arg0, %arg1, %arg2 + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x6x8xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> + func.return %0: tensor<2x6x8xi32> +} + +// ----- + +// CHECK-LABEL: func @select_batch_dynamic_r1 +func.func @select_batch_dynamic_r1(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index + // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<1xindex> + // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> + // CHECK-NEXT: %[[HEAD:.*]], %[[TAIL:.*]] = "shape.split_at"(%[[SHAPE1]], %[[C1]]) : (tensor<3xindex>, index) -> (tensor<1xindex>, tensor<2xindex>) + // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<1xindex> + // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] + // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor) { + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, %[[SHAPE1]], dims = [0] + // CHECK-NEXT: %[[SELECT:.*]] = stablehlo.select %[[BCAST]], %arg1, %arg2 : tensor, tensor + // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @select_batch_dynamic +func.func @select_batch_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ3:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming %[[SHAPEEQ3]] + // CHECK-NEXT: %[[SELECT:.*]] = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor + // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: testSelectInvalidUnranked +func.func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> + func.return %0: tensor<*xf16> +} + +// ----- + +// CHECK-LABEL: testSelectThenUnranked +func.func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> + func.return %0: tensor<*xf16> +} + +// ----- + +// CHECK-LABEL: testSelectElseUnranked +func.func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> + func.return %0: tensor<*xf16> +} + +// ----- + +// CHECK-LABEL: func @selectv2_dynamic_ranked +func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { + // CHECK: stablehlo.select + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> + func.return %0: tensor<2x?x8xi32> +} + +//===----------------------------------------------------------------------===// +// Fast Fourier Transform op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @fft_1D +func.func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: stablehlo.fft %arg0, type = FFT, length = [8] + %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +// ----- + +// CHECK-LABEL: func @ifft_1D +func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: stablehlo.fft %arg0, type = IFFT, length = [8] + %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +// ----- + +// CHECK-LABEL: func @rfft_1D +func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: stablehlo.fft %arg0, type = RFFT, length = [8] + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<5xcomplex> + func.return %0 : tensor<5xcomplex> +} + +// ----- + +// CHECK-LABEL: func @rfft_1D_padded +func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[PADDED:.*]] = stablehlo.pad %arg0, %{{.*}}, low = [0], high = [1], interior = [0] + // CHECK: stablehlo.fft %[[PADDED]], type = RFFT, length = [8] + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<5xcomplex> + func.return %0 : tensor<5xcomplex> +} + +// ----- + +// CHECK-LABEL: func @rfft_1D_sliced +func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = stablehlo.slice %arg0 [0:2, 0:8] + // CHECK: stablehlo.fft %[[SLICED]], type = RFFT, length = [8] + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x5xcomplex> + func.return %0 : tensor<2x5xcomplex> +} + +// ----- + +// CHECK-LABEL: func @irfft_1D +func.func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xf32> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = stablehlo.slice %arg0 [0:5] + // CHECK: stablehlo.fft %[[SLICED]], type = IRFFT, length = [8] + %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex>, tensor<1xi32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: fft_1D_dynamic +func.func @fft_1D_dynamic(%arg0: tensor>) -> tensor<8xcomplex> { + // CHECK: "tf.FFT" + %0 = "tf.FFT"(%arg0) : (tensor>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +// ----- + +// CHECK-LABEL: rfft_1D_dynamic +func.func @rfft_1D_dynamic(%arg0: tensor) -> tensor<8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: "tf.RFFT" + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor, tensor<1xi32>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +//===----------------------------------------------------------------------===// +// Shape op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @shape_1D +func.func @shape_1D(%arg0: tensor) -> tensor<1xi32> { + // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 + // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<1xindex> to tensor<1xi32> + %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<1xi32> + + // CHECK: return [[TENSOR]] + func.return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @shape_2D +func.func @shape_2D(%arg0: tensor) -> tensor<2xi32> { + // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 + // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<2xindex> to tensor<2xi32> + %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> + + // CHECK: return [[TENSOR]] + func.return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: func @shape_rankless +func.func @shape_rankless(%arg0: tensor<*xf32>) -> tensor { + // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 + // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor to tensor + %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor + + // CHECK: return [[TENSOR]] + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// Transpose op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @transpose_noop +func.func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>) + // CHECK: return %arg0 + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32> + func.return %0 : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_2d +func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) + // CHECK: stablehlo.transpose + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> + func.return %0 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_3d_int32 +func.func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { + %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>) + // CHECK: stablehlo.transpose + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> + func.return %0 : tensor<3x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_3d +func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { + %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>) + // CHECK: stablehlo.transpose + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> + func.return %0 : tensor<3x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_dynamic_2d +func.func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { + %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) + // CHECK: stablehlo.transpose + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> + func.return %0 : tensor<4x?xf32> +} + +//===----------------------------------------------------------------------===// +// Unary op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @abs +func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.abs %arg0 : tensor<2xf32> + %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @abs_dynamic +func.func @abs_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.abs %arg0 : tensor + %0 = "tf.Abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @acos +func.func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: %[[TEMP_0:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> + // CHECK: %[[TEMP_1:.*]] = stablehlo.subtract %[[TEMP_0]], %arg0 : tensor<2xf32> + // CHECK: %[[TEMP_2:.*]] = stablehlo.add %arg0, %[[TEMP_0]] : tensor<2xf32> + // CHECK: %[[TEMP_3:.*]] = stablehlo.multiply %[[TEMP_1]], %[[TEMP_2]] : tensor<2xf32> + // CHECK: %[[TEMP_4:.*]] = stablehlo.sqrt %[[TEMP_3]] : tensor<2xf32> + // CHECK: %[[TEMP_5:.*]] = stablehlo.atan2 %[[TEMP_4]], %arg0 : tensor<2xf32> + // CHECK: return %[[TEMP_5]] : tensor<2xf32> + %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @acos_complex +func.func @acos_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { +// CHECK-NEXT: %[[TEMP_1:.*]] = stablehlo.constant dense<4.33680869E-19> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_2:.*]] = stablehlo.constant dense<0.693147182> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_3:.*]] = stablehlo.constant dense<2.30584283E+20> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_4:.*]] = stablehlo.constant dense<2.30584274E+12> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_5:.*]] = stablehlo.constant dense<2.30584285E+30> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_6:.*]] = stablehlo.constant dense<1.41421354> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_7:.*]] = stablehlo.constant dense<2.30584287E+18> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_8:.*]] = stablehlo.constant dense<1.500000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_9:.*]] = stablehlo.constant dense<0x7F800000> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_10:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_11:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_12:.*]] = stablehlo.constant dense<5.000000e-01> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_13:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_14:.*]] = stablehlo.real %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK-NEXT: %[[TEMP_15:.*]] = stablehlo.abs %[[TEMP_14]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_16:.*]] = stablehlo.imag %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK-NEXT: %[[TEMP_17:.*]] = stablehlo.abs %[[TEMP_16]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_18:.*]] = stablehlo.maximum %[[TEMP_15]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_19:.*]] = stablehlo.compare GE, %[[TEMP_18]], %[[TEMP_7]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_20:.*]] = stablehlo.compare LE, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_21:.*]] = stablehlo.add %[[TEMP_15]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_22:.*]] = stablehlo.abs %[[TEMP_21]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_23:.*]] = stablehlo.maximum %[[TEMP_22]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_24:.*]] = stablehlo.minimum %[[TEMP_22]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_25:.*]] = stablehlo.compare EQ, %[[TEMP_23]], %[[TEMP_24]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_26:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_6]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_27:.*]] = stablehlo.divide %[[TEMP_24]], %[[TEMP_23]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_28:.*]] = stablehlo.multiply %[[TEMP_27]], %[[TEMP_27]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_29:.*]] = stablehlo.add %[[TEMP_28]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_30:.*]] = stablehlo.sqrt %[[TEMP_29]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_31:.*]] = stablehlo.compare EQ, %[[TEMP_30]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_32:.*]] = stablehlo.compare GT, %[[TEMP_28]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_33:.*]] = stablehlo.and %[[TEMP_31]], %[[TEMP_32]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_34:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_28]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_35:.*]] = stablehlo.divide %[[TEMP_34]], %[[TEMP_11]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_36:.*]] = stablehlo.add %[[TEMP_23]], %[[TEMP_35]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_37:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_30]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_38:.*]] = stablehlo.select %[[TEMP_33]], %[[TEMP_36]], %[[TEMP_37]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_39:.*]] = stablehlo.select %[[TEMP_25]], %[[TEMP_26]], %[[TEMP_38]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_40:.*]] = stablehlo.subtract %[[TEMP_15]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_41:.*]] = stablehlo.abs %[[TEMP_40]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_42:.*]] = stablehlo.maximum %[[TEMP_41]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_43:.*]] = stablehlo.minimum %[[TEMP_41]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_44:.*]] = stablehlo.compare EQ, %[[TEMP_42]], %[[TEMP_43]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_45:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_6]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_46:.*]] = stablehlo.divide %[[TEMP_43]], %[[TEMP_42]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_47:.*]] = stablehlo.multiply %[[TEMP_46]], %[[TEMP_46]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_48:.*]] = stablehlo.add %[[TEMP_47]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_49:.*]] = stablehlo.sqrt %[[TEMP_48]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_50:.*]] = stablehlo.compare EQ, %[[TEMP_49]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_51:.*]] = stablehlo.compare GT, %[[TEMP_47]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_52:.*]] = stablehlo.and %[[TEMP_50]], %[[TEMP_51]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_53:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_47]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_54:.*]] = stablehlo.divide %[[TEMP_53]], %[[TEMP_11]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_55:.*]] = stablehlo.add %[[TEMP_42]], %[[TEMP_54]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_56:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_49]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_57:.*]] = stablehlo.select %[[TEMP_52]], %[[TEMP_55]], %[[TEMP_56]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_58:.*]] = stablehlo.select %[[TEMP_44]], %[[TEMP_45]], %[[TEMP_57]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_59:.*]] = stablehlo.add %[[TEMP_39]], %[[TEMP_58]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_60:.*]] = stablehlo.multiply %[[TEMP_59]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_61:.*]] = stablehlo.add %[[TEMP_60]], %[[TEMP_15]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_62:.*]] = stablehlo.multiply %[[TEMP_61]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_63:.*]] = stablehlo.multiply %[[TEMP_17]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_64:.*]] = stablehlo.add %[[TEMP_39]], %[[TEMP_21]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_65:.*]] = stablehlo.divide %[[TEMP_63]], %[[TEMP_64]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_66:.*]] = stablehlo.subtract %[[TEMP_58]], %[[TEMP_40]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_67:.*]] = stablehlo.add %[[TEMP_65]], %[[TEMP_66]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_68:.*]] = stablehlo.multiply %[[TEMP_62]], %[[TEMP_67]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_69:.*]] = stablehlo.sqrt %[[TEMP_68]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_70:.*]] = stablehlo.divide %[[TEMP_62]], %[[TEMP_64]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_71:.*]] = stablehlo.add %[[TEMP_58]], %[[TEMP_40]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_72:.*]] = stablehlo.divide %[[TEMP_62]], %[[TEMP_71]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_73:.*]] = stablehlo.add %[[TEMP_70]], %[[TEMP_72]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_74:.*]] = stablehlo.sqrt %[[TEMP_73]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_75:.*]] = stablehlo.multiply %[[TEMP_17]], %[[TEMP_74]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_76:.*]] = stablehlo.select %[[TEMP_20]], %[[TEMP_69]], %[[TEMP_75]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_77:.*]] = stablehlo.select %[[TEMP_19]], %[[TEMP_17]], %[[TEMP_76]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_78:.*]] = stablehlo.compare LT, %[[TEMP_15]], %[[TEMP_5]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_79:.*]] = stablehlo.select %[[TEMP_78]], %[[TEMP_4]], %[[TEMP_3]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_80:.*]] = stablehlo.compare GE, %[[TEMP_17]], %[[TEMP_79]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_81:.*]] = stablehlo.select %[[TEMP_80]], %[[TEMP_17]], %[[TEMP_15]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_82:.*]] = stablehlo.select %[[TEMP_80]], %[[TEMP_79]], %[[TEMP_7]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_83:.*]] = stablehlo.compare GE, %[[TEMP_81]], %[[TEMP_82]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_84:.*]] = stablehlo.log %[[TEMP_81]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_85:.*]] = stablehlo.add %[[TEMP_84]], %[[TEMP_2]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_86:.*]] = stablehlo.compare EQ, %[[TEMP_17]], %[[TEMP_9]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_87:.*]] = stablehlo.not %[[TEMP_86]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_88:.*]] = stablehlo.and %[[TEMP_80]], %[[TEMP_87]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_89:.*]] = stablehlo.divide %[[TEMP_15]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_90:.*]] = stablehlo.select %[[TEMP_88]], %[[TEMP_89]], %[[TEMP_10]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_91:.*]] = stablehlo.multiply %[[TEMP_90]], %[[TEMP_90]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_92:.*]] = stablehlo.log_plus_one %[[TEMP_91]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_93:.*]] = stablehlo.multiply %[[TEMP_92]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_94:.*]] = stablehlo.add %[[TEMP_85]], %[[TEMP_93]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_95:.*]] = stablehlo.compare LT, %[[TEMP_17]], %[[TEMP_1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_96:.*]] = stablehlo.compare LT, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_97:.*]] = stablehlo.and %[[TEMP_95]], %[[TEMP_96]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_98:.*]] = stablehlo.multiply %[[TEMP_21]], %[[TEMP_40]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_99:.*]] = stablehlo.add %[[TEMP_60]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_100:.*]] = stablehlo.divide %[[TEMP_98]], %[[TEMP_99]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_101:.*]] = stablehlo.negate %[[TEMP_100]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_102:.*]] = stablehlo.compare GE, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_103:.*]] = stablehlo.multiply %[[TEMP_63]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_104:.*]] = stablehlo.divide %[[TEMP_103]], %[[TEMP_64]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_105:.*]] = stablehlo.multiply %[[TEMP_71]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_106:.*]] = stablehlo.add %[[TEMP_104]], %[[TEMP_105]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_107:.*]] = stablehlo.compare LE, %[[TEMP_60]], %[[TEMP_8]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_108:.*]] = stablehlo.divide %[[TEMP_103]], %[[TEMP_66]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_109:.*]] = stablehlo.add %[[TEMP_104]], %[[TEMP_108]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_110:.*]] = stablehlo.subtract %[[TEMP_60]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_111:.*]] = stablehlo.select %[[TEMP_107]], %[[TEMP_109]], %[[TEMP_110]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_112:.*]] = stablehlo.select %[[TEMP_102]], %[[TEMP_106]], %[[TEMP_111]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_113:.*]] = stablehlo.select %[[TEMP_97]], %[[TEMP_101]], %[[TEMP_112]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_114:.*]] = stablehlo.multiply %[[TEMP_113]], %[[TEMP_99]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_115:.*]] = stablehlo.sqrt %[[TEMP_114]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_116:.*]] = stablehlo.divide %[[TEMP_17]], %[[TEMP_115]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_117:.*]] = stablehlo.add %[[TEMP_113]], %[[TEMP_115]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_118:.*]] = stablehlo.log_plus_one %[[TEMP_117]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_119:.*]] = stablehlo.select %[[TEMP_97]], %[[TEMP_116]], %[[TEMP_118]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_120:.*]] = stablehlo.select %[[TEMP_83]], %[[TEMP_94]], %[[TEMP_119]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_121:.*]] = stablehlo.real %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK-NEXT: %[[TEMP_122:.*]] = stablehlo.atan2 %[[TEMP_77]], %[[TEMP_121]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_123:.*]] = stablehlo.imag %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK-NEXT: %[[TEMP_124:.*]] = stablehlo.compare LT, %[[TEMP_123]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_125:.*]] = stablehlo.negate %[[TEMP_120]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_126:.*]] = stablehlo.select %[[TEMP_124]], %[[TEMP_120]], %[[TEMP_125]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_127:.*]] = stablehlo.complex %[[TEMP_122]], %[[TEMP_126]] : tensor<2xcomplex> +// CHECK-NEXT: return %[[TEMP_127]] : tensor<2xcomplex> + + %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> + func.return %0 : tensor<2xcomplex> +} + +// ----- + +// CHECK-LABEL: @acos_dynamic +func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "tf.Acos" + %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: func @cast_dynamic_i2f +func.func @cast_dynamic_i2f(%arg0: tensor) -> tensor { + // CHECK: stablehlo.convert %arg0 : (tensor) -> tensor + %0 = "tf.Cast"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @cast_i2f +func.func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> { + // CHECK: stablehlo.convert %arg0 : (tensor<2xi32>) -> tensor<2xf32> + %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @cast_c2f +func.func @cast_c2f(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { + // CHECK: stablehlo.convert %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> + %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @ceil +func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.ceil %arg0 : tensor<2xf32> + %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @ceil_dynamic +func.func @ceil_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.ceil %arg0 : tensor + %0 = "tf.Ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @complex_abs +func.func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { + // CHECK: stablehlo.abs %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> + %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @cos +func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.cosine %arg0 : tensor<2xf32> + %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @tan +func.func @tan(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.tan %arg0 : tensor<2xf32> + %0 = "tf.Tan"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @cos_dynamic +func.func @cos_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.cosine %arg0 : tensor + %0 = "tf.Cos"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @exp +func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.exponential %arg0 : tensor<2xf32> + %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @expm1 +func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.exponential_minus_one %arg0 : tensor<2xf32> + %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @exp_dynamic +func.func @exp_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.exponential %arg0 : tensor + %0 = "tf.Exp"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @floor +func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.floor %arg0 : tensor<2xf32> + %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @floor_dynamic +func.func @floor_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.floor %arg0 : tensor + %0 = "tf.Floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @is_finite +func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { + // CHECK: stablehlo.is_finite %arg0 : (tensor<2xf32>) -> tensor<2xi1> + %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// ----- + +// CHECK-LABEL: func @is_finite_dynamic +func.func @is_finite_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.is_finite %arg0 : (tensor) -> tensor + %0 = "tf.IsFinite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @log +func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.log %arg0 : tensor<2xf32> + %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @log_dynamic +func.func @log_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.log %arg0 : tensor + %0 = "tf.Log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @log1p +func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.log_plus_one %arg0 : tensor<2xf32> + %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @log1p_dynamic +func.func @log1p_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.log_plus_one %arg0 : tensor + %0 = "tf.Log1p"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @neg +func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.negate %arg0 : tensor<2xf32> + %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @neg_dynamic +func.func @neg_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.negate %arg0 : tensor + %0 = "tf.Neg"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @sigmoid +func.func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.logistic + %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @sigmoid_complex +func.func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHECK: stablehlo.logistic + %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> + func.return %0 : tensor<2xcomplex> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc new file mode 100644 index 00000000000000..0e7f1744d5fb63 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc @@ -0,0 +1,6911 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering TensorFlow dialect to XLA dialect. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" +#include "xla/client/lib/conv_grad_size_util.h" +#include "xla/client/padding.h" +#include "xla/client/sharding_builder.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/utils/convert_op_folder.h" +#include "xla/mlir_hlo/utils/hlo_utils.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/rng_alg.h" +#include "tensorflow/core/kernels/conv_grad_shape_utils.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tsl/platform/bfloat16.h" +#include "tsl/platform/status.h" +#include "tsl/platform/tensor_float_32_utils.h" + +namespace mlir { +// Keep this in the mlir namespace to allow the use of the mhlo ops. +namespace mhlo { +namespace { + +// The utils are copied into the odml namespace to avoid duplicate names and +// they are imported here to avoid having to change the code below. +using ::mlir::odml::BuildReduceBody; +using ::mlir::odml::GetI64ElementsAttr; +using ::mlir::odml::GetScalarConstOfType; +using ::mlir::odml::GetScalarNegZeroOfType; + +constexpr char kShardingAttr[] = "mhlo.sharding"; + +/// Returns the feature dimension for the given format and input type. +static size_t GetFeatureDimension(tensorflow::TensorFormat format, + RankedTensorType input_ty) { + return GetTensorFeatureDimIndex(input_ty.getRank(), format); +} + +// Gets all integer values from the given attribute and push them to `values`. +void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl *values) { + auto array_attr = mlir::cast(attr); + values->reserve(array_attr.getValue().size()); + for (Attribute val : array_attr.getValue()) + values->push_back(mlir::cast(val).getValue().getSExtValue()); +} + +// Returns 1D 32-bit dense elements attribute with the given values. +static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = tensorflow::GetTypeFromTFTensorShape( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +// Returns a 1-d i64 elements attribute populated with numbers from start to +// end, excluding. +static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, + Builder *builder) { + int size = end - start; + + SmallVector vals; + vals.resize(size); + std::iota(vals.begin(), vals.end(), start); + + TensorType ty = + tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, vals); +} + +// Returns a 1-d i64 elements attribute populated with `val` repeated `size` +// times. +static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val, + Builder *builder) { + TensorType ty = + tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, val); +} + +// Returns the corresponding type that should be used for performing sum +// accumulation over the given input type. +Type GetSumAccumulationType(Type input_type) { + MLIRContext *ctx = input_type.getContext(); + if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx); + if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16)) + return IntegerType::get(ctx, 32); + return input_type; +} + +// Returns axis in HLO format from TF elements attr with exactly one element or +// is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow +// format supports negative indexing unlike HLO. +static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, + Builder *b) { + IntegerAttr intAttr = mlir::dyn_cast_or_null(attr); + if (auto elementAttr = mlir::dyn_cast_or_null(attr)) { + SmallVector index(elementAttr.getShapedType().getRank(), 0); + intAttr = elementAttr.getValues()[index]; + } + + assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis"); + + int64_t axis = intAttr.getInt(); + if (axis < 0) { + axis += rank; + } + return b->getI64IntegerAttr(axis); +} + +// Returns a PrecisionConfig as an array attribute based on whether TF32 +// execution is enabled +static ArrayAttr GetPrecisionConfig(Builder *builder) { + mlir::mhlo::Precision precision = tsl::tensor_float_32_execution_enabled() + ? mhlo::Precision::DEFAULT + : mlir::mhlo::Precision::HIGHEST; + llvm::SmallVector attr_vec; + const int num_inputs = 2; + for (int i = 0; i < num_inputs; i++) { + attr_vec.push_back( + mlir::mhlo::PrecisionAttr::get(builder->getContext(), precision)); + } + return builder->getArrayAttr(attr_vec); +} + +// If `value` is an IntegerAttr, returns the integer value for the HLO axis +// corresponding to the tensorflow axis. In particular, the tensorflow axis can +// be negative, in which case, the corresponding HLO axis is +// (axis + rank-of-the-tensor). +static std::optional GetIntegerHLOAxisFromTFAxis(Value value, + int64_t rank) { + DenseIntElementsAttr attrs; + if (!matchPattern(value, m_Constant(&attrs)) || + attrs.getType().getRank() != 0) { + return std::nullopt; + } + int64_t axis = attrs.getValues()[0].getInt(); + return axis < 0 ? axis + rank : axis; +} + +/// Returns a `ConvertOp` that casts the elements to a i64 type while retaining +/// the shape of the input value. +static ConvertOp CastValueToI64(Location loc, Value value, + PatternRewriter *rewriter) { + return rewriter->create(loc, value, rewriter->getIntegerType(64)); +} + +// Creates an unpack op along the 0th dimension of the tensor. The `value` input +// must be a ranked tensor. +static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, + PatternRewriter *rewriter) { + auto indices_type = mlir::cast(value.getType()); + int num_outputs = indices_type.getShape().front(); + SmallVector unpacked_indices_type( + num_outputs, + tensorflow::GetTypeFromTFTensorShape({}, indices_type.getElementType())); + auto unpacked_indices = rewriter->create( + loc, unpacked_indices_type, value, + IntegerAttr::get(rewriter->getIntegerType(64), 0)); + return unpacked_indices; +} + +// Returns size of dimension at the specified index, if ranked tensor. +// Otherwise, returns -1. +// +// Aborts if the type is ranked but doesn't have the dimension. +int64_t GetDimSize(Type ty, int64_t index) { + RankedTensorType ranked_ty = mlir::dyn_cast(ty); + if (!ranked_ty) return -1; + + return ranked_ty.getDimSize(index); +} + +template +tensorflow::TensorShape ToTensorShape(llvm::ArrayRef sizes) { + return tensorflow::TensorShape( + llvm::SmallVector(sizes.begin(), sizes.end())); +} + +template +tensorflow::TensorShape ToTensorShape( + llvm::iterator_range> sizes) { + return tensorflow::TensorShape( + llvm::SmallVector(sizes.begin(), sizes.end())); +} + +// Returns a limit scalar const op for the given type. +// Requires FloatType or IntegerType +static ConstantOp GetScalarLimitConstOfType(Type ty, Location loc, + hlo::ScalarLimit limit, + OpBuilder *builder) { + return builder->create(loc, hlo::getScalarLimitOfType(ty, limit)); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Gets the resulting type from a broadcast between two types for statically +// shaped types. This is to be used for legacy lowerings that both use non +// left-padded broadcasting and static shapes. Its use should not be permitted +// in new code. +// May return nullptr on invalid static broadcast dimensions. +// ABSL_DEPRECATED() +static RankedTensorType GetStaticBroadcastType( + RankedTensorType x, RankedTensorType y, + DenseIntElementsAttr broadcast_dimensions_attr) { + auto element_type = x.getElementType(); + auto shape_x = x.getShape(); + auto shape_y = y.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector out_shape(shape_x.size()); + for (int i = 0; i < shape_x.size(); i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + out_shape[i] = std::max(x_val, y_val); + } + return tensorflow::GetTypeFromTFTensorShape(out_shape, element_type); + } + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector broadcast_dimensions; + // Explicit broadcast dimensions. + for (const APInt &int_value : broadcast_dimensions_attr) { + broadcast_dimensions.push_back(int_value.getSExtValue()); + } + if (broadcast_dimensions.size() != shape_small.size()) { + return nullptr; + } + llvm::SmallVector out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (const auto &index_pair : llvm::enumerate(broadcast_dimensions)) { + auto old_value = out_shape[index_pair.value()]; + auto new_value = shape_small[index_pair.index()]; + out_shape[index_pair.value()] = std::max(old_value, new_value); + } + return tensorflow::GetTypeFromTFTensorShape(out_shape, element_type); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Applies static binary broadcasting to a binary elementwise op. +// This is a legacy helper to provide general broadcasting support in legacy, +// static shaped code that relies on non-left-padded broadcasting semantics. +template +static Value StaticBinaryBroadcast(Location loc, Value x, Value y, + DenseIntElementsAttr broadcast_dims, + OpBuilder &builder) { + auto x_type = mlir::cast(x.getType()); + auto y_type = mlir::cast(y.getType()); + auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); + if (!result_type) { + emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type + << " with broadcast_dims = " << broadcast_dims; + return nullptr; + } + auto larger_broadcast_dims = + GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); + if (x_type.getRank() < y_type.getRank()) { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, + larger_broadcast_dims); + } + } else { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, + larger_broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, broadcast_dims); + } + } + return builder.create(loc, x, y); +} + +// Gets a 1D tensor type suitable for expressing extents of the given tensor +// value type. If the value type is ranked, the result will be statically +// shaped. Otherwise, it will have a dynamic dimension. +static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { + Builder b(value_type.getContext()); + int64_t dim = value_type.hasRank() ? value_type.getRank() : -1; + return tensorflow::GetTypeFromTFTensorShape({dim}, b.getIndexType()); +} + +// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D +// value (broadcast_from) along that feature dimension. This is a shortcut +// for the cases where a 1D tensor must be broadcast along a specific feature +// dimension, which can vary based on data layout, etc. +// +// The extent of `broadcast_from` dim0 must be equal to the extent of the +// feature_dim of `broadcast_to`. +// +// Example: +// [1x2x3x4], [2], 1 -> [1x2x3x4] +// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for +// consistency. Possibly also rename for clarity. +static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, + Value broadcast_from, int64_t feature_dim, + OpBuilder &builder) { + auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); + auto to_type = mlir::cast(broadcast_to.getType()); + auto result_shape = builder.create(loc, broadcast_to); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create( + loc, result_extents_type, result_shape); + return builder.create( + loc, to_type, broadcast_from, result_extents, broadcast_dims); +} + +// Broadcasts `input` to the shape of `broadcast_to` value following +// TF::BroadcastTo semantics. +// +// Requires that input is a ranked tensor. +// +// TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp +// supports unranked inputs in the lowering. +static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, + OpBuilder &builder) { + auto result_shape = builder.create(loc, broadcast_to); + auto to_type = mlir::cast(broadcast_to.getType()); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create( + loc, result_extents_type, result_shape); + int64_t rank = mlir::cast(input.getType()).getRank(); + auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); + return builder.create( + loc, to_type, input, result_extents, broadcast_dims); +} + +// Builds a set of operations for applying reduction on the input value. A +// tf.sum op is created and will be legalized to tfl ops automatically. +static Value ApplyReduction(Location loc, Value input, + DenseIntElementsAttr reduce_dims, + OpBuilder *builder) { + auto reduce_dims_op = builder->create(loc, reduce_dims); + return builder->create(loc, input, reduce_dims_op, + builder->getBoolAttr(false)); +} + +// Creates a mhlo.rng_uniform op with `builder` to generate `num_elements` +// 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`). +static mhlo::RngOp CreateRngUniform32(Location loc, int num_elements, + int lower_limit, int upper_limit, + OpBuilder *builder) { + auto shape_tensor = builder->create( + loc, GetI64ElementsAttr({num_elements}, builder)); + + auto lower = builder->create( + loc, builder->getI32IntegerAttr(lower_limit)); + auto upper = builder->create( + loc, builder->getI32IntegerAttr(upper_limit)); + + return builder->create(loc, lower, upper, shape_tensor, + ::mlir::mhlo::RngDistribution::UNIFORM); +} + +using WhileBodyFnType = llvm::function_ref old_values, + SmallVectorImpl *new_values, OpBuilder *builder)>; + +// Creates a mhlo.while op with `builder` to loop `num_interations` times, +// each time calling the given `body_fn` on a set of values to generate a new +// set of values. Returns the final set of values via `final_values`. The +// initial set of values is passed in via `init_values`. +// +// This effectively does: +// +// ```c++ +// SmallVector old_values = init_values; +// SmallVector new_values; +// for (int i = 0; i < num_iterations; ++i) { +// body_fn(old_values, &new_values, ...); +// old_values = new_values; +// } +// ``` +// +// Under the hood an induction variable is prepended to values to control the +// number of iterations, but that is transparent to `body_fn`, which does not +// need to care about that. +static void CreateWhile32(Location loc, int num_iterations, + WhileBodyFnType body_fn, ArrayRef init_values, + SmallVectorImpl *final_values, + OpBuilder *builder) { + int value_count = init_values.size() + 1; + + // Prepend a loop induction variable to the initial values. + SmallVector init_values_with_loop_iv; + SmallVector init_types_with_loop_iv; + init_values_with_loop_iv.reserve(value_count); + init_types_with_loop_iv.reserve(value_count); + + // The initial value for the loop induction variable is 0. + init_values_with_loop_iv.push_back( + builder->create(loc, builder->getI32IntegerAttr(0))); + init_values_with_loop_iv.append(init_values.begin(), init_values.end()); + + // Accumulate types of all the init values. + for (const auto &init_value_with_loop_iv : init_values_with_loop_iv) + init_types_with_loop_iv.push_back(init_value_with_loop_iv.getType()); + + // Create the while op. + auto while_op = builder->create(loc, init_types_with_loop_iv, + init_values_with_loop_iv); + auto ivs_count = init_types_with_loop_iv.size(); + + { + OpBuilder::InsertionGuard guard(*builder); + + // Build up the only block in the condition region. + Region &condition = while_op.getCond(); + Block *block = builder->createBlock(&condition); + block->addArguments(init_types_with_loop_iv, + SmallVector(ivs_count, loc)); + + // Get the loop induction variable and compare it against the upper limit. + auto loop_iv = block->getArgument(0); + auto upper_limit = builder->create( + loc, builder->getI32IntegerAttr(num_iterations)); + Value compare = builder->create(loc, loop_iv, upper_limit, + ComparisonDirection::LT); + + builder->create(loc, compare); + } + + { + OpBuilder::InsertionGuard guard(*builder); + + // Build up the only block in the body region. + Region &body = while_op.getBody(); + Block *block = builder->createBlock(&body); + block->addArguments(init_types_with_loop_iv, + SmallVector(ivs_count, loc)); + + SmallVector new_values; // Generated by this iteration + new_values.reserve(value_count); + + // Feed all values excluding the loop induction variable to body_fn. + body_fn(loc, block->getArgument(0), + ArrayRef(block->getArguments().begin() + 1, + block->getArguments().end()), + &new_values, builder); + + // Increment the loop induction variable by one. + auto one = + builder->create(loc, builder->getI32IntegerAttr(1)); + auto scalar_broadcast_dims = builder->getDenseI64ArrayAttr({}); + auto plus_one = builder->create( + loc, block->getArgument(0), one, scalar_broadcast_dims); + // Prepend with the updated loop induction variable. + new_values.insert(new_values.begin(), plus_one); + + builder->create(loc, new_values); + } + + // TODO(jpienaar): Support multi-operand while op. + final_values->reserve(init_values.size()); + for (int i = 0, e = init_values.size(); i < e; ++i) + final_values->push_back(while_op.getResult(i + 1)); +} + +//===----------------------------------------------------------------------===// +// BatchNorm op utilities. +//===----------------------------------------------------------------------===// + +static IntegerAttr getFeatureDimensionAttr(Builder &b, + tensorflow::TensorFormat format, + Value input) { + return b.getI64IntegerAttr(GetFeatureDimension( + format, mlir::cast(input.getType()))); +} + +//===----------------------------------------------------------------------===// +// FFT op utilities. +//===----------------------------------------------------------------------===// + +// Returns the 1D i64 elements attribute populated with the inner-most dim of +// the value. +static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type, + Builder *builder) { + if (type.getRank() == 0) { + return builder->getI64TensorAttr({}); + } + return builder->getI64TensorAttr(type.getShape().back()); +} + +// Returns True if the inner-most dim is static. +bool CheckInnerDimStatic(ShapedType type, Builder *builder) { + if (!type.hasRank()) { + return false; + } + return !type.isDynamicDim(type.getShape().size() - 1); +} + +//===----------------------------------------------------------------------===// +// MatMul op utilities. +//===----------------------------------------------------------------------===// + +// If the 'transpose' attribute is true returns ElementsAttr to transpose 2D +// matrix. Otherwise, returns ElementsAttr for identity transpose. +static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { + if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b); + return GetI64ElementsAttr({0, 1}, b); +} + +//===----------------------------------------------------------------------===// +// Pad op utilities. +//===----------------------------------------------------------------------===// + +// Slices input attribute of rank two and returns the specified column. +// +// Always returns 64 bit integer attribute regardless of bitwidth of the input +// attribute. +static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( + ElementsAttr input, int column) { + auto int_attr = mlir::cast(input); + auto shaped_type = int_attr.getType(); + auto shape = shaped_type.getShape(); + + if (shape.size() != 2) return DenseIntElementsAttr(); + + llvm::SmallVector values; + values.reserve(shaped_type.getNumElements() / shape[1]); + + for (const auto &it : llvm::enumerate(int_attr.getValues())) { + if (static_cast(it.index() % shape[1]) == column) { + values.push_back(it.value().getSExtValue()); + } + } + + auto element_type = IntegerType::get(input.getContext(), 64); + return DenseIntElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape({shape[0]}, element_type), values); +} + +// Returns interior padding to use in HLO Pad op based on the TensorFlow padding +// in TensorFlow PadV2 op. +static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { + auto length = tf_padding.getShapedType().getShape()[0]; + auto element_type = IntegerType::get(tf_padding.getContext(), 64); + return DenseIntElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape({length}, element_type), 0); +} + +//===----------------------------------------------------------------------===// +// Binary op utilities. +//===----------------------------------------------------------------------===// + +// Returns whether the two values are guaranteed to be broadcastable to the +// same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions +// must be broadcasted with a size 1 tensor or another dynamic dimension. +// Returns false on rankless. +static bool AreBroadcastCompatible(Value x, Value y) { + auto x_rankless = mlir::dyn_cast(x.getType()); + auto y_rankless = mlir::dyn_cast(y.getType()); + if (!x_rankless || !y_rankless) { + return false; + } + + // Check that the shapes can be broadcasted. + auto shape_x = x_rankless.getShape(); + auto shape_y = y_rankless.getShape(); + + int rank_diff = shape_x.size() - shape_y.size(); + int offset_x = rank_diff > 0 ? rank_diff : 0; + int offset_y = rank_diff < 0 ? -rank_diff : 0; + for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) { + int index_x = i + offset_x; + int index_y = i + offset_y; + if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) || + (shape_y[index_y] == -1 && shape_x[index_x] != 1)) { + return false; + } + } + + return true; +} + +// Return a new TensorType the same rank and dimensions as the input with an +// updated element type. +static Type ChangeTensorElementType(Builder *b, Type tensor_type, + Type element_type) { + RankedTensorType ranked_type = mlir::dyn_cast(tensor_type); + if (ranked_type) { + return tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), + element_type); + } + + return UnrankedTensorType::get(element_type); +} + +//===----------------------------------------------------------------------===// +// Softmax op utilities. +//===----------------------------------------------------------------------===// + +// Returns the type to use for accumulating the given type. +static Type GetAccumulationType(Type ty) { + // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from + // repeated floating point additions. + return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty; +} + +//===----------------------------------------------------------------------===// +// Softplus op utilities. +//===----------------------------------------------------------------------===// + +static DenseElementsAttr GetEpsilonValue(Type ty) { + auto element_ty = mlir::cast(ty).getElementType(); + auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, element_ty); + if (element_ty.isF16()) { + uint16_t raw_epsilon = Eigen::numext::bit_cast( + Eigen::NumTraits::epsilon()); + auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon)); + return DenseElementsAttr::get(scalar_ty, value); + } else if (element_ty.isBF16()) { + uint16_t raw_epsilon = Eigen::numext::bit_cast( + Eigen::NumTraits::epsilon()); + auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon)); + return DenseElementsAttr::get(scalar_ty, value); + } else if (element_ty.isF32()) { + auto value = APFloat(std::numeric_limits::epsilon()); + return DenseElementsAttr::get(scalar_ty, value); + } else if (element_ty.isF64()) { + auto value = APFloat(std::numeric_limits::epsilon()); + return DenseElementsAttr::get(scalar_ty, value); + } + llvm_unreachable("unsupported element type for tf.SoftPlus"); +} + +//===----------------------------------------------------------------------===// +// ArgMax/ArgMin op utilities. +//===----------------------------------------------------------------------===// + +static void BuildArgMinMaxReductionBody(Type input_element_type, + Type index_element_type, + ComparisonDirection direction, + Region *body, OpBuilder *builder) { + OpBuilder::InsertionGuard insertion_point_gurad(*builder); + + Type input_type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, input_element_type); + Type index_type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, index_element_type); + Block *block = builder->createBlock(body); + Location loc = body->getLoc(); + block->addArguments({input_type, index_type, input_type, index_type}, + SmallVector(4, loc)); + + Value lhs_val = block->getArgument(0); + Value lhs_index = block->getArgument(1); + Value rhs_val = block->getArgument(2); + Value rhs_index = block->getArgument(3); + + ImplicitLocOpBuilder b(loc, *builder); + Value compare_dt = b.create(lhs_val, rhs_val, direction); + Value selected_input = + b.create(input_type, compare_dt, lhs_val, rhs_val); + + Value compare_eq = + b.create(lhs_val, rhs_val, ComparisonDirection::EQ); + Value min_index = b.create(lhs_index, rhs_index); + Value min_val_index = + b.create(index_type, compare_dt, lhs_index, rhs_index); + Value selected_index = + b.create(index_type, compare_eq, min_index, min_val_index); + + Value return_values[] = {selected_input, selected_index}; + b.create(return_values); +} + +//===----------------------------------------------------------------------===// +// PartitionedCall op utilities. +//===----------------------------------------------------------------------===// + +// Verify that the arguments to be passed into the function are the same types +// as the function paramter types. +static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args, + SymbolRefAttr func) { + auto module = op->getParentOfType(); + auto function = + dyn_cast_or_null(SymbolTable::lookupSymbolIn(module, func)); + FunctionType function_ty = function.getFunctionType(); + + for (auto arg_in : llvm::zip(args, function_ty.getInputs())) { + if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) { + // Argument type and input type mismatch. + return false; + } + } + return true; +} + +//===----------------------------------------------------------------------===// +// Slice op utilities. +//===----------------------------------------------------------------------===// + +static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, + DenseIntElementsAttr slice_sizes) { + auto input_ty = mlir::dyn_cast(input.getType()); + if (!input_ty) return false; + auto start_indices_ty = + mlir::dyn_cast(start_indices.getType()); + if (!start_indices_ty) return false; + + int64_t input_rank = input_ty.getRank(); + ArrayRef input_shape = input_ty.getShape(); + DenseIntElementsAttr constant_start_indices; + bool is_constant_start = + matchPattern(start_indices, m_Constant(&constant_start_indices)); + + for (int64_t i = 0; i < input_rank; ++i) { + int64_t input_size = input_shape[i]; + int64_t slice_size = slice_sizes.getValues()[i].getInt(); + // A slice_size of -1 means "all elements from start_index to the end". + // In order to support these semantics, we need to know both the start index + // and the shape of the input dimension. + if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false; + } + return true; +} + +// TF slice size can be -1, which represents all elements from start_index to +// the end. HLO slice size can't be -1. As such, we need to translate TF slice +// size -1 to HLO slice size. +static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( + Value input, Value start_indices, DenseIntElementsAttr slice_sizes, + Builder *builder) { + DenseIntElementsAttr constant_start_indices; + if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { + return mlir::cast( + hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))); + } + + auto input_ty = mlir::dyn_cast(input.getType()); + int64_t input_rank = input_ty.getRank(); + ArrayRef input_shape = input_ty.getShape(); + SmallVector normalized_sizes; + + for (int64_t i = 0; i < input_rank; ++i) { + int64_t input_size = input_shape[i]; + int64_t start_index = + constant_start_indices.getValues()[i].getInt(); + int64_t slice_size = slice_sizes.getValues()[i].getInt(); + normalized_sizes.push_back(slice_size == -1 ? input_size - start_index + : slice_size); + } + + return GetI64ElementsAttr(normalized_sizes, builder); +} + +//===----------------------------------------------------------------------===// +// XlaGather op utilities. +//===----------------------------------------------------------------------===// + +bool HasValidGatherDims(StringAttr attr) { + ::xla::GatherDimensionNumbers dims; + return dims.ParseFromString(attr.getValue().str()); +} + +GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr, + Builder *builder) { + ::xla::GatherDimensionNumbers dims; + if (!dims.ParseFromString(attr.getValue().str())) return {}; + return ::xla::ConvertGatherDimensionNumbers(dims, builder); +} + +//===----------------------------------------------------------------------===// +// XlaDot op utilities. +//===----------------------------------------------------------------------===// + +bool HasValidDotDims(StringAttr attr) { + ::xla::DotDimensionNumbers dims; + return dims.ParseFromString(attr.getValue().str()); +} + +DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, Builder *builder) { + ::xla::DotDimensionNumbers dims; + if (!dims.ParseFromString(attr.getValue().str())) return {}; + return ::xla::ConvertDotDimensionNumbers(dims, builder); +} + +bool HasValidPrecisionConfig(StringAttr attr) { + ::xla::PrecisionConfig precision; + return precision.ParseFromString(attr.getValue().str()); +} + +mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) { + ::xla::PrecisionConfig precision; + if (!precision.ParseFromString(attr.getValue().str())) return {}; + return ::xla::ConvertPrecisionConfig(&precision, builder); +} + +//===----------------------------------------------------------------------===// +// XlaVariadicReduceV2 op utilities. +//===----------------------------------------------------------------------===// + +static void BuildBodyWithCall(PatternRewriter &rewriter, const Location &loc, + mlir::SymbolRefAttr func, + mlir::FunctionType func_ty, Region *body) { + OpBuilder::InsertionGuard guard(rewriter); + + Block *block = rewriter.createBlock(body); + auto inputs = func_ty.getInputs(); + block->addArguments(inputs, SmallVector(inputs.size(), loc)); + mlir::func::CallOp call_op = rewriter.create( + loc, func, func_ty.getResults(), block->getArguments()); + rewriter.create(loc, call_op.getResults()); +} + +//===----------------------------------------------------------------------===// +// Op converters. +//===----------------------------------------------------------------------===// + +NamedAttribute GetConvDimensionNumbersAttr(ArrayRef spatial_dims, + tensorflow::TensorFormat format, + Builder *builder) { + int64_t num_spatial_dims = spatial_dims.size(); + int64_t num_dims = num_spatial_dims + 2; + + int64_t batch_dim = GetTensorBatchDimIndex(num_dims, format); + int64_t feature_dim = GetTensorFeatureDimIndex(num_dims, format); + + // Filters data_format is always HWIO so input channels dimension is after + // all spatial dimensions. + int64_t kernel_input_feature_dim = num_spatial_dims; + int64_t kernel_output_feature_dim = num_spatial_dims + 1; + SmallVector kernel_spatial_dimensions; + kernel_spatial_dimensions.resize(num_spatial_dims); + std::iota(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(), + 0); + + return builder->getNamedAttr( + "dimension_numbers", + ConvDimensionNumbersAttr::get( + builder->getContext(), batch_dim, feature_dim, spatial_dims, + kernel_input_feature_dim, kernel_output_feature_dim, + kernel_spatial_dimensions, batch_dim, feature_dim, spatial_dims)); +} + +// Converts a TF::BiasAddOp to HLO. +// This differs from a normal TF::AddOp with respect to how the data_format +// is handled, which can optionally require a general broadcast of the +// 'bias' term in a way that is not compatible with the standard left-padded +// broadcast semantics (i.e. NCHW will broadcast into dimension 1). +// The correct 'bias' broadcast will be synthesized manually. +class ConvertBiasAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::BiasAddOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + auto value_type = mlir::dyn_cast(op.getValue().getType()); + if (!value_type) return failure(); + auto feature_dim = GetFeatureDimension(data_format, value_type); + auto bias_broadcast = Broadcast1DToFeatureDim( + loc, op.getValue(), op.getBias(), feature_dim, rewriter); + Value add = rewriter.create(loc, op.getValue(), bias_broadcast); + if (add.getType() != op.getType()) { + add = rewriter.create(loc, op.getType(), add); + } + rewriter.replaceOp(op, {add}); + return success(); + } +}; + +// Conterts tf.Conv2D to mhlo.dynamic_conv. +// TODO(disc): To recover static special case's performance with adding folding, +// canonicalization func and removing ConvertConvOp. +template +class ConvertConvDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + bool GetPaddingValues(OpT &op, PatternRewriter &rewriter, Value input_size, + Value filter_size, int64_t dilation_rate, + int64_t stride, tensorflow::Padding padding_type, + Type shape_scalar_type, Value *padding_low, + Value *padding_high) const { + // Stride must be > 0 + if (stride <= 0) return false; + // Dilation rate must be >= 1 + if (dilation_rate < 1) return false; + + Location loc = op.getLoc(); + switch (padding_type) { + case tensorflow::Padding::VALID: { + auto zero = + rewriter.create(loc, 0, shape_scalar_type); + *padding_low = *padding_high = zero; + break; + } + case tensorflow::Padding::EXPLICIT: + break; + case tensorflow::Padding::SAME: { + auto zero = + rewriter.create(loc, 0, shape_scalar_type); + auto one = + rewriter.create(loc, 1, shape_scalar_type); + auto two = + rewriter.create(loc, 2, shape_scalar_type); + // See also the parallel implementation in + // GetWindowedOutputSizeFromDimsV2. effective_filter_size = (filter_size + // - 1) * dilation_rate + 1 + Value stride_value = rewriter.create( + loc, stride, shape_scalar_type); + Value dilation_rate_value = rewriter.create( + loc, dilation_rate, shape_scalar_type); + Value effective_filter_size_op = rewriter.create( + loc, one, + rewriter.create( + loc, dilation_rate_value, + rewriter.create(loc, filter_size, one))); + // output_size = (input_size + stride - 1) / stride; + Value output_size = rewriter.create( + loc, + rewriter.create( + loc, input_size, + rewriter.create(loc, stride_value, one)), + stride_value); + // std::max(int64{0}, (output_size - 1) * stride + + // effective_filter_size - input_size); + Value padding_needed = rewriter.create( + loc, + rewriter.create( + loc, effective_filter_size_op, + rewriter.create( + loc, stride_value, + rewriter.create(loc, output_size, one))), + input_size); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::sge, padding_needed, zero); + padding_needed = rewriter.create( + loc, padding_needed.getType(), cond, padding_needed, zero); + *padding_low = + rewriter.create(loc, padding_needed, two); + *padding_high = + rewriter.create(loc, padding_needed, *padding_low); + break; + } + } + return true; + } + + LogicalResult matchAndRewriteDynamicConv(OpT op, + PatternRewriter &rewriter) const { + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); + auto result_ty = mlir::dyn_cast(op.getType()); + if (!input_ty || !filter_ty || !result_ty) return failure(); + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (input_ty.hasStaticShape() && filter_ty.hasStaticShape()) + return failure(); + + ArrayRef dilations = op.getDilations().getValue(); + ArrayRef strides = op.getStrides().getValue(); + ArrayRef explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is attached to + // Conv2D. + explicit_paddings = + op->template getAttrOfType("explicit_paddings").getValue(); + } + + SmallVector spatial_dim_indices; + SmallVector rhs_dilations; + SmallVector window_strides; + SmallVector paddings; + + auto get_int = [](Attribute attr) { + return mlir::cast(attr).getInt(); + }; + + constexpr int num_dims = num_spatial_dims + 2; + + Location loc = op.getLoc(); + auto shape_scalar_type = rewriter.getIntegerType(32); + + auto get_const = [&](int64_t val) { + return rewriter.create(loc, val, + shape_scalar_type); + }; + auto get_dim_value = [&](Value val, int64_t dim) { + Value dim_value = rewriter.create(loc, val, dim); + return rewriter.create(loc, shape_scalar_type, + dim_value); + }; + + for (auto i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + spatial_dim_indices.push_back(dim); + + const int64_t dilation = get_int(dilations[dim]); + rhs_dilations.push_back(dilation); + const int64_t stride = get_int(strides[dim]); + window_strides.push_back(stride); + + Value pad_low, pad_high; + if (padding == tensorflow::Padding::EXPLICIT) { + pad_low = get_const(get_int(explicit_paddings[2 * dim])); + pad_high = get_const(get_int(explicit_paddings[2 * dim + 1])); + } else { + auto input_size = get_dim_value(op.getInput(), dim); + auto filter_size = get_dim_value(op.getFilter(), i); + if (!GetPaddingValues(op, rewriter, input_size, filter_size, dilation, + stride, padding, shape_scalar_type, &pad_low, + &pad_high)) { + return failure(); + } + } + paddings.push_back(pad_low); + paddings.push_back(pad_high); + } + auto rhs_dilations_attr = rewriter.getNamedAttr( + "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); + + auto window_strides_attr = rewriter.getNamedAttr( + "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); + + auto dimension_numbers_attr = GetConvDimensionNumbersAttr( + spatial_dim_indices, data_format, &rewriter); + + const int64_t input_channels = + GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format)); + // Filters data_format is always HWIO so input channels dimension is after + // all spatial dimensions. + const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); + // TensorFlow convolution op verifies that the number of input channels is + // divisible by the number of filter channels. + // For depthwise convolution the feature_group_count argument would be set + // to the input feature dimension. + const int64_t feature_group_count = + depthwise_conv ? input_channels : input_channels / filter_channels; + auto feature_group_count_attr = rewriter.getNamedAttr( + "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count)); + + auto batch_group_count_attr = rewriter.getNamedAttr( + "batch_group_count", rewriter.getI64IntegerAttr(1)); + + auto precision_config_attr = rewriter.getNamedAttr( + "precision_config", GetPrecisionConfig(&rewriter)); + + Value paddings_op = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape(2 * num_spatial_dims, + rewriter.getI32Type()), + paddings); + + SmallVector operands(op.getOperands()); + operands.push_back(paddings_op); + // Reshape the filter to {spatial_dims...., 1,in_channels * + // channel_multiplier} + if (depthwise_conv) { + ArrayRef filter_shape = filter_ty.getShape(); + llvm::SmallVector new_shape( + filter_shape.begin(), filter_shape.begin() + num_spatial_dims); + new_shape.push_back(1); + new_shape.push_back(filter_shape[num_spatial_dims] * + filter_shape[num_spatial_dims + 1]); + operands[1] = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape(new_shape, + filter_ty.getElementType()), + operands[1]); + } + NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, + dimension_numbers_attr, feature_group_count_attr, + batch_group_count_attr, precision_config_attr}; + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + llvm::ArrayRef(attrs)); + return success(); + } + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + return matchAndRewriteDynamicConv(op, rewriter); + } +}; + +using ConvertConv2DDynamic = + ConvertConvDynamic; + +// Converts the TensorFlow conv op in template to the generic HLO conv op by +// converting TensorFlow op attributes to HLO op attributes. +// +// Sample result for Conv2D: +// +// %conv = "mhlo.convolution"(%input, %filter) { +// strides = [1, 2], +// paddings = [[1, 0], [1, 1]], +// ... +// } +// +// This pattern is not defined using declarative rewrite rules as computation of +// the paddings attribute anyway requires multiple source op attributes and +// result op attributes. Defining it as declarative rewrite rule will introduce +// some duplication in the C++ helper methods. +template +class ConvertConvOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); + + // With the exception of input's batch dimension, input and filter need to + // have static shape for calculation of HLO paddings and feature group count + // attributes. Filter is validated here, input is mostly validated at use. + if (!input_ty || !filter_ty || !filter_ty.hasStaticShape()) + return failure(); + + ArrayRef dilations = op.getDilations().getValue(); + ArrayRef strides = op.getStrides().getValue(); + ArrayRef explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is limited to + // Conv2D. So, fetch attribute by identifier instead of the + // op.explicit_paddings() attribute getter. + explicit_paddings = + op->template getAttrOfType("explicit_paddings").getValue(); + } + + SmallVector spatial_dim_indices; + SmallVector rhs_dilations; + SmallVector window_strides; + SmallVector paddings; + + auto get_int = [](Attribute attr) { + return mlir::cast(attr).getInt(); + }; + + constexpr int num_dims = num_spatial_dims + 2; + for (auto i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + spatial_dim_indices.push_back(dim); + + const int64_t dilation = get_int(dilations[dim]); + rhs_dilations.push_back(dilation); + const int64_t stride = get_int(strides[dim]); + window_strides.push_back(stride); + + int64_t pad_low, pad_high; + if (padding == tensorflow::Padding::EXPLICIT) { + pad_low = get_int(explicit_paddings[2 * dim]); + pad_high = get_int(explicit_paddings[2 * dim + 1]); + } else { + int64_t output_size; + int64_t pad_low_int64; + int64_t pad_high_int64; + int64_t input_size = input_ty.getDimSize(dim); + if (input_size == ShapedType::kDynamic) return failure(); + absl::Status status = tensorflow::GetWindowedOutputSizeVerbose( + input_size, filter_ty.getDimSize(i), dilation, stride, padding, + &output_size, &pad_low_int64, &pad_high_int64); + if (!status.ok()) return failure(); + pad_low = pad_low_int64; + pad_high = pad_high_int64; + } + paddings.push_back(pad_low); + paddings.push_back(pad_high); + } + + auto rhs_dilations_attr = rewriter.getNamedAttr( + "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); + + auto window_strides_attr = rewriter.getNamedAttr( + "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); + + auto dimension_numbers_attr = GetConvDimensionNumbersAttr( + spatial_dim_indices, data_format, &rewriter); + + const int64_t input_channels = + GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format)); + if (input_channels == ShapedType::kDynamic) return failure(); + // Filters data_format is always HWIO so input channels dimension is after + // all spatial dimensions. + const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); + // TensorFlow convolution op verifies that the number of input channels is + // divisible by the number of filter channels. + // For depthwise convolution the feature_group_count argument would be set + // to the input feature dimension. + const int64_t feature_group_count = + depthwise_conv ? input_channels : input_channels / filter_channels; + auto feature_group_count_attr = rewriter.getNamedAttr( + "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count)); + + auto batch_group_count_attr = rewriter.getNamedAttr( + "batch_group_count", rewriter.getI64IntegerAttr(1)); + + RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( + {num_spatial_dims, 2}, rewriter.getIntegerType(64)); + auto paddings_attr = rewriter.getNamedAttr( + "padding", DenseElementsAttr::get(paddings_ty, paddings)); + + auto precision_config_attr = rewriter.getNamedAttr( + "precision_config", GetPrecisionConfig(&rewriter)); + + SmallVector operands(op.getOperands()); + // Reshape the filter to {spatial_dims...., 1,in_channels * + // channel_multiplier} + if (depthwise_conv) { + ArrayRef filter_shape = filter_ty.getShape(); + llvm::SmallVector new_shape( + filter_shape.begin(), filter_shape.begin() + num_spatial_dims); + new_shape.push_back(1); + new_shape.push_back(filter_shape[num_spatial_dims] * + filter_shape[num_spatial_dims + 1]); + operands[1] = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape(new_shape, + filter_ty.getElementType()), + operands[1]); + } + NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, + dimension_numbers_attr, feature_group_count_attr, + batch_group_count_attr, paddings_attr, + precision_config_attr}; + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + llvm::ArrayRef(attrs)); + return success(); + } +}; + +using ConvertConv2DOp = ConvertConvOp; +using ConvertConv3DOp = ConvertConvOp; +using ConvertDepthConv2DOp = + ConvertConvOp; + +// Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const. +class ConvertPadOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + // TODO(disc): To recover static special case's performance with folding and + // canonicalization. + LogicalResult matchAndRewrite(TF::PadV2Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto input = op.getInput(); + auto paddings = op.getPaddings(); + auto constant_values = op.getConstantValues(); + auto input_type = mlir::dyn_cast(input.getType()); + auto paddings_type = mlir::dyn_cast(paddings.getType()); + if (!input_type || !paddings_type || !paddings_type.hasStaticShape()) + return failure(); + + // TODO(disc): Remove this constraint once fold and canonicalization is + // implemented. + if (input_type.hasStaticShape()) return failure(); + + int input_rank = input_type.getRank(); + // interior padding + std::vector interior_values(input_rank, 0); + auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter); + + Value interior_padding_tensor = + rewriter.create(loc, interior_attr); + Type paddings_elem_ty = paddings_type.getElementType(); + if (!paddings_elem_ty.isInteger(64)) { + interior_padding_tensor = rewriter.create( + loc, interior_padding_tensor, paddings_elem_ty); + } + llvm::SmallVector transposed_shape = {2, input_rank}; + auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter); + Value transposed_paddings = + rewriter.create(loc, paddings, transpose_attr); + Value reshaped_paddings = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape({input_rank * 2}, + paddings_elem_ty), + transposed_paddings); + + auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter); + auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter); + auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); + Value left_padding_tensor = rewriter.create( + loc, reshaped_paddings, left_padding_start_attr, + left_padding_limit_attr, left_padding_stride_attr); + + auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter); + auto right_padding_limit_attr = + GetI64ElementsAttr({2 * input_rank}, &rewriter); + auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); + Value right_padding_tensor = rewriter.create( + loc, reshaped_paddings, right_padding_start_attr, + right_padding_limit_attr, right_padding_stride_attr); + + rewriter.replaceOpWithNewOp( + op, op.getType(), input, constant_values, left_padding_tensor, + right_padding_tensor, interior_padding_tensor); + + return success(); + } +}; + +class ConvertGatherNdOpDynamic : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + // Converts tf.GatherNdOp to mhlo.DynamicGatherOp. + // Here we leave 'slice_sizes' as an Attr, without defining a new + // DynamicGatherOp, since GatherDimensionNumbers has already provide enough + // information for shape inference and code generation of mhlo::GatherOp. '?' + // will be filled into slice_sizes for dimensions that are dynamic sized. + // TODO(disc): To recover static special case's performance with folding and + // canonicalization. + LogicalResult matchAndRewrite(TF::GatherNdOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto params = op.getParams(); + auto params_ty = mlir::dyn_cast(params.getType()); + auto indices = op.getIndices(); + auto indices_ty = mlir::dyn_cast(indices.getType()); + auto params_rank = params_ty.getRank(); + auto indices_rank = indices_ty.getRank(); + int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); + if (!params_ty || !indices_ty) return failure(); + // the last dim of indices of GatherNdOp must be fixed shaped + if (num_index_dims == ShapedType::kDynamic) return failure(); + + SmallVector slice_sizes; + slice_sizes.reserve(params_rank); + for (int64_t i = 0; i < params_rank; ++i) { + if (i < num_index_dims) { + slice_sizes.push_back(1); + } else { + // potentially dynamic + int64_t dim_size = params_ty.getDimSize(i); + slice_sizes.push_back(dim_size); + } + } + SmallVector slice_sizes_vals; + Value slice_sizes_value = nullptr; + for (int64_t i = 0; i < params_rank; ++i) { + if (i < num_index_dims) { + slice_sizes_vals.push_back(rewriter.create( + loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1))); + } else { + int64_t dim_size = params_ty.getDimSize(i); + if (dim_size != ShapedType::kDynamic) { + slice_sizes_vals.push_back(rewriter.create( + loc, + rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size))); + } else { + slice_sizes_vals.push_back(rewriter.create( + loc, indices_ty.getElementType(), + rewriter.create(loc, params, i))); + } + } + } + slice_sizes_value = + rewriter.create(loc, slice_sizes_vals); + + // collapsed_slice_dims + SmallVector collapsed_slice_dims; + collapsed_slice_dims.reserve(num_index_dims); + for (int64_t i = 0; i < num_index_dims; ++i) { + collapsed_slice_dims.push_back(i); + } + // offset_dims + SmallVector offset_dims; + offset_dims.reserve(params_rank - num_index_dims); + for (int64_t i = num_index_dims; i < params_rank; i++) { + offset_dims.push_back(i + indices_rank - 1 - num_index_dims); + } + // start_index_map + SmallVector start_index_map; + offset_dims.reserve(num_index_dims); + for (int64_t i = 0; i < num_index_dims; i++) { + start_index_map.push_back(i); + } + // index_vector_dim + int64_t index_vector_dim = indices_rank - 1; + + auto dims_attr = GatherDimensionNumbersAttr::get( + rewriter.getContext(), offset_dims, collapsed_slice_dims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, start_index_map, index_vector_dim); + // TODO(disc): Remove this if-statement once fold and canonicalization is + // implemented. + if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getParams(), op.getIndices(), dims_attr, + GetI64ElementsAttr(slice_sizes, &rewriter)); + } else { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getParams(), op.getIndices(), slice_sizes_value, + dims_attr); + } + return success(); + } +}; + +// Converts BF16 FloorDiv op to have casting operators on either end as BF16 +// division can result in strange behavior. +// +// floordiv = cast(floordiv(cast(left), cast(right)))) +// +// %left_cast = cast(%left) +// %right_cast = cast(%right) +// %div = div(%left, %left) +// %floored = floor(%div) +// %floored_cast = cast(%floored) +// +// Required to manually specify the intermediate types. +class ConvertBF16FloorDivOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::FloorDivOp op, + PatternRewriter &rewriter) const override { + auto l = mlir::dyn_cast>(op.getX()); + auto r = mlir::dyn_cast>(op.getY()); + if (!l || !r) return failure(); + + auto element_type = getElementTypeOrSelf(l.getType()); + if (!element_type.isBF16()) return failure(); + + auto out_type = op.getZ().getType(); + + l = rewriter.create(op.getLoc(), l, rewriter.getF32Type()); + r = rewriter.create(op.getLoc(), r, rewriter.getF32Type()); + + auto intermediate = rewriter.create( + op.getLoc(), + ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l, + r); + + auto floor_op = + rewriter.create(op.getLoc(), out_type, intermediate); + rewriter.replaceOp(op, floor_op.getResult()); + return success(); + } +}; + +class ConvertBroadcastToOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::BroadcastToOp op, + PatternRewriter &rewriter) const override { + auto input_type = mlir::dyn_cast(op.getInput().getType()); + auto output_type = op.getOutput().getType(); + if (!input_type) { + return rewriter.notifyMatchFailure(op, "requires ranked input shape"); + } + llvm::SmallVector broadcast_dimensions; + if (input_type.getRank() > 0) { + auto ranked_output_type = mlir::dyn_cast(output_type); + if (!ranked_output_type) { + return rewriter.notifyMatchFailure(op, "requires ranked output shape"); + } + auto rank_diff = ranked_output_type.getRank() - input_type.getRank(); + // The tf.BroadcastTo op performs "right-aligned" numpy-style + // broadcasting. + broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(rank_diff, ranked_output_type.getRank())); + } + rewriter.replaceOpWithNewOp( + op, output_type, op.getInput(), op.getShape(), + rewriter.getI64TensorAttr(broadcast_dimensions)); + return success(); + } +}; + +/// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis +/// have to be a constant. +class ConvertRollOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::RollOp op, + PatternRewriter &rewriter) const override { + auto shift_ty = mlir::dyn_cast(op.getShift().getType()); + if (!shift_ty || shift_ty.getRank() != 0) { + return rewriter.notifyMatchFailure( + op, "require the type of shift to be 0D tensor"); + } + + APInt val; + if (!matchPattern(op.getAxis(), m_ConstantInt(&val))) { + return rewriter.notifyMatchFailure(op, "require axis to be constant"); + } + int axis = val.getSExtValue(); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty || !input_ty.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "require the type of input to have static shapes"); + } + ArrayRef input_shape = input_ty.getShape(); + int input_rank = input_ty.getRank(); + if (axis < 0) axis += input_rank; + + // Adjust large offsets into [0, axis_size). This also makes negative + // offsets positive. + // offset = ((offset % axis_size) + axis_size) % axis_size + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value offset = op.getShift(); + auto axis_size = b.create(b.getIntegerAttr( + getElementTypeOrSelf(offset.getType()), input_shape[axis])); + offset = b.create( + b.create(b.create(offset, axis_size), axis_size), + axis_size); + + // Stack two copies of the dimension, then slice from the calculated + // offset. This also works if shift is not constant. + // DynamicSliceOp requires the sizes being integer, and we can get the + // information from input shape. + auto concat = b.create( + ValueRange{op.getInput(), op.getInput()}, b.getI64IntegerAttr(axis)); + Value zero = b.create( + b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0)); + SmallVector slice_begin_indices(input_rank, zero); + slice_begin_indices[axis] = b.create(axis_size, offset); + rewriter.replaceOpWithNewOp( + op, input_ty, concat, slice_begin_indices, + rewriter.getI64TensorAttr(input_shape)); + return success(); + } +}; + +/// Converts a TF::LeakyReluOp to HLO. +/// LeakyRelu(x) = alpha * x if x < 0 else x. +class ConvertLeakyReluOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::LeakyReluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value features = op.getFeatures(); + + // Use ConstantLike for `alpha` to match the shape of feature. + auto alphaVal = chlo::getConstantLike( + rewriter, loc, op.getAlpha().convertToFloat(), features); + Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); + + Value leakyActivationVal = + rewriter.create(loc, features, alphaVal); + + Value compareGtZero = rewriter.create( + loc, features, zeroVal, ComparisonDirection::GT); + + rewriter.replaceOpWithNewOp(op, compareGtZero, features, + leakyActivationVal); + return success(); + } +}; + +/// Converts a TF::LeakyReluGradOp to HLO. +/// LeakyReluGrad(gradient, inputs) = gradient if input > 0 +/// else alpha * gradient. +class ConvertLeakyReluGradOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::LeakyReluGradOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value gradients = op.getGradients(); + Value features = op.getFeatures(); + auto featureType = features.getType(); + + // Use ConstantLike for `alpha` to match the shape of feature. + auto alphaVal = chlo::getConstantLike( + rewriter, loc, op.getAlpha().convertToFloat(), features); + Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); + + Value leakyGradientVal = + rewriter.create(loc, gradients, alphaVal); + + Value compareGtZero = rewriter.create( + loc, features, zeroVal, ComparisonDirection::GT); + + rewriter.replaceOpWithNewOp(op, featureType, compareGtZero, + gradients, leakyGradientVal); + return success(); + } +}; + +// Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix. +// For a Rank-2 input, it creates the following ops: +// %1 = "mhlo.iota"() {iota_dimension = 0 : i64} +// %2 = "mhlo.iota"() {iota_dimension = 1 : i64} +// %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"} +// %4 = mhlo.constant dense<0.000000e+00> : tensor +// %5 = "mhlo.broadcast"(%4) +// %6 = "mhlo.select"(%3, %input, %5) +// %7 = "mhlo.reduce"(%6, %4) ({ +// ^bb0(%arg1: tensor, %arg2: tensor): +// %9 = mhlo.add %arg1, %arg2 : tensor +// "mhlo.return"(%9) : (tensor) -> () +// }) {dimensions = dense<0> : tensor<1xi64>} +// +// If the input's rank N is greater than 2, we will reshape it to R2 first and +// create the above ops, then reshape it back to rank N/2. +class ConvertDiagPartOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::DiagPartOp op, + PatternRewriter &rewriter) const override { + auto input_type = mlir::dyn_cast(op.getInput().getType()); + if (!input_type || !input_type.hasStaticShape()) return failure(); + int64_t num_dims = input_type.getRank(); + if (num_dims < 2 || num_dims % 2 != 0) return failure(); + const int64_t out_dims = num_dims / 2; + + int64_t new_size = 1; + llvm::SmallVector new_dims; + for (int i = 0; i < out_dims; i++) { + if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims)) + return op.emitOpError("invalid dimensions size"); + new_size *= input_type.getDimSize(i); + new_dims.push_back(input_type.getDimSize(i)); + } + Value reshaped_input = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({new_size, new_size}, + input_type.getElementType()), + op.getInput()); + auto iota_type = tensorflow::GetTypeFromTFTensorShape( + {new_size, new_size}, rewriter.getIntegerType(32)); + auto iota0 = rewriter.create(op.getLoc(), iota_type, + rewriter.getI64IntegerAttr(0)); + auto iota1 = rewriter.create(op.getLoc(), iota_type, + rewriter.getI64IntegerAttr(1)); + Value compare = rewriter.create(op.getLoc(), iota0, iota1, + ComparisonDirection::EQ); + Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), + 0, &rewriter); + Value zero_matrix = rewriter.create( + op.getLoc(), reshaped_input.getType(), zero, + GetI64ElementsAttr({new_size, new_size}, &rewriter)); + Value masked = + rewriter.create(op.getLoc(), reshaped_input.getType(), + compare, reshaped_input, zero_matrix); + auto reduce = rewriter.create(op.getLoc(), masked, zero, + GetI64ElementsAttr({0}, &rewriter), + input_type.getElementType()); + assert(!input_type.getElementType().isInteger(1) && + "data type should not be i1"); + BuildReduceBody(input_type.getElementType(), &reduce.getBody(), + &rewriter); + rewriter.replaceOpWithNewOp( + op, + tensorflow::GetTypeFromTFTensorShape(new_dims, + input_type.getElementType()), + reduce.getResult(0)); + return success(); + } +}; + +// Converts TensorFlow MatrixDiagPartOp to HLO ops. +class ConvertMatrixDiagPartV3Op + : public OpRewritePattern { + using Shape = llvm::SmallVector; + + // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s) + // with k. This can be either a single value (for a single diagonal) or a + // tuple of two values (starting and ending diagonal, for a band). + LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const { + DenseIntElementsAttr kattr; + if (!matchPattern(op.getK(), m_Constant(&kattr))) { + return failure(); + } + DenseIntElementsAttr::iterator it = kattr.begin(); + (*k)[0] = (*it).getSExtValue(); + it++; + if (it == kattr.end()) { + // Handle input like e.g. "k = 5", in which case we extract a single + // diagonal. + (*k)[1] = (*k)[0]; + } else { + // Handle input like e.g. "k = [-1, 1]", in which case we extract a + // band (multiple diagonals). + (*k)[1] = (*it).getSExtValue(); + } + return success(); + } + + // Utility method for broadcasting integer constants to a given shape. + BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant, + int int_size, PatternRewriter &rewriter) const { + return rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(shape, + rewriter.getIntegerType(int_size)), + GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant, + &rewriter), + GetI64ElementsAttr(shape, &rewriter)); + } + + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + ShapedType input_type = mlir::dyn_cast(op.getInput().getType()); + + // Align is a string specifying how superdiagonals and subdiagonals should + // be aligned/padded for diagonals that are shorter than max_diag_len. The + // format is "{super}_{sub}", with {super} the superdiagonal alignment and + // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the + // left, "RIGHT" means rows will be padded ot the right. The default is + // "RIGHT_LEFT". + StringRef align = op->getAttrOfType("align").getValue(); + enum Alignment { kLeft, kRight }; + + // default is RIGHT_LEFT + Alignment superdiagonal_align = kRight; + Alignment subdiagonal_align = kLeft; + + if (align == "RIGHT_LEFT") { + superdiagonal_align = kRight; + subdiagonal_align = kLeft; + } else if (align == "RIGHT_RIGHT") { + superdiagonal_align = kRight; + subdiagonal_align = kRight; + } else if (align == "LEFT_RIGHT") { + superdiagonal_align = kLeft; + subdiagonal_align = kRight; + } else if (align == "LEFT_LEFT") { + superdiagonal_align = kLeft; + subdiagonal_align = kLeft; + } else { + return failure(); // unsupported alignment + } + + // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and + // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L]. + if (!input_type || !input_type.hasStaticShape()) return failure(); + int64_t num_dims = input_type.getRank(); + if (num_dims < 2) return failure(); + int64_t rows = input_type.getDimSize(num_dims - 2); // rows + int64_t cols = input_type.getDimSize(num_dims - 1); // cols + + // We extract the diagonals from k[0] up to and including k[1]. + // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract + // the main diagonal). It's negative for subdiagonals (under and to the left + // of the main diagonal) and positive for superdiagonals (above and to the + // right of the main diagonal). + int64_t k[2]; + if (failed(ExtractK(op, &k))) return failure(); + int num_diags = k[1] - k[0] + 1; + + // Shifting diagonals away from the main diagonal might shorten them. This + // is the longest diagonal we will see. We make this the last dimension of + // the output shape. + int64_t max_diag_len = + std::min(rows + std::min(k[1], static_cast(0)), + cols + std::min(-k[0], static_cast(0))); + + // The first dimension is the index vector dimension we'll use for gather. + // It's 1 here, but will be 2 once we glue x and y together. + Shape indices_shape({1, num_diags, max_diag_len}); + + RankedTensorType iota_type = tensorflow::GetTypeFromTFTensorShape( + indices_shape, rewriter.getIntegerType(32)); + Value iotaM = + rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(1)); + Value iotaN = + rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(2)); + + // Boradcasted constants, of the same shape as iotaM and iotaN. + Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter); + Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter); + Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter); + Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter); + Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter); + Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter); + Value b_max_diag_len = + BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter); + + // d = k[1] - m + // (A.k.a. the number of the diagonal, depending on m. Note that we + // subtract m here. This means we start with the superdiagonals and + // move downwards towards the subdiagonals. So the start indices will + // be decreasing.) + Value d = rewriter.create(loc, b_k1, iotaM); + Value neg_d = rewriter.create(loc, d); + + // diag_len_d = min(rows + min(d, 0), cols - max(d, 0)) + // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.) + Value diag_len_d = rewriter.create( + loc, + rewriter.create(loc, b_rows, + rewriter.create(loc, d, b_zero)), + rewriter.create(loc, b_cols, + rewriter.create(loc, d, b_zero))); + + // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise. + Value cmp; + if (subdiagonal_align == kRight && superdiagonal_align == kRight) { + cmp = b_true; + } else if (superdiagonal_align == kRight) { + // offset = d>=0 ? max_diag_len - diag_len_d : 0 + cmp = rewriter.create(loc, d, b_zero); + } else if (subdiagonal_align == kRight) { + // offset = d<=0 ? max_diag_len - diag_len_d : 0 + cmp = rewriter.create(loc, d, b_zero); + } else { + // offset = 0 + cmp = b_false; + } + + // This offset shifts the diagonals to the "left" or "right", depending + // on alignment. + Value offset = rewriter.create( + loc, b_zero.getType(), cmp, + rewriter.create(loc, b_max_diag_len, diag_len_d), b_zero); + + // x = max(d, 0) - offset + // y = max(-d, 0) - offset + Value x = rewriter.create( + loc, rewriter.create(loc, d, b_zero), offset); + Value y = rewriter.create( + loc, rewriter.create(loc, neg_d, b_zero), offset); + + Value n_plus_x = rewriter.create(loc, iotaN, x); + Value n_plus_y = rewriter.create(loc, iotaN, y); + + // GatherOp is happy about letting us index out of bounds values, but those + // values will be undefined. So we mask them later. Set up the boolean + // expression that tells us which entries, in the output shape, are out of + // bounds and thus become the padding_value. + Value x_in_bounds = rewriter.create( + loc, + rewriter.create(loc, b_false.getType(), n_plus_x, + b_zero), + rewriter.create(loc, b_false.getType(), n_plus_x, b_cols)); + Value y_in_bounds = rewriter.create( + loc, + rewriter.create(loc, b_false.getType(), n_plus_y, + b_zero), + rewriter.create(loc, b_false.getType(), n_plus_y, b_rows)); + Value in_bounds = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(Shape({num_diags, max_diag_len}), + rewriter.getIntegerType(1)), + rewriter.create(loc, x_in_bounds, y_in_bounds)); + + // Now combine x and y into the index data structure needed for gather. + Shape concat_shape({2, num_diags, max_diag_len}); + Value start_indices = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(concat_shape, + rewriter.getIntegerType(32)), + mlir::ValueRange({n_plus_y, n_plus_x}), + mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0)); + + // Shape of the final output. (Except for dimension folding in the + // single diagonal case.) + Shape output_shape; + for (int i = 0; i < num_dims - 2; i++) { + output_shape.push_back(input_type.getDimSize(i)); + } + output_shape.push_back(num_diags); + output_shape.push_back(max_diag_len); + + // A slice is the shape of what GatherOp copies per lookup. So the last + // two dimensions (M, N in the matrix-diag-part docs) are where we go + // through entry by entry. + ArrayRef input_shape = input_type.getShape(); + int input_shape_size = input_shape.size(); + Shape slice_sizes(input_shape.begin(), input_shape.end()); + int slice_dimensions = slice_sizes.size(); + slice_sizes[slice_dimensions - 2] = + std::min((int64_t)1, input_shape[input_shape_size - 2]); + slice_sizes[slice_dimensions - 1] = + std::min((int64_t)1, input_shape[input_shape_size - 1]); + + // Dimensions of the input we won't see in the output (M and N). + SmallVector collapsed_dims( + {slice_dimensions - 2, slice_dimensions - 1}); + + // Which dimensions (in the input) the two offset "columns" map to. + SmallVector start_index_map({num_dims - 2, num_dims - 1}); + + // Gather the diagonal entries. + // TODO(kramm): For a single diagonal, this might be slower than the + // mask + sum approach. Special-case num_diags==1? + auto dims_attr = GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/llvm::to_vector<4>(llvm::seq(0, num_dims - 2)), + /*collapsedSliceDims=*/collapsed_dims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, start_index_map, + /*indexVectorDim=*/0); + Value gather = rewriter.create( + loc, op.getInput(), start_indices, dims_attr, + GetI64ElementsAttr(slice_sizes, &rewriter)); + + // We now need to broadcast the "in_bounds" boolean expression, as well as + // the padding value, to do the final select. + Shape broadcast_bounds; + for (int i = 0; i < output_shape.size() - 2; i++) { + broadcast_bounds.push_back(output_shape[i]); + } + Value b_in_bounds = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(output_shape, + rewriter.getIntegerType(1)), + in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter)); + Value b_padding = rewriter.create( + loc, op.getPaddingValue(), GetI64ElementsAttr(output_shape, &rewriter)); + + // Replace all out-of-bounds values in the result with padding_value. + Value result = + rewriter.create(loc, b_in_bounds, gather, b_padding); + + if (num_diags == 1) { + // matrix_diag_part folds away the 1-sized band dimension if we only + // extract a single diagonal. + result = rewriter.create(loc, op.getType(), result); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +// Converts TensorFlow EinsumOp to HLO EinsumOp +class ConvertEinsumOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter &rewriter) const override { + // Prepend `,` to equation if unary einsum. + std::string equation_str = op.getEquation().str(); + llvm::SmallVector inputs; + + // Unary einsum prepends `,` to equation and + // creates a scalar constant 1.0 for first operand. + if (op.getN() == 1) { + equation_str = "," + equation_str; + inputs.push_back(rewriter.create( + op.getLoc(), hlo::getScalarOfType( + mlir::getElementTypeOrSelf(op.getOperand(0)), 1))); + } + // Insert remaining operands into inputs, TF op verifier requires there be + // 0 or 1 operands. + auto operands = op.getInputs(); + inputs.insert(inputs.end(), operands.begin(), operands.end()); + assert(inputs.size() == 2); + + rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], + inputs[1], equation_str); + return success(); + } +}; + +// Bypasses IdentityN op. +class ConvertIdentityNOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::IdentityNOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + +template +class ConvertFFTOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto input_ty = mlir::cast(op.getInput().getType()); + if (!input_ty.hasRank()) { + return failure(); + } + auto input_shape = input_ty.getShape(); + DenseIntElementsAttr fft_length_attr; + if (!matchPattern(op.getFftLength(), m_Constant(&fft_length_attr))) { + return failure(); + } + int64_t fft_length; + if (fft_length_attr.getNumElements() != 0) { + fft_length = fft_length_attr.getValues()[0].getInt(); + } else { + return failure(); + } + + int64_t expected_dim = fft_length; + std::string fft_string = "RFFT"; + if (typeid(OpTy) == typeid(TF::IRFFTOp)) { + expected_dim = fft_length / 2 + 1; + fft_string = "IRFFT"; + } + Location loc = op.getLoc(); + + // The inner-most dim cannot be dynamic. + if (input_ty.isDynamicDim(input_shape.size() - 1)) { + return failure(); + } + + auto expected_shape = llvm::to_vector<4>(input_shape.drop_back()); + expected_shape.push_back(expected_dim); + + // Zero pad or truncate the last axis + Value reshaped = op.getInput(); + SmallVector begin_indices(input_shape.size(), 0); + SmallVector strides(input_shape.size(), 1); + + // Last dim larger than expected_dim, slice the input + if (input_shape.back() > expected_dim) { + reshaped = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape(expected_shape, + input_ty.getElementType()), + op.getInput(), GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(expected_shape, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + + // Last dim smaller than expected_dim, zero-pad the input + } else if (input_ty.getShape().back() < expected_dim) { + SmallVector no_padding(input_shape.size(), 0); + SmallVector padding(input_shape.size() - 1, 0); + padding.push_back(expected_dim - input_shape.back()); + Value zero = + GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter); + reshaped = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(expected_shape, + input_ty.getElementType()), + op.getInput(), zero, GetI64ElementsAttr(no_padding, &rewriter), + GetI64ElementsAttr(padding, &rewriter), + GetI64ElementsAttr(no_padding, &rewriter)); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), reshaped, + FftTypeAttr::get(rewriter.getContext(), + symbolizeFftType(fft_string).value()), + rewriter.getI64TensorAttr(fft_length)); + return success(); + } +}; + +using ConvertRFFTOp = ConvertFFTOp; +using ConvertIRFFTOp = ConvertFFTOp; + +// The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO +// BatchNormGradOp for training and a sequence of binary ops for inference. +// TODO(b/145536565): move to legalize_tf_patterns.td if it applies. +template +class ConvertFusedBatchNormGradBase + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FusedBatchNormGradOpT op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value grad = op.getYBackprop(); + Value act = op.getX(); + Value scale = op.getScale(); + Value mean = op.getReserveSpace_1(); + Value var = op.getReserveSpace_2(); + + // TODO(b/141785544): Update this to not require static shapes. + // activation shape needs to be static to convert negative indices in + // TensorFlow to absolute indices required by HLO. + RankedTensorType act_type = mlir::dyn_cast(act.getType()); + if (!act_type) return failure(); + Type act_ele_type = act_type.getElementType(); + // To support mixed precision, the statistics type, which maybe more + // precise than the input types, are used for this op. + Type kernel_type = mlir::cast(scale.getType()).getElementType(); + grad = rewriter.create(loc, grad, kernel_type); + act = rewriter.create(loc, act, kernel_type); + + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + auto feature_dim_attr = getFeatureDimensionAttr(rewriter, data_format, act); + auto feature_dim = feature_dim_attr.getValue().getSExtValue(); + + // Gets the result values. + Value x_backprop, scale_backprop, offset_backprop; + if (op.getIsTraining()) { // training + // TODO(b/145536565): handle GPU logic separately. + // Infers the output type with the converted `act`. + Type feature_type = tensorflow::GetTypeFromTFTensorShape( + {GetDimSize(act_type, feature_dim)}, kernel_type); + + SmallVector operand_types = {act.getType(), feature_type, + feature_type}; + auto training_op = rewriter.create( + loc, operand_types, act, scale, mean, var, grad, op.getEpsilon(), + feature_dim); + + x_backprop = training_op.getResult(0); + + scale_backprop = training_op.getResult(1); + + offset_backprop = training_op.getResult(2); + } else { // inference + SmallVector non_feature_dims; + for (int64_t i = 0; i < act_type.getRank(); ++i) { + if (i == feature_dim) continue; + non_feature_dims.push_back(i); + } + auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); + auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); + + // scratch1 = rsqrt(var + epsilon) + RankedTensorType scalar_float = + tensorflow::GetTypeFromTFTensorShape({}, kernel_type); + auto epsilon = rewriter.create( + loc, DenseFPElementsAttr::get(scalar_float, {op.getEpsilon()})); + auto add_op = rewriter.create( + loc, var, epsilon.getResult(), scalar_broadcast_dims); + + Value scratch1 = rewriter.create(loc, add_op); + + // scratch2 = sum(y_backprop * (x - mean)) + auto sub_op = rewriter.create( + loc, act, + Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); + auto weighted_grad = rewriter.create(loc, grad, sub_op); + Value scratch2 = + ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); + + // x_backprop = y_backprop * (scale * scratch1) + auto scaled_grad = + rewriter.create(loc, op.getScale(), scratch1); + x_backprop = rewriter.create( + loc, grad, + Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, + rewriter)); + + // scale_backprop = scratch2 * scratch1 + scale_backprop = rewriter.create(loc, scratch1, scratch2); + + // offset_backprop = sum(y_backprop) + offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); + } + + x_backprop = rewriter.create(loc, x_backprop, act_ele_type); + Value last_val[2]; + if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) { + // It doesn't matter what values we provide for the last 2 results. + last_val[0] = last_val[1] = op.getX(); + } else { + auto const_val = rewriter.create( + op.getLoc(), DenseElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape( + {0}, getElementTypeOrSelf(op.getResult(3))), + 0.0)); + auto maybe_cast = [&](Value val, Type t) -> Value { + if (val.getType() == t) return val; + return rewriter.create(op.getLoc(), t, val); + }; + last_val[0] = maybe_cast(const_val, op.getResult(3).getType()); + last_val[1] = maybe_cast(const_val, op.getResult(4).getType()); + } + rewriter.replaceOp( + op, {/*x_backprop=*/x_backprop, + /*scale_backprop=*/scale_backprop, + /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]}); + return success(); + } +}; + +using ConvertFusedBatchNormGradOp = + ConvertFusedBatchNormGradBase; +using ConvertFusedBatchNormGradV2Op = + ConvertFusedBatchNormGradBase; +using ConvertFusedBatchNormGradV3Op = + ConvertFusedBatchNormGradBase; + +// Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or +// HLO BatchNormInferenceOp, depending on the value of the 'is_training' +// parameter. +template +class ConvertFusedBatchNormBase : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FusedBatchNormOpT op, + PatternRewriter &rewriter) const override { + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + auto feature_dim = + getFeatureDimensionAttr(rewriter, data_format, op.getX()); + + auto input_type_tensor = mlir::cast(op.getX().getType()); + auto input_element_type = input_type_tensor.getElementType(); + + auto scale_type_tensor = mlir::cast(op.getScale().getType()); + auto scale_element_type = scale_type_tensor.getElementType(); + + auto mean_type_tensor = mlir::cast(op.getMean().getType()); + auto mean_element_type = mean_type_tensor.getElementType(); + // In the training case, dimensions of input tensors must be static. + if (op.getIsTraining() && (!input_type_tensor.hasStaticShape() || + !scale_type_tensor.hasStaticShape() || + !mean_type_tensor.hasStaticShape())) + return failure(); + + // TODO(b/69928690): Support mixed precision in the XLA batch + // normalization operators. As a workaround, create a new x with the same + // element type as scale (which may be more precise than the input type). + Value bn_train_input = rewriter.create( + op.getLoc(), op.getX(), scale_element_type); + TensorType bn_train_input_type_tensor = + mlir::cast(bn_train_input.getType()); + + if (op.getIsTraining()) { + // Training case. + auto operand_shape = bn_train_input_type_tensor.getShape(); + // The mean and variance are each 1 dimensional arrays the size of the + // feature dimension, with the same element type as the operand (x). + // This shape must be constructed manually because the mean and variance + // inputs are empty in the training case. + Type mean_var_type = tensorflow::GetTypeFromTFTensorShape( + {operand_shape[feature_dim.getInt()]}, scale_element_type); + // Op result type is a tuple of 3 values: output with same shape as input; + // batch_mean, and batch_var. + SmallVector operand_types = {bn_train_input_type_tensor, + mean_var_type, mean_var_type}; + auto bn_train_op = rewriter.create( + op.getLoc(), operand_types, bn_train_input, op.getScale(), + op.getOffset(), op.getEpsilon(), feature_dim.getInt()); + // HLO op outputs a tuple of tensors. Extract those results. + Value y_out = bn_train_op.getResult(0); + Value batch_mean = bn_train_op.getResult(1); + Value reserve_space_1 = batch_mean; + Value batch_variance = bn_train_op.getResult(2); + + // Apply Bessel's correction on the variance. + int total_input_size = bn_train_input_type_tensor.getNumElements(); + int total_scale_size = scale_type_tensor.getNumElements(); + int sample_size = + total_scale_size > 0 ? total_input_size / total_scale_size : 0; + int sample_size_minus_one = std::max(1, sample_size - 1); + double factor = static_cast(sample_size) / + static_cast(sample_size_minus_one); + auto factor_const_op = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); + + Value corrected_variance = rewriter.create( + op.getLoc(), batch_variance.getType(), batch_variance, + factor_const_op, /*broadcast_dimensions=*/DenseI64ArrayAttr()); + + // Convert back to input type to stay aligned with expected output type + // for TF op. + y_out = rewriter.create(op.getLoc(), y_out, + input_element_type); + + float exponential_avg_factor = + op.getExponentialAvgFactor().convertToFloat(); + if (exponential_avg_factor != 1.0f) { + auto alpha = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(mean_element_type, + 1.0f - exponential_avg_factor)); + auto beta = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); + + // new_running_mean = alpha * old_mean + beta * batch_mean. + auto alpha_mul_old_mean = rewriter.create( + op.getLoc(), op.getMean().getType(), alpha, op.getMean(), + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + auto beta_mul_batch_mean = rewriter.create( + op.getLoc(), batch_mean.getType(), beta, batch_mean, + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + batch_mean = rewriter.create( + op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + + // new_running_variance = alpha * old_variance + beta * batch_variance. + auto alpha_mul_old_variance = rewriter.create( + op.getLoc(), op.getVariance().getType(), alpha, op.getVariance(), + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + auto beta_mul_batch_variance = rewriter.create( + op.getLoc(), corrected_variance.getType(), beta, corrected_variance, + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + corrected_variance = rewriter.create( + op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + } + + if (std::is_same::value) { + // FusedBatchNormV2 expects 4 outputs. + // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2". + // They are used to pass the per-batch mean and variance to the + // gradiant. Here we maintain the same behavior by setting them to the + // mean and variance calculated by BatchNormTraining. + rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, + /*batch_variance=*/corrected_variance, + /*reserve_space_1=*/reserve_space_1, + /*reserve_space_2=*/batch_variance}); + } else { // TF::FusedBatchNormV3Op + // For FusedBatchNormV3Op, also create a constant tensor to forward to + // last reserve_space_3 output. + auto reserve_space_3_type = + mlir::cast(op.getResult(5).getType()); + int num_elements = reserve_space_3_type.hasStaticShape() + ? reserve_space_3_type.getNumElements() + : 0; + auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( + {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); + Value dummy_const = rewriter.create( + op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); + if (const_attr_type != reserve_space_3_type) + dummy_const = rewriter.create( + op.getLoc(), reserve_space_3_type, dummy_const); + rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, + /*batch_variance=*/corrected_variance, + /*reserve_space_1=*/reserve_space_1, + /*reserve_space_2=*/batch_variance, + /*reserve_space_3=*/dummy_const}); + } + } else { // Inference case. + auto bn_train_op = rewriter.create( + op.getLoc(), + /*result_type=*/bn_train_input_type_tensor, bn_train_input, + op.getScale(), op.getOffset(), op.getMean(), op.getVariance(), + op.getEpsilon(), feature_dim.getInt()); + + // Convert back to input type to stay aligned with expected output type + // for TF op. + auto y_out = rewriter.create(op.getLoc(), bn_train_op, + input_element_type); + + // The mean, variance, and reserved space outputs of the batch norm op are + // not used for inference. It doesn't matter what values we provide for + // the last 5 results as long as they are of the same type. Forward + // input mean and variance to output mean, variance, reserved_space_1 and + // reserved_space_2. + if (std::is_same::value) { + rewriter.replaceOp(op, {/*y=*/y_out, + /*batch_mean=*/op.getMean(), + /*batch_variance=*/op.getVariance(), + /*reserve_space_1=*/op.getMean(), + /*reserve_space_2=*/op.getVariance()}); + } else { + // For FusedBatchNormV3Op, also create a constant tensor to forward to + // last reserve_space_3 output. + auto reserve_space_3_type = + mlir::cast(op.getResult(5).getType()); + int num_elements = reserve_space_3_type.hasStaticShape() + ? reserve_space_3_type.getNumElements() + : 0; + auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( + {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); + Value dummy_const = rewriter.create( + op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); + if (const_attr_type != reserve_space_3_type) + dummy_const = rewriter.create( + op.getLoc(), reserve_space_3_type, dummy_const); + rewriter.replaceOp(op, {/*y=*/y_out, + /*batch_mean=*/op.getMean(), + /*batch_variance=*/op.getVariance(), + /*reserve_space_1=*/op.getMean(), + /*reserve_space_2=*/op.getVariance(), + /*reserve_space_3=*/dummy_const}); + } + } + return success(); + } +}; + +using ConvertFusedBatchNormV2Op = + ConvertFusedBatchNormBase; +using ConvertFusedBatchNormV3Op = + ConvertFusedBatchNormBase; + +using PaddingArray = std::vector>; + +// Returns padding values for ReduceWindow op as a vector of pairs. +// +// Requires padding to be either 'SAME' or 'VALID' and the number of input +// dimensions to be equal to the size of window dimensions and window strides. +template +static PaddingArray GetReduceWindowPaddingAsArray( + llvm::ArrayRef input_dims, ArrayAttr window_dims, + ArrayAttr window_strides, StringRef padding, Builder *builder) { + if (padding == "VALID") { + return PaddingArray(num_dims, std::make_pair(0, 0)); + } + assert(padding == "SAME"); + llvm::SmallVector input_shape, window_shape, strides; + input_shape.reserve(input_dims.size()); + window_shape.reserve(window_shape.size()); + strides.reserve(window_strides.size()); + + for (const auto &dim : input_dims) input_shape.push_back(dim); + for (Attribute attr : window_dims) + window_shape.push_back(mlir::cast(attr).getInt()); + for (Attribute attr : window_strides) + strides.push_back(mlir::cast(attr).getInt()); + + PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides, + ::xla::Padding::kSame); + return paddings; +} + +// Same as GetReduceWindowPaddingAsArray but returns padding as +// DenseIntElementsAttr. Returns empty attribute for `VALID` padding. +template +static DenseIntElementsAttr GetReduceWindowPaddingAsAttr( + llvm::ArrayRef input_dims, ArrayAttr window_dims, + ArrayAttr window_strides, StringRef padding, Builder *builder) { + if (padding == "VALID") return {}; + assert(padding == "SAME"); + PaddingArray paddings = GetReduceWindowPaddingAsArray( + input_dims, window_dims, window_strides, padding, builder); + int64_t rank = paddings.size(); + llvm::SmallVector flatten_paddings(rank * 2); + for (int i = 0; i < rank; i++) { + flatten_paddings[2 * i] = paddings[i].first; + flatten_paddings[2 * i + 1] = paddings[i].second; + } + return DenseIntElementsAttr::get(tensorflow::GetTypeFromTFTensorShape( + {rank, 2}, builder->getIntegerType(64)), + flatten_paddings); +} + +// Helper function for dividing each entry of `pooled` by the count of its +// corresponding window, i.e., the number of non-padding entries of the window +// which an `AvgPool` operation performed on an `input_shape`-tensor would map +// to this entry, depending on `ksize` and `strides`. This function is used for +// `AvgPool` and `AvgPoolGrad` legalizations. +// `zero` is passed as a parameter because it can be reused from caller level. +// `pooled` must have `RankedTensorType`. +template +Operation *AvgPoolDivideByCount( + Value pooled, const SmallVector &input_shape, + const SmallVector &ksize, + const SmallVector &strides, OpTy op, Value zero, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + RankedTensorType pooled_type = mlir::cast(pooled.getType()); + Type element_type = pooled_type.getElementType(); + Operation *result = nullptr; + RankedTensorType orig_input_type = + tensorflow::GetTypeFromTFTensorShape(input_shape, element_type); + + if (op.getPadding() == "VALID") { + // All window counts are equal here because we don't have padding + // (each entry of `pooled` corresponds to a window that consists of + // original input entries only). + int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1, + std::multiplies()); + // Divide `pooled` by window counts. + Value divisor = + GetScalarConstOfType(element_type, loc, window_count, &rewriter); + auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); + result = rewriter.create( + loc, pooled_type, pooled, divisor, scalar_broadcast_dims); + } else { + assert(op.getPadding() == "SAME"); + // For SAME padding, only original entries that contributed to a window + // are counted for the average of this window, not padded entries. + + // Build all-ones tensor of same shape as the original input. + ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); + auto all_ones_tensor = rewriter.create(loc, splat); + + // Get padding for the input. + DenseIntElementsAttr input_padding_attr = + GetReduceWindowPaddingAsAttr(input_shape, op.getKsize(), + op.getStrides(), op.getPadding(), + &rewriter); + + // Count the 1's in each window, using the same padding as for the input, + // which gives us the window counts by which `pooled` needs to be divided. + auto divisor = rewriter.create( + loc, pooled_type, + /*operand=*/all_ones_tensor, + /*init_value=*/zero, + /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), + /*window_strides=*/GetI64ElementsAttr(op.getStrides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*padding=*/input_padding_attr); + BuildReduceBody(element_type, &divisor.getBody(), &rewriter); + + // Divide `pooled` by window counts. + result = rewriter.create(loc, pooled_type, pooled, + divisor.getResult(0)); + } + return result; +} + +Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.getValue(); } +Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.getInput(); } + +// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window +// dimensions with add as the reduction function. The reduction result is +// then divided by the number of elements in the window. +template +class ConvertAvgPoolOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Value input_value = GetAvgPoolInput(op); + auto input_type = mlir::dyn_cast(input_value.getType()); + if (!input_type) return failure(); + + // We will do accumulation first; use a larger bitwidth if suitable. + Type input_element_type = input_type.getElementType(); + Type sum_element_type = GetSumAccumulationType(input_element_type); + Type result_type; + + // The result type for reduction and division with the proper element type. + if (auto ranked_type = mlir::dyn_cast(op.getType())) + result_type = tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), + sum_element_type); + else + result_type = UnrankedTensorType::get(sum_element_type); + + // Convert if we need enlarge the element type's bitwidth. + if (input_element_type != sum_element_type) + input_value = rewriter.create(op.getLoc(), input_value, + sum_element_type); + + // Create the ReduceWindow op. + Value init = + GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_type.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + auto reduce = rewriter.create( + op.getLoc(), result_type, input_value, init, + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + BuildReduceBody(sum_element_type, &reduce.getBody(), &rewriter); + + // Count the number of elements in the window. The following calculation + // is only valid for no paddings. + SmallVector input_shape( + llvm::to_vector(input_type.getShape())); + SmallVector ksize, strides; + GetI64ArrayAttrValues(op.getKsize(), &ksize); + GetI64ArrayAttrValues(op.getStrides(), &strides); + + Operation *result_op = AvgPoolDivideByCount( + reduce.getResult(0), input_shape, ksize, strides, op, init, rewriter); + + // Convert back if we enlarged the element type's bitwidth. + Value result = result_op->getOpResult(0); + if (input_element_type != sum_element_type) + result = + rewriter.create(op.getLoc(), result, input_element_type); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +using ConvertAvgPool2DOp = ConvertAvgPoolOp; +using ConvertAvgPool3DOp = ConvertAvgPoolOp; + +// `AvgPoolGradOp` is converted to the following operations: +// 1. Divide each entry of the output gradient (the gradient for the previous +// layer in backpropagation order) by the count of the corresponding window +// (i.e., the number of non-padding entries of the window which `AvgPool` +// has mapped to this entry in forward propagation). +// 2. Add appropriate interior and exterior padding for step 3 (see example +// below). +// 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape +// as windows) and stride 1 in each dimension. This is implemented as a +// `ReduceWindowOp` with `AddOp` as body. +// +// Example: +// Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2, +// and SAME padding with 0's. It is defined by +// f(x) = [ (x_1 + x_2 + x_3) / 3 ] ( x = (x_1, x_2, x_3, x_4) ) +// [ (x_3 + x_4 + 0) / 2 ] (the 0 results from right padding) +// Note that for SAME padding in `AvgPool` the padded entries are not counted +// for the average, this is why the second denominator is 2 and not 3. +// The Jacobian Df is +// [ 1/3 1/3 1/3 0 ] +// [ 0 0 1/2 1/2 ] +// +// Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only +// needs the original input shape and not the tensor as argument). +// Let v = [ 4 6 ]^T be the output gradient (^T = transposed). Then the +// average pool gradient is given by +// Df^T * v = [ 4/3 4/3 13/3 3 ]^T +// Instead of a matrix-vector-multiplication we can utilize the sparsity and +// structure of Df by using the 3-step approach from above: +// 1. Divide output gradient v by window counts: [ 4/3 6/2 ]^T +// 2. Add appropriate padding: [ 0 0 4/3 0 3 0 ]^T +// 3. Convolve with kernel [ 1 1 1 ]: [ 4/3 4/3 11/3 3 ]^T +// +// Note that the padding in step 2. is chosen in such a way that the subsequent +// convolution produces the gradient. Higher dimensions, different padding, and +// different windows/strides work in a similar way, the main difference is in +// the computation of the paddings in step 2. +// +// For more details on backpropagation for convolution of which `AvgPoolGrad` +// is a special case see `tensorflow/core/kernels/conv_grad_ops.h`. +// `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more +// examples for different cases. +template +class ConvertAvgPoolGradOp : public OpRewritePattern { + using DimVector = SmallVector; + + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) { + return op.emitOpError("invalid data format"); + } + // `out_grad` is the gradient that was propagated via backpropagation from + // the output layer. + Value out_grad = op.getGrad(); + auto out_grad_type = mlir::dyn_cast(out_grad.getType()); + if (!out_grad_type) { + return failure(); + } + Type element_type = out_grad_type.getElementType(); + DenseIntElementsAttr orig_input_shape_attr; + if (!matchPattern(op.getOrigInputShape(), + m_Constant(&orig_input_shape_attr))) { + return failure(); + } + auto orig_input_shape_values = orig_input_shape_attr.getValues(); + DimVector orig_input_shape(orig_input_shape_values.begin(), + orig_input_shape_values.end()); + DimVector ksize, strides; + GetI64ArrayAttrValues(op.getKsize(), &ksize); + GetI64ArrayAttrValues(op.getStrides(), &strides); + Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter); + + auto out_grad_divided = AvgPoolDivideByCount( + out_grad, orig_input_shape, ksize, strides, op, zero, rewriter); + + // Get same padding as for original input. + PaddingArray orig_padding = GetReduceWindowPaddingAsArray( + orig_input_shape, op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + + // Add padding around `out_grad_divided` values in such a way that the + // subsequent `ReduceWindowOp` produces the gradient. + DimVector out_grad_shape( + llvm::to_vector(out_grad_type.getShape())); + DimVector low_padding(num_dims, 0); + DimVector high_padding(num_dims, 0); + DimVector interior_padding(num_dims, 0); + constexpr int num_spatial_dims = num_dims - 2; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i); + int orig_input_shape_padded_in_dim = orig_input_shape[dim] + + orig_padding[dim].first + + orig_padding[dim].second; + // Set interior padding such that neighboring entries from + // `out_grad_divided` have distance `strides[dim]` from each other in + // every dimension. + interior_padding[dim] = strides[dim] - 1; + // Set exterior padding in the same way as for convolution gradient + // computation. + auto status = ::xla::ConvGradExtractAndVerifyDimension( + /*input_size=*/orig_input_shape_padded_in_dim, + /*filter_size=*/ksize[dim], + /*output_size=*/out_grad_shape[dim], + /*dilation=*/1, + /*stride=*/strides[dim], + /*padding=*/::xla::Padding::kValid); + if (!status.ok()) { + return failure(); + } + ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim = + status.value(); + // Subtract the original exterior padding since it doesn't contribute to + // the gradient. Note that we save one `PadOp` and some unnecessary kernel + // computations, compared to the `xla::AvgPoolGrad` implementation, by + // subtracting the original exterior padding before `ReduceWindowOp` + // instead of trimming the result of `ReduceWindowOp` (the final result is + // the same because all strides are 1). + low_padding[dim] = + conv_grad_spatial_dim.pad_before - orig_padding[dim].first; + high_padding[dim] = + conv_grad_spatial_dim.pad_after - orig_padding[dim].second; + + // Update `out_grad_shape` to result shape of following `PadOp`. + out_grad_shape[dim] = low_padding[dim] + high_padding[dim] + + (out_grad_shape[dim] - 1) * strides[dim] + 1; + } + Value reduce_window_input = rewriter.create( + loc, tensorflow::GetTypeFromTFTensorShape(out_grad_shape, element_type), + /*operand=*/out_grad_divided->getOpResult(0), + /*padding_value=*/zero, + /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter), + /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter), + /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter)); + + // Compute result by convolving `reduce_window_input` with an all-ones + // kernel, using `ReduceWindowOp` with `AddOp` body. + + Type sum_element_type = GetSumAccumulationType(element_type); + if (element_type != sum_element_type) { + // Convert to appropriate sum accumulation type to avoid precision loss. + reduce_window_input = rewriter.create(loc, reduce_window_input, + sum_element_type); + zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter); + } + auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter); + auto reduce_window_op = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(orig_input_shape, + sum_element_type), + /*operand=*/reduce_window_input, + /*init_value=*/zero, + /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), + /*window_strides=*/ones, + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*padding=*/DenseIntElementsAttr()); + BuildReduceBody(sum_element_type, &reduce_window_op.getBody(), + &rewriter); + Value result = reduce_window_op.getResult(0); + + if (element_type != sum_element_type) { + // Convert back to original element type. + result = rewriter.create(op.getLoc(), result, element_type); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +using ConvertAvgPool2DGradOp = + ConvertAvgPoolGradOp; +using ConvertAvgPool3DGradOp = + ConvertAvgPoolGradOp; + +// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window +// dimensions with max as the reduction function. +// +// Sample result for VALID padding mode: +// +// %init = arith.constant dense<...> : tensor +// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// {window_dimensions = ..., window_strides = ... } +// +template +class ConvertMaxPoolOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Type element_type = + mlir::cast(op.getInput().getType()).getElementType(); + if (!element_type.isSignlessIntOrFloat()) return failure(); + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + if (padding == tensorflow::Padding::EXPLICIT) { + return failure(); + } + Location loc = op.getLoc(); + ConstantOp init = GetScalarLimitConstOfType( + element_type, loc, hlo::kInfinityLowest, &rewriter); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty) return failure(); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + auto reduce = rewriter.create( + loc, op.getType(), op.getInput(), init, + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + BuildReduceBody(element_type, &reduce.getBody(), &rewriter); + + rewriter.replaceOp(op, reduce.getResult(0)); + return success(); + } +}; + +using ConvertMaxPool2DOp = ConvertMaxPoolOp; +using ConvertMaxPool3DOp = ConvertMaxPoolOp; + +// Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on +// the condition only. +class ConvertSelectOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SelectOp op, + PatternRewriter &rewriter) const override { + // This lowering only works on ranked types. + auto cond_type = + mlir::dyn_cast(op.getCondition().getType()); + auto then_type = + mlir::dyn_cast(op.getThenValue().getType()); + auto else_type = + mlir::dyn_cast(op.getElseValue().getType()); + if (!cond_type || !then_type || !else_type) { + return failure(); + } + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value cond_shape = b.createOrFold(op.getCondition()); + Value then_shape = b.createOrFold(op.getThenValue()); + Value else_shape = b.createOrFold(op.getElseValue()); + + // First check that the `then` and `else` shapes are the equal. + Value assumption = + b.createOrFold(ValueRange{then_shape, else_shape}); + // For a vector cond we also verify that the majormost dim of `then` matches + // the vector size. To do that split off the first dim of `then`. + bool needs_broadcast = cond_type.getRank() == 1 && then_type.getRank() != 1; + Value then_shape_split = then_shape; + if (needs_broadcast) { + Value const_one = b.create(1); + Type extent_first = shape::getExtentTensorType(b.getContext(), 1); + Type extent_second = + shape::getExtentTensorType(b.getContext(), then_type.getRank() - 1); + SmallVector then_split; + b.createOrFold(then_split, + TypeRange{extent_first, extent_second}, + then_shape, const_one); + then_shape_split = then_split[0]; + } + // If the condition is not a scalar, check that it matches the other shapes. + if (cond_type.getRank() > 0) { + Value eq_cstr = b.createOrFold( + ValueRange{cond_shape, then_shape_split}); + auto witness = shape::WitnessType::get(b.getContext()); + assumption = b.createOrFold( + witness, ValueRange{assumption, eq_cstr}); + } + auto result_type = mlir::cast(op.getResult().getType()); + auto assuming_op = + b.create(ArrayRef{result_type}, assumption); + + OpBuilder::InsertionGuard guard(b); + b.createBlock(&assuming_op.getDoRegion()); + + // Broadcast the cond if necessary. + Value cond = op.getCondition(); + if (needs_broadcast) { + Value result_extents = b.create( + GetExtentsTensorTypeFor(result_type), then_shape); + cond = b.create( + tensorflow::GetTypeFromTFTensorShape(result_type.getShape(), + b.getI1Type()), + cond, result_extents, + GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b)); + } + Value select = b.create( + result_type, cond, op.getThenValue(), op.getElseValue()); + b.create(select); + rewriter.replaceOp(op, {assuming_op.getResult(0)}); + return success(); + } +}; + +// Converts the tf.Slice op into mhlo.real_dynamic_slice +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertSliceOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SliceOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getInput(); + Value begin_indices = op.getBegin(); + Value sizes = op.getSize(); + + auto input_ty = mlir::dyn_cast(input.getType()); + auto begin_type = mlir::dyn_cast(begin_indices.getType()); + auto size_type = mlir::dyn_cast(sizes.getType()); + + if (!input_ty || !begin_type || !size_type || + !begin_type.hasStaticShape() || !size_type.hasStaticShape() || + begin_type.getRank() != 1 || size_type.getRank() != 1) { + return failure(); + } + // TODO(disc): remove static shape check once folding/canonicalization func + // added + DenseIntElementsAttr size_attr; + if (matchPattern(op.getSize(), m_Constant(&size_attr))) { + return failure(); + } + + int rank = begin_type.getDimSize(0); + auto shape_scalar_type = begin_type.getElementType(); + Value one = rewriter.create(loc, 1); + SmallVector stride_values(rank, one); + SmallVector end_values; + SmallVector begin_values; + end_values.reserve(rank); + for (int i = 0; i < rank; ++i) { + SmallVector indices; + indices.push_back(rewriter.create(loc, i)); + auto begin_value = + rewriter.create(loc, begin_indices, indices); + auto size_value = rewriter.create(loc, sizes, indices); + Value minus_one = rewriter.create( + loc, shape_scalar_type, + rewriter.create(loc, -1)); + auto is_minus_one = rewriter.create( + loc, arith::CmpIPredicate::eq, size_value, minus_one); + Value end_value = + rewriter.create(loc, begin_value, size_value); + auto dim_value = rewriter.create( + loc, shape_scalar_type, + rewriter.create(loc, input, i)); + end_value = rewriter.create(loc, is_minus_one, + dim_value, end_value); + auto end_value_casted = rewriter.create( + loc, rewriter.getIndexType(), end_value); + end_values.push_back(end_value_casted); + + auto begin_value_casted = rewriter.create( + loc, rewriter.getIndexType(), begin_value); + begin_values.push_back(begin_value_casted); + } + auto index_ty = rewriter.getIndexType(); + auto start_indices = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(begin_values.size())}, index_ty), + begin_values); + auto end_indices = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(end_values.size())}, index_ty), + end_values); + auto stride_indices = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(stride_values.size())}, index_ty), + stride_values); + + auto d_slice = rewriter.create( + loc, op.getOperation()->getResult(0).getType(), input, start_indices, + end_indices, stride_indices); + rewriter.replaceOp(op, d_slice.getOperation()->getResults()); + return success(); + } +}; + +static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, + Value *out_lhs, Value *out_rhs, + PatternRewriter *rewriter) { + // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is: + // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS] + // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS] + // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS] + // To perform the matmul, we need to first broadcast lhs and rhs to a common + // set of leading dimensions before doing the actual matmul. + // That's what the code below does. + // In particular, we populate out_lhs and out_rhs to have dimension structure: + // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS] + // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS] + // To do this, we need to calculate those output shapes, which involves + // slicing off the leading batch dims of each operand, broadcasting them, + // then concatenating the broadcasted leading dims back to the row/col dims. + // Finally, we create a TF::BroadcastTo op that does the actual broadcast. + + // TODO(silvasean): Reduce duplication across reified shape calculations and + // the static computation of output types needed to create ops. + Value lhs_shape = rewriter->create(loc, lhs); + Value rhs_shape = rewriter->create(loc, rhs); + Value const_neg2 = + rewriter->create(loc, rewriter->getIndexAttr(-2)); + auto shape_type = shape::ShapeType::get(rewriter->getContext()); + auto lhs_splitted = rewriter->create( + loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2); + auto rhs_splitted = rewriter->create( + loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); + // The last two dimensions are the matrix row/col dimensions. Don't broadcast + // them. + SmallVector result_batch_shape_compile_time_extents; + mlir::OpTrait::util::getBroadcastedShape( + lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2), + result_batch_shape_compile_time_extents); + auto result_batch_shape = rewriter->create( + loc, shape_type, lhs_splitted.getHead(), rhs_splitted.getHead(), + /*error=*/nullptr); + // Lambda which handles the broadcasting of one side to the common + // leading-batch dimensions. + auto broadcast_one_side = [&](Value side, RankedTensorType type, + Value tail_shape, Value *out_side) { + ArrayRef matrix_dims = type.getShape().take_back(2); + auto result_shape = result_batch_shape_compile_time_extents; + result_shape.append(matrix_dims.begin(), matrix_dims.end()); + auto result_type = tensorflow::GetTypeFromTFTensorShape( + result_shape, type.getElementType()); + auto shape = rewriter->create( + loc, shape_type, result_batch_shape, tail_shape); + auto shape_tensor = rewriter->create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(result_shape.size())}, + rewriter->getIndexType()), + shape); + *out_side = rewriter->create(loc, result_type, side, + shape_tensor); + }; + broadcast_one_side(lhs, lhs_type, lhs_splitted.getTail(), out_lhs); + broadcast_one_side(rhs, rhs_type, rhs_splitted.getTail(), out_rhs); +} + +class ConvertBatchMatMulV2Op : public OpRewritePattern { + public: + // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved + // to CHLO and it is missing legalization to MHLO. Once that is done, this + // pattern's benefit can be changed back to one as well as the fallback + // lowering pattern for the op can be removed. + // + // Set benefit of this pattern to zero to prefer the fallback pattern when + // available and applicable. That pattern avoids broadcast on operands and is + // therefore faster. + // + // Native legalization for BatchMatMulV3 needs to be added as well. + explicit ConvertBatchMatMulV2Op(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, + PatternRewriter &rewriter) const override { + Value lhs = op.getX(); + Value rhs = op.getY(); + auto lhs_type = mlir::dyn_cast(lhs.getType()); + auto rhs_type = mlir::dyn_cast(rhs.getType()); + if (!lhs_type || !rhs_type) return failure(); + if (mlir::isa(lhs_type.getElementType()) && op.getAdjX()) { + lhs = rewriter.create(op.getLoc(), lhs_type, lhs); + } + if (mlir::isa(rhs_type.getElementType()) && op.getAdjY()) { + rhs = rewriter.create(op.getLoc(), rhs_type, rhs); + } + + // Broadcast both operands. + BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, + &rewriter); + lhs_type = mlir::cast(lhs.getType()); + rhs_type = mlir::cast(rhs.getType()); + assert(lhs_type.getRank() == rhs_type.getRank()); + int64_t rank = lhs_type.getRank(); + auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); + auto lhs_contracting_dimensions = llvm::to_vector<4>( + llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); + auto rhs_contracting_dimensions = llvm::to_vector<4>( + llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); + auto dimension_numbers = DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhs_batching_dimensions=*/batch_dimensions, + /*rhs_batching_dimensions=*/batch_dimensions, + /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, + /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); + // TODO(silvasean): Emit shape checks for contracting dimensions. + // (The batch dimensions are checked by the broadcasting logic) + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, rhs, dimension_numbers, + /*precision_config=*/GetPrecisionConfig(&rewriter), + /*algorithm=*/DotAlgorithmAttr{}); + return success(); + } +}; + +// Converts the tf.Split op into a series of HLO slice ops when the tensor to be +// split has fully static shape and the dimension to split is a constant. +// +// The main logic of this pattern is to calculate the index start and end range +// for each slice. And this happens only on the dimension to be split; for all +// other dimensions, all resultant slices' index start and end range covers the +// input tensor's full range. Strides for all resultant slices are all one. +// +// For example, the following source IR: +// +// %dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// %0:3 = "tf.Split"(%dim, %input) : (tensor, tensor<4x6xf32>) -> +// (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) +// +// will be converted into: +// +// %0 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 2]> : tensor<2xi64>, +// start_indices = dense<0> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// %1 = "mhlo.slice"(%input) { +// limit_indices = dense<4> : tensor<2xi64>, +// start_indices = dense<[0, 2]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// %2 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 6]> : tensor<2xi64>, +// start_indices = dense<[0, 4]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// TODO(antiagainst): consider lowering into TF ops so the pattern can be more +// applicable. +class ConvertSplitOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SplitOp op, + PatternRewriter &rewriter) const override { + // We can only split inputs that have fully static shape. + auto input_type = mlir::dyn_cast(op.getValue().getType()); + if (!input_type || !input_type.hasStaticShape()) return failure(); + + // We can only match when the split dimension is a constant scalar. + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) + return failure(); + + // Get the dimension we are splitting at. Offset properly if it's negative. + int64_t input_rank = input_type.getRank(); + int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); + if (dim_index < 0) dim_index += input_rank; + + // Calculate the dimension size for each slice along the split dimension. + int64_t input_dim_size = input_type.getDimSize(dim_index); + + int64_t num_splits = op.getNumResults(); + int64_t slice_size = input_dim_size / num_splits; + + // Get each slice's type. + auto slice_shape = llvm::to_vector<4>(input_type.getShape()); + slice_shape[dim_index] = slice_size; + Type slice_type = tensorflow::GetTypeFromTFTensorShape( + slice_shape, input_type.getElementType()); + + // Parameters for constructing each slice. + SmallVector begin_indices(input_rank, 0); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); + SmallVector strides(input_rank, 1); + + // All HLO slice results used to replace the original tf.Split op. + SmallVector slices; + slices.reserve(num_splits); + + for (int i = 0; i < num_splits; ++i) { + begin_indices[dim_index] = i * slice_size; + end_indices[dim_index] = (i + 1) * slice_size; + slices.push_back( + rewriter.create(op.getLoc(), slice_type, op.getValue(), + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter))); + } + + rewriter.replaceOp(op, slices); + return success(); + } +}; + +// Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the +// dimension to split is a constant. +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. delete ConvertSplitOp +class ConvertSplitOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SplitOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getValue(); + auto input_type = mlir::dyn_cast(input.getType()); + if (!input_type) return failure(); + + // TODO(disc): remove static shape check once folding/canonicalization func + // added and ConvertSplitOp deleted. Calculate the dimension size for each + // slice along the split dimension. We are splitting along the dynamic + // dimension, or using static pattern transform + if (input_type.hasStaticShape()) return failure(); + + // We can only match when the split dimension is a constant scalar. + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) + return failure(); + + // Get the dimension we are splitting at. Offset properly if it's negative. + int64_t input_rank = input_type.getRank(); + int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); + if (dim_index < 0) dim_index += input_rank; + + Value input_dim_size = + rewriter.create(loc, input, dim_index); + // Calculate the dimension size for each slice along the split dimension. + int num_splits = op.getNumResults(); + Value num_splits_value = rewriter.create( + loc, rewriter.getIndexAttr(num_splits)); + Value slice_size = + rewriter.create(loc, input_dim_size, num_splits_value); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + SmallVector begin_indices(input_rank, zero); + SmallVector end_indices; + end_indices.reserve(input_rank); + SmallVector strides(input_rank, one); + for (int i = 0; i < input_rank; ++i) { + end_indices.push_back(rewriter.create(loc, input, i)); + } + + // All HLO d_slice results used to replace the original tf.Split op. + SmallVector slices; + slices.reserve(num_splits); + + for (int i = 0; i < num_splits; ++i) { + begin_indices[dim_index] = rewriter.create( + loc, slice_size, rewriter.create(loc, i)); + end_indices[dim_index] = rewriter.create( + loc, slice_size, rewriter.create(loc, i + 1)); + + Type index_ty = rewriter.getIndexType(); + auto begin_value = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(begin_indices.size())}, index_ty), + begin_indices); + auto end_value = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(end_indices.size())}, index_ty), + end_indices); + auto stride_value = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(strides.size())}, index_ty), + strides); + slices.push_back(rewriter.create( + loc, op.getOperation()->getResult(i).getType(), input, begin_value, + end_value, stride_value)); + } + + rewriter.replaceOp(op, slices); + return success(); + } +}; + +// Converts the tf.SplitV op into a series of HLO slice ops when the tensor to +// be split has fully static shape and the dimension to split and split sizes +// are constants. +// +// This is similar to the conversion for tf.Split op other than that the size of +// each chunk on the dimension to split is explicitly given as an op operand +// and they are not necessarily the same. +// +// For example, given the following IR: +// +// %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} +// %split_dim = "tf.Const"() {value = dense<1> : tensor} +// %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : +// (tensor<4x6xf32>, tensor<3xi32>, tensor) -> +// (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) +// +// We will generate slices following slices: +// %0 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 1]> : tensor<2xi64>, +// start_indices = dense<0> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x1xf32> +// %1 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 3]> : tensor<2xi64>, +// start_indices = dense<[0, 1]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// %2 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 6]> : tensor<2xi64>, +// start_indices = dense<[0, 3]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x3xf32> +class ConvertSplitVOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SplitVOp op, + PatternRewriter &rewriter) const override { + // We can only split inputs that have fully static shape. + // TODO(b/145731001): enhance to support dynamic-shaped inputs. + auto input_type = mlir::dyn_cast(op.getValue().getType()); + if (!input_type || !input_type.hasStaticShape()) return failure(); + + // We can only match when the split dimension is a constant scalar. + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) + return failure(); + + // We can only match when the split sizes is a constant int vector. + DenseIntElementsAttr split_sizes_attr; + if (!matchPattern(op.getSizeSplits(), m_Constant(&split_sizes_attr))) + return failure(); + + // Get each chunck's size along the dimension to split. It may contain + // dynamic sizes and we need to update it if so. + SmallVector split_sizes; + int64_t total_dim_size = 0; // Total dimension size assigned to splits + std::optional dynamic_dim_index; + split_sizes.reserve( + mlir::cast(split_sizes_attr.getType()).getNumElements()); + for (const auto &dim : llvm::enumerate(split_sizes_attr)) { + int64_t dim_val = dim.value().getSExtValue(); + split_sizes.push_back(dim_val); + if (dim_val == -1) { + // We cannot have more than one dynamic dimension. + assert(!dynamic_dim_index && "invalid split sizes"); + dynamic_dim_index = dim.index(); + } else { + total_dim_size += dim_val; + } + } + + // Get the dimension we are splitting at. Offset properly if it's negative. + int64_t input_rank = input_type.getRank(); + int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); + if (dim_index < 0) dim_index += input_rank; + + int64_t input_dim_size = input_type.getDimSize(dim_index); + assert(((dynamic_dim_index && total_dim_size <= input_dim_size) || + (!dynamic_dim_index && total_dim_size == input_dim_size)) && + "invalid split sizes"); + + // Update the dynamic dimension with calculated concrete size. + if (dynamic_dim_index) + split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size; + + // Parameters for constructing each slice. + SmallVector begin_indices(input_rank, 0); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); + SmallVector strides(input_rank, 1); + + // All HLO slice results used to replace the original tf.Split op. + SmallVector slices; + slices.reserve(op.getNumResults()); + + for (int i = 0, end = op.getNumResults(); i < end; ++i) { + end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; + slices.push_back(rewriter.create( + op.getLoc(), op.getValue(), + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter))); + // Prepare the begin indice for the next slice. + begin_indices[dim_index] = end_indices[dim_index]; + } + + rewriter.replaceOp(op, slices); + return success(); + } +}; + +// Converts StridedSlice op to HLO Slice op along with Reverse op to handle +// negative strides and Reshape op to update the output shape. Indices and +// strides operands are converted to attributes with non-negative indexing. +// +// If the begin input is not a compile time constant, the begin input needs to +// be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this +// case, strides must have a known value of 1 (otherwise we have insufficient +// information to conform to XLA's op semantics). +// +// For example with an op like following, +// tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} +// : tensor -> tensor +// +// If the %begin input is constant, output would be: +// %reversed = "mhlo.Reverse" (%input) {dimensions = ...} +// %sliced = "mhlo.Slice" (%input) +// {start_indices = ..., limit_indices = ..., strides = ...} +// %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor +// +class ConvertStridedSliceOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op, + ArrayRef begin_indices, + ArrayRef end_indices, + ArrayRef strides, + RankedTensorType input_ty, + PatternRewriter &rewriter) const { + SmallVector hlo_begin_indices, hlo_end_indices, hlo_strides, + dims_to_reverse; + int64_t input_rank = input_ty.getRank(); + ArrayRef input_shape = input_ty.getShape(); + hlo_begin_indices.reserve(input_rank); + hlo_end_indices.reserve(input_rank); + hlo_strides.reserve(input_rank); + + int64_t indices_elements = begin_indices.size(); + if (input_rank < indices_elements) return failure(); + + // Convert from TensorFlow negative or out of range indices and strides + // values to legal HLO Slice attributes. + for (int i = 0, e = indices_elements; i != e; i++) { + int64_t begin = begin_indices[i]; + int64_t end = end_indices[i]; + int64_t stride = strides[i]; + + if (stride < 0) { + // Negative stride means that the output values are computed starting + // from end until begin. Mark the dimension for reversal before slice + // and compute indices for the reversed input. + dims_to_reverse.push_back(i); + begin = (input_shape[i] - 1) - begin; + end = (input_shape[i] - 1) - end; + stride = -stride; + } + + // Unlike TensorFlow, HLO requires begin and end values to be within + // range. + begin = std::max(int64_t(0), begin); + end = std::max(begin, end); + end = std::min(end, input_shape[i]); + + hlo_begin_indices.push_back(begin); + hlo_end_indices.push_back(end); + hlo_strides.push_back(stride); + } + + Location loc = op.getLoc(); + Value input = op.getInput(); + if (!dims_to_reverse.empty()) + input = rewriter.create( + loc, input_ty, op.getInput(), + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + auto sliced = rewriter.create( + loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter), + GetI64ElementsAttr(hlo_end_indices, &rewriter), + GetI64ElementsAttr(hlo_strides, &rewriter)); + + // Reshape slice result so that the shape is updated depending on + // 'new_axis_mask' or 'shrink_axis_mask' attributes. + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + return success(); + } + + LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op, + RankedTensorType input_ty, + RankedTensorType result_ty, + PatternRewriter &rewriter) const { + // If begin and end values are dynamic, we can only support this lowering + // if strides are a known value of 1. + DenseIntElementsAttr sparse_strides_attr; + if (!matchPattern(op.getStrides(), m_Constant(&sparse_strides_attr))) { + return rewriter.notifyMatchFailure( + op, + "requires that strides are known when begin/end values are dynamic"); + } + SmallVector strides; + int64_t stride_value; + for (const APInt &stride : sparse_strides_attr) { + if ((stride_value = stride.getSExtValue()) != 1) { + return rewriter.notifyMatchFailure(op, + "requires that strides are all 1 " + "when begin/end values are dynamic"); + } + strides.push_back(stride_value); + } + + ArrayRef input_shape = input_ty.getShape(); + int last_dim = std::max(static_cast(input_shape.size()) - 1, 0); + + // When begin/end values are dynamic, the ellipsis mask, if set, must refer + // to the last dimension. + int ellipsis_mask = op.getEllipsisMask(); + if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim))) + return rewriter.notifyMatchFailure( + op, + "requires that ellipsis_mask, if set, refer to the last dimension of " + "input (when begin/end values are dynamic)"); + + // In this case where the begin and end values are dynamic, we only support + // cases where the number of output elements has to be equal to the number + // of input elements that are sliced. Each dimension is either sliced fully + // or sliced with a size of one. + int output_elements = result_ty.getNumElements(); + int input_elements_sliced = 1; + + // Begin must be a ranked, 1-dimensional tensor: This is checked by the + // verifier. + int64_t slicing_dim_size = + mlir::cast(op.getBegin().getType()).getDimSize(0); + uint64_t begin_mask = op.getBeginMask(); + uint64_t end_mask = op.getEndMask(); + const int input_rank = input_shape.size(); + for (int d = 0; d < input_rank; ++d) { + // Each dimension is either sliced fully or has size of one. + if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) || + (d >= slicing_dim_size)) { + input_elements_sliced *= input_shape[d]; + } + } + if (input_elements_sliced != output_elements) { + return rewriter.notifyMatchFailure( + op, + "requires the number of output elements to be equal to the number of " + "input elements sliced (when begin/end values are dynamic)"); + } + + SmallVector slice_begin_indices; + // For the dimensions that are to be sliced, all have slice sizes of 1. + SmallVector slice_sizes; + auto begin_element_ty = + mlir::cast(op.getBegin().getType()).getElementType(); + // Scalar tensor type. + TensorType type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, begin_element_ty); + Location loc = op.getLoc(); + auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter); + for (int d = 0; d < input_rank; ++d) { + if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) || + (d >= slicing_dim_size)) { + slice_begin_indices.push_back(zero); + slice_sizes.push_back(input_shape[d]); + continue; + } + + auto index = rewriter.create( + loc, op.getBegin(), GetI64ElementsAttr({d}, &rewriter), + GetI64ElementsAttr({d + 1}, &rewriter), + GetI64ElementsAttr({1}, &rewriter)); + // Convert index to scalar. + auto reshaped_index = rewriter.create(loc, type, index); + // If the index is negative, wrap it around with dimension size. + auto index_negative = + rewriter.create(loc, reshaped_index, zero); + auto input_val = GetScalarConstOfType(begin_element_ty, loc, + input_shape[d], &rewriter); + auto wrapped_index = + rewriter.create(loc, input_val, reshaped_index); + auto final_index = rewriter.create( + loc, type, index_negative, wrapped_index, reshaped_index); + slice_begin_indices.push_back(final_index); + slice_sizes.push_back(1); + } + + auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter); + auto sliced_type = tensorflow::GetTypeFromTFTensorShape( + slice_sizes, op.getType().getElementType()); + // This must be an xla DynamicSlice op due to the inputs that aren't + // constant. + auto sliced = rewriter.create( + loc, sliced_type, op.getInput(), slice_begin_indices, slice_sizes_attr); + + // Reshape slice result so that the shape is updated depending on + // 'new_axis_mask' or 'shrink_axis_mask' attributes. + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + return success(); + } + + LogicalResult matchAndRewrite(TF::StridedSliceOp op, + PatternRewriter &rewriter) const override { + // Input shape needs to be static to convert negative indices in TensorFlow + // to absolute indices required by HLO. + // + // TODO(hinsu): Relax this constraint for ops without negative indices and + // strides. + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); + + // Output shape needs to be static to apply 'new_axis_mask' or + // 'shrink_axis_mask' by reshaping tensor after slice. + // + // TODO(hinsu): Relax this constraint for ops without the above masks. + auto result_ty = mlir::dyn_cast(op.getType()); + if (!result_ty || !result_ty.hasStaticShape()) return failure(); + + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr; + if (!matchPattern(op.getBegin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(op.getEnd(), m_Constant(&sparse_end_attr))) { + return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter); + } + + SmallVector begin_indices, end_indices, strides; + if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) { + return failure(); + } + return rewriteWithConstantBegin(op, begin_indices, end_indices, strides, + input_ty, rewriter); + } +}; + +// Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops. +// +// tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does +// the reverse: it propagates the graident for the sliced tensor to the original +// input tensor by doing padding with zeros. The main logic is calculating the +// indices and strides for padding. +class ConvertStridedSliceGradOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::StridedSliceGradOp op, + PatternRewriter &rewriter) const override { + // We need constant input shape to perform padding calculations later. + DenseIntElementsAttr input_shape_attr; + if (!matchPattern(op.getShape(), m_Constant(&input_shape_attr))) + return failure(); + + // We also need constant begin/end indices and strides to perform padding + // calculations. + // Bounded shape after performing strided slice + SmallVector shape; + // Bounded begin, end, and strides for strided slice + SmallVector begin_indices, end_indices, strides; + if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices, + &strides)) + return failure(); + + Value grad = op.getDy(); + Type element_type = mlir::cast(grad.getType()).getElementType(); + + // Perform reshape to undo any new/shrink axes done by strided slice. + grad = rewriter.create( + op.getLoc(), tensorflow::GetTypeFromTFTensorShape(shape, element_type), + grad); + + SmallVector padding_low, padding_high, padding_interm; + SmallVector dims_to_reverse; + padding_low.reserve(shape.size()); + padding_high.reserve(shape.size()); + padding_interm.reserve(shape.size()); + + // Prepare padding parameters for each dimension. + for (int i = 0, e = shape.size(); i < e; ++i) { + int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue(); + if (strides[i] > 0) { + padding_low.push_back(begin_indices[i]); + padding_interm.push_back(strides[i] - 1); + + // Pad the upper dimension up to the expected input shape. It's not + // sufficient simply to use end_indices[i] to compute the padding in + // cases where the stride does not divide evenly into the interval + // between begin_indices[i] and end_indices[i]. + int64_t size = + padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i]; + padding_high.push_back(input_dim - size); + } else { + dims_to_reverse.push_back(i); + padding_high.push_back(input_dim - begin_indices[i] - 1); + padding_interm.push_back(-strides[i] - 1); + + // Pad the lower dimension up to the expected input shape. + int64_t size = + padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i]; + padding_low.push_back(input_dim - size); + } + } + + if (!dims_to_reverse.empty()) { + grad = rewriter.create( + op.getLoc(), grad.getType(), grad, + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + + auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter); + rewriter.replaceOpWithNewOp( + op, op.getType(), grad, zero, + GetI64ElementsAttr(padding_low, &rewriter), + GetI64ElementsAttr(padding_high, &rewriter), + GetI64ElementsAttr(padding_interm, &rewriter)); + return success(); + } +}; + +/// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and +/// offset applied to generate the range values. The output tensor needs to +/// have a static shape. +/// +/// For example an op like the following: +/// %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"} +/// : (tensor, tensor, tensor) -> tensor<5xf32> +/// +/// Output would be: +/// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32> +/// %scaled = "mhlo.multiply"(%iota, %delta) +/// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : +/// (tensor<5xf32>, tensor) -> tensor<5xf32> +/// %result = "mhlo.add"(%scaled, %offset) +/// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : +/// (tensor<5xf32>, tensor) -> tensor<5xf32> +/// +/// Implementation is defined in C++ due to no type interface for the iota op. +class ConvertRangeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::RangeOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + auto result_type = result.getType(); + if (!mlir::cast(result_type).hasStaticShape()) { + return failure(); + } + + auto iota = rewriter.create(op.getLoc(), result_type, + rewriter.getI64IntegerAttr(0)); + auto scaled = rewriter.create( + op.getLoc(), result_type, iota, op.getDelta(), + hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.getDelta())); + rewriter.replaceOpWithNewOp( + op, result_type, scaled, op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.getStart())); + return success(); + } +}; + +// Converts RangeOp for cases with the length is a dynamic value. The shape of +// the resulting tensor computed, then the start and delta is used with the +// dynamic_iota value to compute the final range value. +// +// For example, the resulting range op value: +// %range = "tf.range"(%start, %limit, %delta) +// +// Is converted to the following. +// %start + %delta * iota(ceil(abs((%limit - %start) / %delta)) +// +// Implementation is defined in C++ due to the complicated type behavior. +class ConvertDynamicRangeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::RangeOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + auto result_type = mlir::cast(result.getType()); + if (result_type.hasStaticShape()) { + return failure(); + } + + Value start = op.getStart(); + Value delta = op.getDelta(); + Value limit = op.getLimit(); + + // To compute the length we need to use floating point calculations so that + // ceil can be computed for the number of steps. + auto compute_element_type = + mlir::isa(getElementTypeOrSelf(start.getType())) + ? getElementTypeOrSelf(start.getType()) + : rewriter.getF64Type(); + auto compute_type = tensorflow::GetTypeFromTFTensorShape( + mlir::cast(limit.getType()).getShape(), + compute_element_type); + + // Compute the length of the sequence we are going to need. This includes + // some conversion to float for the operations. + // + // %size = ceil(abs((%limit - %start) / %delta)) + auto range = rewriter.create(op.getLoc(), limit, start); + auto abs = rewriter.create(op.getLoc(), range); + + // Delta is not necessarily the same type as start and limit. + auto abs_cast = + rewriter.create(op.getLoc(), compute_type, abs); + auto delta_cast = + rewriter.create(op.getLoc(), compute_type, delta); + + // Compute the total number of integer steps and convert to the HLO + // dimension tensor. + auto normalized = + rewriter.create(op.getLoc(), abs_cast, delta_cast); + auto ceil = rewriter.create(op.getLoc(), normalized); + auto steps = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), ceil); + auto reshape = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), + steps); + + // Using the resulting length compute the correct range value: + // + // %range = %start + %delta * iota(%size) + auto out_scalar_type = tensorflow::GetTypeFromTFTensorShape( + {}, getElementTypeOrSelf(result_type)); + auto start_out_cast = + rewriter.create(op.getLoc(), out_scalar_type, start); + auto delta_out_cast = + rewriter.create(op.getLoc(), out_scalar_type, delta); + + auto iota = rewriter.create( + op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0)); + auto scaled = rewriter.create( + op.getLoc(), result_type, iota, delta_out_cast, + hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast)); + rewriter.replaceOpWithNewOp( + op, result_type, scaled, start_out_cast, + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast)); + return success(); + } +}; + +ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { + auto int_attr = mlir::cast(attr); + auto type = mlir::cast(val.getType()); + + SmallVector axis; + axis.reserve(int_attr.getNumElements()); + + int64_t rank = type.getRank(); + for (auto val : int_attr.getValues()) { + axis.push_back((val.getSExtValue() + rank) % rank); + } + + return builder->getI64TensorAttr(axis); +} + +/// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling +/// and offset applied to generate the linspace values. The output tensor needs +/// to have a static shape. The implementation is defined in C++ because there +/// is no type inference for the iota op. +class ConvertLinSpaceOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::LinSpaceOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + auto result_type = mlir::dyn_cast(result.getType()); + if (!result_type || !result_type.hasStaticShape()) { + return failure(); + } + + DenseIntElementsAttr num_attr; + if (!matchPattern(op.getNum(), m_Constant(&num_attr))) { + return rewriter.notifyMatchFailure(op, "Num must be a constant scalar"); + } + + if (num_attr.begin() == num_attr.end()) { + return rewriter.notifyMatchFailure(op, "Num must not be empty"); + } + int64_t num = (*num_attr.begin()).getSExtValue(); + + // Calculate the scaling that needs to be applied to the iota. + auto step_numerator = rewriter.create( + op.getLoc(), op.getStart().getType(), op.getStop(), op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, op.getStop(), + op.getStart())); + Value step_denominator = rewriter.create( + op.getLoc(), op.getNum(), result_type.getElementType()); + if (num > 1) { + Value one = GetScalarConstOfType(result_type.getElementType(), + op.getLoc(), 1, &rewriter); + step_denominator = rewriter.create( + op.getLoc(), step_denominator.getType(), step_denominator, one, + hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); + } + auto step = rewriter.create( + op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, + hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator, + step_denominator)); + + // Scale the iota and add the offset. + auto iota = rewriter.create(op.getLoc(), result_type, + rewriter.getI64IntegerAttr(0)); + auto scaled = rewriter.create( + op.getLoc(), result_type, iota, step, + hlo::getBroadcastDimensionsAttr(&rewriter, iota, step)); + rewriter.replaceOpWithNewOp( + op, result_type, scaled, op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.getStart())); + return success(); + } +}; + +/// Converts a generic OpTy tensorflow op to a mhlo.reduce op over +/// ReductionOp. +/// `is_accumulation` controls whether it uses higher precision for the actual +/// reduction. This is set to false for ops like max where there is no precision +/// concerns. +// +// The Derived class should have a static method to return the initial value to +// use for reduction: +// static Value GetInitialValue(Type reduce_element_type, Location loc, +// PatternRewriter *rewriter); +// The reduce_element_type is guaranteed to be a float, int, or complex type +// suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType. +template +class GenericConvertReductionOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // TODO(b/141785544): Update this to not require ranked shapes. + // Input shape needs to be ranked to convert negative indices in TensorFlow + // to absolute indices required by HLO. + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty) return failure(); + ArrayRef input_shape = input_ty.getShape(); + + DenseIntElementsAttr dimensions; + if (!matchPattern(op.getReductionIndices(), m_Constant(&dimensions))) + return failure(); + + // Build the final shape from input_shape and dimensions using a bitmap + // to mark the reduced dimensions. + SmallVector reduced_dimensions_bitmap(input_shape.size(), false); + SmallVector xla_dimensions; + for (const APInt &index_raw : dimensions.getValues()) { + int64_t index = index_raw.getSExtValue(); + int64_t rank = input_shape.size(); + if ((index < -rank || index >= rank)) return failure(); + index = (index + rank) % rank; + reduced_dimensions_bitmap[index] = true; + xla_dimensions.push_back(index); + } + + Location loc = op.getLoc(); + Type element_type = input_ty.getElementType(); + + // Only float, int, and complex types are currently supported. + if (!mlir::isa(element_type) && + !mlir::isa(element_type) && + !mlir::isa(element_type)) { + return rewriter.notifyMatchFailure( + op, "element type must be float, int, or complex type"); + } + + // Convert to an accumulation type to not lose precision when doing + // repeated arithmetic operations. + Type reduce_element_type = + is_accumulation ? GetAccumulationType(element_type) : element_type; + auto casted_input = + rewriter.create(loc, op.getInput(), reduce_element_type); + + // Each reduction op can have a different initial value. + Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter); + + auto reduction = rewriter.create( + loc, casted_input.getResult(), init, + GetI64ElementsAttr(xla_dimensions, &rewriter), reduce_element_type); + BuildReduceBody(reduce_element_type, &reduction.getBody(), + &rewriter); + Value result = reduction.getResult(0); + + // The mean op needs to divide by the product of the reduced dimensions. + if (std::is_same::value) { + Value in_shape = rewriter.create(loc, op.getInput()); + Value divisor_count = rewriter.create(loc, 1); + for (size_t i = 0; i < input_shape.size(); ++i) { + if (reduced_dimensions_bitmap[i]) { + Value index = rewriter.create(loc, i); + auto dim = rewriter.create(loc, in_shape, index); + divisor_count = + rewriter.create(loc, divisor_count, dim); + } + } + // HLO ops are only defined on tensors, so we cast the divisor from + // index -> i64 -> tensor<1xi64> -> tensor -> tensor + Value divisor_casted = rewriter.create( + loc, rewriter.getI64Type(), divisor_count); + Value divisor_tensor = rewriter.create( + loc, tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), + divisor_casted); + Value divisor = rewriter.create( + loc, tensorflow::GetTypeFromTFTensorShape({}, reduce_element_type), + divisor_tensor); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); + result = rewriter.create(loc, result, divisor, + broadcast_dims); + } + + result = rewriter.create(loc, result, element_type); + + // Need to reshape back after the reduction if we're keeping the reduced + // dimensions. Note that we do this through successive (nominally 1) + // applications of the TF ExpandDims op vs a more labor intensive + // reshape. Various code generation techniques benefit from the knowledge + // that this is a restricted form of shape manipulation that is just adding + // unit dims. + if (op.getKeepDims()) { + for (const auto &dim_is_reduced : + llvm::enumerate(reduced_dimensions_bitmap)) { + if (dim_is_reduced.value()) { + auto index_attr = GetI32ElementsAttr( + {static_cast(dim_is_reduced.index())}, &rewriter); + Value index = rewriter.create(loc, index_attr); + result = rewriter.create(loc, result, index); + } + } + } + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +// Converts Mean op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] +// {dimensions = ...} +// %divisor = arith.constant dense<...> : tensor +// %mean = "mhlo.divide"(%sum, %divisor) +class ConvertMeanOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter); + } +}; + +// Converts Sum op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] +// {dimensions = ...} +class ConvertSumOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + // The neutral element of fp addition is -0.0, not 0.0: '0.0 + -0.0 = 0.0'. + return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter); + } +}; + +// Converts Max op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// {dimensions = ...} +class ConvertMaxOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityLowest, rewriter); + } +}; + +// Converts Min op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"] +// {dimensions = ...} +class ConvertMinOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityMax, rewriter); + } +}; + +// Converts Prod op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"] +// {dimensions = ...} +class ConvertProdOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); + } +}; + +// Converts All op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"] +// {dimensions = ...} +class ConvertAllOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); + } +}; + +// Converts Any op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"] +// {dimensions = ...} +class ConvertAnyOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter); + } +}; + +// Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform +// a reduction on the original input and the corresponding index. The reduction +// sub-computation selects the max (or min) value and the index for the value. +// Derived: is the resulting derived class of this class. +// OpTy: is TF::ArgMaxOp or TF::ArgMinOp. +template +class ConvertArgMinMaxOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + RankedTensorType input_type = + mlir::dyn_cast(op.getInput().getType()); + if (!input_type) { + return failure(); + } + + Type input_element_type = input_type.getElementType(); + // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If + // tf.ArgMax doesn't support complex data types, this check can be removed. + if (!input_element_type.isSignlessIntOrFloat()) return failure(); + + Location loc = op.getLoc(); + Value init_value = + Derived::GetInitialValue(input_element_type, loc, rewriter); + + RankedTensorType output_type = + mlir::dyn_cast(op.getOutput().getType()); + if (!output_type) { + return rewriter.notifyMatchFailure(op, "requires known rank"); + } + + Type index_element_type = output_type.getElementType(); + Value index_init_value = + GetScalarConstOfType(index_element_type, loc, 0, &rewriter); + + RankedTensorType index_type = tensorflow::GetTypeFromTFTensorShape( + input_type.getShape(), index_element_type); + + std::optional optional_axis = + GetIntegerHLOAxisFromTFAxis(op.getDimension(), input_type.getRank()); + if (!optional_axis.has_value()) + return rewriter.notifyMatchFailure(op, "required axis"); + int64_t axis = optional_axis.value(); + + IntegerAttr iota_dimension = + IntegerAttr::get(rewriter.getIntegerType(64), axis); + Value input_shape = rewriter.create(loc, op.getInput()); + Value index_values = rewriter.create( + loc, index_type, input_shape, iota_dimension); + + Value operands[] = {op.getInput(), index_values}; + Value init_values[] = {init_value, index_init_value}; + DenseIntElementsAttr reduction_dimensions = + GetI64ElementsAttr({axis}, &rewriter); + + auto reduction = rewriter.create( + loc, llvm::ArrayRef(operands), + llvm::ArrayRef(init_values), reduction_dimensions, + TypeRange({input_element_type, index_element_type})); + auto direction = Derived::GetDirection(); + BuildArgMinMaxReductionBody(input_element_type, index_element_type, + direction, &reduction.getBody(), &rewriter); + + rewriter.replaceOp(op, {reduction.getResult(1)}); + return success(); + } +}; + +// Converts tensorflow ArgMax op to mhlo operations. The actual +// implementation is in class ConvertArgMinMaxOp: +// +// %init_index = arith.constant dense<...> : tensor +// %init = arith.constant dense<...> : tensor +// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["mhlo.arg_max"] +class ConvertArgMaxOp + : public ConvertArgMinMaxOp { + public: + using ConvertArgMinMaxOp::ConvertArgMinMaxOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityLowest, &rewriter); + } + + static ComparisonDirection GetDirection() { return ComparisonDirection::GE; } +}; + +// Converts tensorflow ArgMin op to mhlo operations. The actual +// implementation is in class ConvertArgMinMaxOp: +// +// %init_index = arith.constant dense<...> : tensor +// %init = arith.constant dense<...> : tensor +// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["mhlo.arg_min"] +class ConvertArgMinOp + : public ConvertArgMinMaxOp { + public: + using ConvertArgMinMaxOp::ConvertArgMinMaxOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityMax, &rewriter); + } + + static ComparisonDirection GetDirection() { return ComparisonDirection::LE; } +}; + +// Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with +// assignment: +// +// %result = "mhlo.scatter"(%tensor, %indices, %updates) +// { dimensions = ... } +// +template +class ConvertTensorScatterOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto tensor_ty = mlir::dyn_cast(op.getTensor().getType()); + auto indices_ty = + mlir::dyn_cast(op.getIndices().getType()); + auto updates_ty = + mlir::dyn_cast(op.getUpdates().getType()); + + if (!tensor_ty || !indices_ty || !updates_ty) return failure(); + // Last dimension of the indices needs to known at compile time for + // computation of the 'update_window_dims' attribute in the dimensions + // struct. + int64_t num_index_dims = indices_ty.getShape().back(); + if (ShapedType::isDynamic(num_index_dims)) return failure(); + + auto updates = op.getUpdates(); + + // Broadcast scalar `updates` in into expected shape as following shape: + // updates.shape == indices.shape[:-1] + tensor.shape[indices.shape[-1]:] + if (updates_ty.getRank() == 0 && + (std::is_same::value || + std::is_same::value)) { + if (!tensor_ty.hasStaticShape()) { + return failure(); + } + + if (!indices_ty.hasStaticShape()) { + return failure(); + } + + auto tensor_shape = tensor_ty.getShape(); + auto indices_shape = indices_ty.getShape(); + auto index_depth = indices_shape.back(); + llvm::SmallVector expected_update_shape; + + // create the expected update shape which scalar update is broadcasted to + expected_update_shape.append(indices_shape.begin(), + std::prev(indices_shape.end())); + + expected_update_shape.append(std::next(tensor_shape.begin(), index_depth), + tensor_shape.end()); + + auto const_type = tensorflow::GetTypeFromTFTensorShape( + {static_cast(expected_update_shape.size())}, + rewriter.getIntegerType(64)); + + auto const_attr = GetI64ElementsAttr(expected_update_shape, &rewriter); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + + auto broadcast_to_type = tensorflow::GetTypeFromTFTensorShape( + llvm::ArrayRef(expected_update_shape), + updates_ty.getElementType()); + + updates = rewriter.create( + op->getLoc(), broadcast_to_type, op.getUpdates(), const_op); + + updates_ty = mlir::dyn_cast(updates.getType()); + } + + int64_t tensor_rank = tensor_ty.getRank(); + int64_t indices_rank = indices_ty.getRank(); + int64_t updates_rank = + mlir::dyn_cast(updates.getType()).getRank(); + + int64_t window_dims = tensor_rank - num_index_dims; + auto dims_attr = ScatterDimensionNumbersAttr::get( + rewriter.getContext(), + llvm::to_vector<4>( + llvm::seq(updates_rank - window_dims, updates_rank)), + llvm::to_vector<4>(llvm::seq(0, num_index_dims)), + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, + llvm::to_vector<4>(llvm::seq(0, num_index_dims)), + indices_rank - 1); + + Location loc = op.getLoc(); + auto scatter = rewriter.create( + loc, op.getType(), ValueRange(Value(op.getTensor())), op.getIndices(), + updates, dims_attr); + Derived::BuildScatterBody(tensor_ty.getElementType(), + &scatter.getUpdateComputation(), loc, rewriter); + + rewriter.replaceOp(op, scatter.getResult(0)); + return success(); + } +}; + +class ConvertTensorScatterUpdateOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + builder.create(loc, block->getArgument(1)); + } +}; + +class ConvertTensorScatterAddOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + auto add_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, add_op.getResult()); + } +}; + +class ConvertTensorScatterSubOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + auto sub_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, sub_op.getResult()); + } +}; + +class ConvertTensorScatterMinOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + auto min_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, min_op.getResult()); + } +}; + +class ConvertTensorScatterMaxOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + auto max_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, max_op.getResult()); + } +}; + +// Converts Tile op to HLO BroadcastInDim and Reshape ops. +// For shape [S1, S2] and multiples [M1, M2], +// MS1 = M1 * S1; MS2 = M2 * S2 +// +// %broadcast = mhlo.broadcast_in_dim(%input) { +// broadcast_dimensions = [0, 2] +// } +// %result = "mhlo.reshape"(%broadcast) : (tensor) +// -> tensor +class ConvertTileOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::TileOp op, + PatternRewriter &rewriter) const override { + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); + ArrayRef input_shape = input_ty.getShape(); + Type element_type = input_ty.getElementType(); + + DenseIntElementsAttr multiples; + if (!matchPattern(op.getMultiples(), m_Constant(&multiples)) || + multiples.getType().getRank() != 1) + return failure(); + + const int64_t input_shape_size = input_shape.size(); + if (multiples.getNumElements() != input_shape_size) return failure(); + + SmallVector broadcasted_shape; + SmallVector broadcast_dimensions; + broadcasted_shape.reserve(input_shape.size() * 2); + broadcast_dimensions.reserve(input_shape.size()); + for (auto multiple_and_input : + llvm::zip(multiples.getValues(), input_shape)) { + int64_t multiple = std::get<0>(multiple_and_input).getSExtValue(); + int64_t input_size = std::get<1>(multiple_and_input); + + if (multiple < 0) return failure(); + + // Line input up with the next dimension in broadcasted_shape + // when broadcasting. + int64_t broadcast_dim; + int64_t output_size = input_size * multiple; + if (input_size == 1 || multiple == 1) { + // Special case for when normal broadcasting will just work. + broadcast_dim = broadcasted_shape.size(); + broadcasted_shape.push_back(output_size); + } else { + // Tiling will happen for this dimension during the ReshapeOp below. + broadcasted_shape.push_back(multiple); + broadcast_dim = broadcasted_shape.size(); + broadcasted_shape.push_back(input_size); + } + broadcast_dimensions.push_back(broadcast_dim); + } + Location loc = op.getLoc(); + Type broadcasted_type = + tensorflow::GetTypeFromTFTensorShape(broadcasted_shape, element_type); + Type output_type = op.getType(); + + Value result = rewriter.create( + loc, broadcasted_type, op.getInput(), + GetI64ElementsAttr(broadcast_dimensions, &rewriter)); + + if (output_type != broadcasted_type) { + result = rewriter.create(loc, output_type, result); + } + + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +// Converts the tf.TileOp op into mhlo.dynamic_reshape +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertTileOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + // clang-format off + // Converts Tile op to HLO DBroadcastInDim and DReshape ops. + // For shape [S1, S2] and multiples [M1, M2], + // MS1 = M1 * S1; MS2 = M2 * S2 + // + // %out_dim_size = [S1, M1, S2, M2] + // %broadcast_dimensions = [1, 3]; + // %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions); + // %shape = [MS1, MS2] + // %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor) -> tensor + // clang-format on + LogicalResult matchAndRewrite(TF::TileOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value input = op.getInput(); + Value multiples = op.getMultiples(); + auto input_ty = mlir::dyn_cast(input.getType()); + if (!input_ty) return failure(); + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (input_ty.hasStaticShape()) return failure(); + + Type element_type = input_ty.getElementType(); + int64_t input_rank = input_ty.getRank(); + SmallVector input_shape_values; + for (int64_t i = 0; i < input_rank; ++i) { + auto dim_size = input_ty.getDimSize(i); + if (dim_size == ShapedType::kDynamic) { + input_shape_values.push_back( + rewriter.create(loc, input, i)); + } else { + input_shape_values.push_back(rewriter.create( + loc, rewriter.getIndexAttr(dim_size))); + } + } + + auto multiples_ty = mlir::dyn_cast(multiples.getType()); + int64_t multiples_rank = multiples_ty.getRank(); + // rank of multiples input of tf.TileOp must be 1 + if (multiples_rank != 1) return failure(); + // multiples input of tf.TileOp must be fixed shaped + if ((!multiples_ty.hasStaticShape()) || + (multiples_ty.getDimSize(0) != input_rank)) { + return failure(); + } + Type index_ty = rewriter.getIndexType(); + // %out_dim_size + SmallVector out_dim_size; + out_dim_size.reserve(input_rank * 2); + for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) { + Value index = rewriter.create( + loc, rewriter.getIndexAttr(dim_idx)); + Value multiples_size = + rewriter.create(loc, multiples, ValueRange{index}); + Value multiples_size_casted = + rewriter.create(loc, index_ty, multiples_size); + out_dim_size.push_back(multiples_size_casted); + out_dim_size.push_back(input_shape_values[dim_idx]); + } + SmallVector broadcast_dimensions; + broadcast_dimensions.reserve(input_rank); + for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) { + broadcast_dimensions.push_back(1 + 2 * dim_idx); + } + auto broadcast_dims_attr = + GetI64ElementsAttr(broadcast_dimensions, &rewriter); + + Value out_dim_size_tensor = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(out_dim_size.size())}, index_ty), + out_dim_size); + SmallVector broadcast_shape(input_rank * 2, + ShapedType::kDynamic); + RankedTensorType broadcast_type = + tensorflow::GetTypeFromTFTensorShape(broadcast_shape, element_type); + Value broadcast = rewriter.create( + loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr); + + // %shape = [MS1, MS2] + SmallVector shape_values; + shape_values.reserve(input_rank); + for (int64_t i = 0; i < input_rank; ++i) { + Value dim_size_value = rewriter.create( + loc, out_dim_size[2 * i], out_dim_size[2 * i + 1]); + shape_values.push_back(dim_size_value); + } + Value shape = rewriter.create( + loc, tensorflow::GetTypeFromTFTensorShape({input_rank}, index_ty), + shape_values); + rewriter.replaceOpWithNewOp(op, op.getType(), + broadcast, shape); + return success(); + } +}; + +template +class ConvertMaxPoolGradOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Type element_type = + mlir::cast(op.getOrigInput().getType()).getElementType(); + + // Compute paddings using the original input and kernel shape and strides. + // Here, ReduceWindow op as used as the MaxPool op is lowered to the + // ReduceWindow op. + auto input_ty = + mlir::dyn_cast(op.getOrigInput().getType()); + if (!input_ty) return failure(); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + + auto result = rewriter.create( + loc, op.getType(), op.getOrigInput(), op.getGrad(), + GetScalarConstOfType(element_type, loc, 0, &rewriter), + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + paddings_attr); + + BuildReduceBody(element_type, &result.getScatter(), &rewriter); + { + OpBuilder::InsertionGuard guard(rewriter); + Block *block = rewriter.createBlock(&result.getSelect()); + + // Block arguments are scalars of the given element type. + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + + auto reducer = rewriter.create(loc, block->getArgument(0), + block->getArgument(1), + ComparisonDirection::GE); + rewriter.create(loc, reducer.getResult()); + } + + rewriter.replaceOp(op, result); + + return success(); + } +}; + +using ConvertMaxPool2DGradOp = + ConvertMaxPoolGradOp; +using ConvertMaxPool3DGradOp = + ConvertMaxPoolGradOp; + +// Converts tf.Conv?DBackpropInputOp into: +// %rev_filter = "mhlo.reverse"(%filter) +// %result = "mhlo.convolution"(%out_backprop, %rev_filter) +template +class ConvertConvBackpropInputOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Unpack all of the attributes. + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + constexpr int num_dims = num_spatial_dims + 2; + int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + + auto out_backprop_ty = + mlir::dyn_cast(op.getOutBackprop().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); + + // With the exception of out_backprop's batch dimension, out_backprop and + // filter need to have static shape. Filter is validated here, out_backprop + // is mostly validated at use. + if (!out_backprop_ty || !filter_ty || !filter_ty.hasStaticShape()) + return failure(); + + // Compute input_shape by supporting either: + // 1) Fully static shapes, represented as constants. + // 2) Static shapes with a dynamic batch dimension, represented as + // 1D tf.Pack of a batch dimension (can be static or dynamic) + // and other dimensions (can only be static), for example: + // "tf.Pack"(%142, %cst_301, %cst_301, %cst_300) {axis = 0 : i64, ...} + std::vector input_shape; + DenseIntElementsAttr input_shape_attr; + if (matchPattern(op.getInputSizes(), m_Constant(&input_shape_attr)) && + input_shape_attr.getType().getRank() == 1) { + input_shape.insert(input_shape.end(), + input_shape_attr.getValues().begin(), + input_shape_attr.getValues().end()); + } else { + auto pack = op.getInputSizes().template getDefiningOp(); + if (!pack || pack.getAxis() != 0) return failure(); + auto pack_ty = mlir::dyn_cast(pack.getType()); + if (!pack_ty || pack_ty.getRank() != 1) return failure(); + for (auto i = 0; i < pack_ty.getDimSize(0); ++i) { + if (i == batch_dim) { + // We don't use the batch dimension below, so we don't care about + // its size. Might as well populate it with -1. + input_shape.push_back(ShapedType::kDynamic); + } else { + DenseIntElementsAttr input_dims_attr; + if (matchPattern(pack.getValues()[i], m_Constant(&input_dims_attr)) && + input_dims_attr.getType().getRank() == 0) { + input_shape.push_back(input_dims_attr.getSplatValue()); + } else { + return failure(); + } + } + } + } + + auto dilations_attr = GetI64ElementsAttr(op.getDilations()); + std::vector dilations{ + dilations_attr.template getValues().begin(), + dilations_attr.template getValues().end()}; + auto strides_attr = GetI64ElementsAttr(op.getStrides()); + std::vector strides{ + strides_attr.template getValues().begin(), + strides_attr.template getValues().end()}; + + std::vector explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is limited to + // Conv2DBackpropInput. So, fetch attribute by identifier instead of the + // op.explicit_paddings() attribute getter. + ArrayRef explicit_paddings_attr = + op->template getAttrOfType("explicit_paddings").getValue(); + explicit_paddings.reserve(explicit_paddings_attr.size()); + for (Attribute explicit_padding : explicit_paddings_attr) + explicit_paddings.push_back( + mlir::cast(explicit_padding).getInt()); + } + + ArrayRef filter_shape = filter_ty.getShape(); + + // Compute ConvDimensionNumbers, dilation, and padding. + SmallVector spatial_dims; + SmallVector lhs_dilation; + SmallVector rhs_dilation; + SmallVector paddings; + + for (int i : llvm::seq(0, num_spatial_dims)) { + const int64_t spatial_dim = + GetTensorSpatialDimIndex(num_dims, data_format, i); + spatial_dims.push_back(spatial_dim); + + // Prepare metadata indexed by spatial_dim for computing pad_before + // and pad_after. + int64_t input_size = input_shape[spatial_dim]; + if (input_size == ShapedType::kDynamic) return failure(); + int64_t output_size = out_backprop_ty.getDimSize(spatial_dim); + if (output_size == ShapedType::kDynamic) return failure(); + int64_t filter_size = filter_ty.getDimSize(i); + int64_t stride = strides[spatial_dim]; + int64_t dilation = dilations[spatial_dim]; + + // Compute pad_before and pad_after following the logic from + // ConvBackpropComputeDimensionsV2. (Unfortunately, we cannot call + // the function in question because it doesn't work with dynamic dims). + int64_t padding_before = -1, padding_after = -1; + if (padding == tensorflow::Padding::EXPLICIT) { + padding_before = explicit_paddings[2 * spatial_dim]; + padding_after = explicit_paddings[2 * spatial_dim + 1]; + } + int64_t expected_output_size = 0; + auto status = GetWindowedOutputSizeVerbose( + input_size, filter_size, dilation, stride, padding, + &expected_output_size, &padding_before, &padding_after); + if (!status.ok()) return failure(); + if (output_size != expected_output_size) return failure(); + int64_t effective_filter_size = (filter_size - 1) * dilation + 1; + int64_t pad_before = effective_filter_size - 1 - padding_before; + int64_t padded_out_size = input_size + effective_filter_size - 1; + int64_t expanded_output_size = (output_size - 1) * stride + 1; + int64_t pad_after = padded_out_size - expanded_output_size - pad_before; + + // Populate metadata for the upcoming mhlo.conv op using the result of + // the computations performed above. + lhs_dilation.push_back(stride); + rhs_dilation.push_back(dilation); + paddings.push_back(pad_before); + paddings.push_back(pad_after); + } + + RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( + {num_spatial_dims, 2}, rewriter.getIntegerType(64)); + auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); + + Value filter = op.getFilter(); + + const int feature_dim = + tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); + const int64_t in_depth = *(input_shape.begin() + feature_dim); + if (in_depth == ShapedType::kDynamic) return failure(); + const int64_t filter_in_depth = filter_shape[num_spatial_dims]; + const int64_t feature_group_count = in_depth / filter_in_depth; + + if (feature_group_count != 1) { + // 1. Reshape filter from + // [H, W, ..., filter_in_depth, out_depth] to + // [H, W, ..., filter_in_depth, G, out_depth / G]. + auto new_shape = llvm::to_vector<6>(filter_shape); + new_shape.back() = feature_group_count; + new_shape.push_back(filter_shape.back() / feature_group_count); + Type filter_element_ty = filter_ty.getElementType(); + auto ty = + tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); + filter = rewriter.create(op.getLoc(), ty, filter); + + // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]. + llvm::SmallVector perm(num_dims + 1); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]); + std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]); + ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); + filter = rewriter.create( + op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter)); + + // 3. Reshape to [H, W, ..., in_depth, out_depth / G]. + new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1]; + new_shape[num_spatial_dims + 1] = new_shape.back(); + new_shape.pop_back(); + ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); + filter = rewriter.create(op.getLoc(), ty, filter); + } + + SmallVector kernel_spatial_dims; + kernel_spatial_dims.resize(num_spatial_dims); + std::iota(kernel_spatial_dims.begin(), kernel_spatial_dims.end(), 0); + + // Mirror the filter in the spatial dimensions. + filter = rewriter.create( + op.getLoc(), filter, + GetI64ElementsAttr(kernel_spatial_dims, &rewriter)); + + // activation gradients + // = gradients (with padding and dilation) mirrored_weights + Value result = rewriter.create( + op.getLoc(), op.getType(), op.getOutBackprop(), filter, + /*window_strides=*/ + GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), + /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), + GetI64ElementsAttr(rhs_dilation, &rewriter), + /*window_reversal=*/nullptr, + ConvDimensionNumbersAttr::get( + rewriter.getContext(), + /*inputBatchDimension=*/batch_dim, + /*inputFeatureDimension=*/feature_dim, + /*inputSpatialDimensions=*/spatial_dims, + // TF filter shape is [ H, W, ..., inC, outC ] + // Transpose the input and output features for computing the + // gradient. + /*kernelInputFeatureDimension=*/ + num_spatial_dims + 1, + /*kernelOutputFeatureDimension=*/ + num_spatial_dims, + /*kernelSpatialDimensions=*/kernel_spatial_dims, + /*outputBatchDimension=*/batch_dim, + /*outputFeatureDimension=*/feature_dim, + /*outputSpatialDimensions=*/spatial_dims), + rewriter.getI64IntegerAttr(feature_group_count), + /*batch_group_count=*/rewriter.getI64IntegerAttr(1), + /*precision_config=*/GetPrecisionConfig(&rewriter)); + + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +using ConvertConv2DBackpropInputOp = + ConvertConvBackpropInputOp; +using ConvertConv3DBackpropInputOp = + ConvertConvBackpropInputOp; + +// Converts tf.Conv?DBackpropFilterOp into: +// %result = "mhlo.convolution"(%input, %out_backprop) +template +class ConvertConvBackpropFilterOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Unpack all of the attributes. + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + + auto out_backprop_ty = + mlir::dyn_cast(op.getOutBackprop().getType()); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + + for (RankedTensorType ty : {out_backprop_ty, input_ty}) + if (!ty || !ty.hasStaticShape()) return failure(); + + ArrayRef out_backprop_shape = out_backprop_ty.getShape(); + ArrayRef input_shape = input_ty.getShape(); + + DenseIntElementsAttr filter_shape_attr; + if (!matchPattern(op.getFilterSizes(), m_Constant(&filter_shape_attr)) || + filter_shape_attr.getType().getRank() != 1) + return failure(); + + auto dilations_attr = GetI64ElementsAttr(op.getDilations()); + std::vector dilations{ + dilations_attr.template getValues().begin(), + dilations_attr.template getValues().end()}; + auto strides_attr = GetI64ElementsAttr(op.getStrides()); + std::vector strides{ + strides_attr.template getValues().begin(), + strides_attr.template getValues().end()}; + + std::vector explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is limited to + // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the + // op.explicit_paddings() attribute getter. + ArrayRef explicit_paddings_attr = + op->template getAttrOfType("explicit_paddings").getValue(); + explicit_paddings.reserve(explicit_paddings_attr.size()); + for (Attribute explicit_padding : explicit_paddings_attr) + explicit_paddings.push_back( + mlir::cast(explicit_padding).getInt()); + } + + constexpr int num_dims = num_spatial_dims + 2; + auto filter_shape = filter_shape_attr.getValues(); + + // Reuse dimension computation logic from conv_grad_shape_utils.cc. + tensorflow::ConvBackpropDimensions dims; + if (!tensorflow::ConvBackpropComputeDimensionsV2( + /*label=*/"", num_spatial_dims, + ToTensorShape(input_shape), + ToTensorShape(filter_shape), + ToTensorShape(out_backprop_shape), dilations, + strides, padding, explicit_paddings, data_format, &dims) + .ok()) { + return failure(); + } + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we need to: + // 1. In the case of group convolution, move the num_groups dimension before + // the batch dimension + // 2. Swap the roles of the batch and feature dimensions. + const int feature_dim = + tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); + const int64_t in_depth = input_shape[feature_dim]; + const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims); + const int64_t batch_group_count = in_depth / filter_in_depth; + + // Compute ConvDimensionNumbers, dilation, and padding. + SmallVector spatial_dims; + SmallVector kernel_spatial_dims; + SmallVector rhs_dilation; + SmallVector paddings; + SmallVector window_strides; + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + for (int i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = + tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i); + kernel_spatial_dims.push_back(dim); + // Besides padding the input, we will also expand output_rows to + // expanded_out_rows = (output_rows - 1) * stride + 1 + // with zeros in between: + // + // a . . . b . . . c . . . d . . . e + // + // This is done by specifying the window dilation factors in the + // convolution HLO below. + const auto &spatial_dim_i = dims.spatial_dims[i]; + rhs_dilation.push_back(spatial_dim_i.stride); + window_strides.push_back(dilations[dim]); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + + const int64_t padded_in_size = + spatial_dim_i.expanded_output_size + + (spatial_dim_i.filter_size - 1) * dilations[dim]; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int64_t pad_total = padded_in_size - spatial_dim_i.input_size; + + // + For the EXPLICIT padding, we pad the top/left side with the explicit + // padding and pad the bottom/right side with the remaining space. + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT + ? explicit_paddings[2 * dim] + : padding == tensorflow::Padding::SAME + ? std::max(pad_total / 2, 0) + : 0; + paddings.push_back(pad_before); + paddings.push_back(pad_total - pad_before); + } + + RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( + {num_spatial_dims, 2}, rewriter.getIntegerType(64)); + auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); + + SmallVector output_spatial_dimensions; + output_spatial_dimensions.resize(num_spatial_dims); + std::iota(output_spatial_dimensions.begin(), + output_spatial_dimensions.end(), 0); + + const int batch_dim = + tensorflow::GetTensorBatchDimIndex(num_dims, data_format); + + Value result = rewriter.create( + op.getLoc(), op.getType(), op.getInput(), op.getOutBackprop(), + /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), + /*padding=*/paddings_attr, /*lhs_dilation=*/ + GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), + GetI64ElementsAttr(rhs_dilation, &rewriter), + /*window_reversal=*/nullptr, + ConvDimensionNumbersAttr::get( + rewriter.getContext(), + // Swap batch_dim and feature_dim in the activations. + /*inputBatchDimension=*/feature_dim, + /*inputFeatureDimension=*/batch_dim, + /*inputSpatialDimensions=*/kernel_spatial_dims, + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., + // out_depth] where the batch becomes the input feature for the + // convolution. + /*kernelInputFeatureDimension=*/batch_dim, + /*kernelOutputFeatureDimension=*/feature_dim, + /*kernelSpatialDimensions=*/kernel_spatial_dims, + /*outputBatchDimension=*/num_spatial_dims, + /*outputFeatureDimension=*/num_spatial_dims + 1, + /*outputSpatialDimensions=*/output_spatial_dimensions), + /*feature_group_count=*/rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(batch_group_count), + /*precision_config=*/GetPrecisionConfig(&rewriter)); + + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +using ConvertConv2DBackpropFilterOp = + ConvertConvBackpropFilterOp; +using ConvertConv3DBackpropFilterOp = + ConvertConvBackpropFilterOp; + +class ConvertOneHotOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::OneHotOp op, + PatternRewriter &rewriter) const override { + auto indices_ty = + mlir::dyn_cast(op.getIndices().getType()); + if (!indices_ty || !indices_ty.hasStaticShape()) return failure(); + ArrayRef indices_shape = indices_ty.getShape(); + Type element_type = indices_ty.getElementType(); + + DenseIntElementsAttr depth_attr; + if (!matchPattern(op.getDepth(), m_Constant(&depth_attr))) { + return failure(); + } + + int64_t depth = depth_attr.getValues()[0].getSExtValue(); + int64_t axis = op.getAxis(); + if (axis == -1) axis = indices_shape.size(); + + llvm::SmallVector broadcast_dims(indices_shape.size()); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + + llvm::SmallVector output_dims = + llvm::to_vector<4>(indices_shape); + output_dims.insert(output_dims.begin() + axis, depth); + + Location loc = op.getLoc(); + + // The iota result is the effective output shape of the computation, + // and indices must be broadcast into it. At this point, this computation + // would need to be reworked quite a bit to support dynamic shapes, so + // just using static broadcasting. + auto index_type = + tensorflow::GetTypeFromTFTensorShape(output_dims, element_type); + auto iota = rewriter.create( + loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); + auto broadcast_indices = rewriter.create( + loc, index_type, op.getIndices(), + GetI64ElementsAttr(broadcast_dims, &rewriter)); + + Value compare = rewriter.create( + loc, broadcast_indices, iota, ComparisonDirection::EQ); + Value on_value = rewriter.create( + loc, op.getType(), op.getOnValue(), + GetI64ElementsAttr(output_dims, &rewriter)); + Value off_value = rewriter.create( + loc, op.getType(), op.getOffValue(), + GetI64ElementsAttr(output_dims, &rewriter)); + Value result = rewriter.create(loc, op.getType(), compare, + on_value, off_value); + + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +// Converts InfeedDequeueTuple to XLA HLO create_token, infeed and +// get_tuple_element ops. +// +// All HLO infeed ops expect a HLO token type operand and produce a tuple +// containing a token. This HLO token type is used to order multiple infeed +// operations within a computation. The token type can come from other +// infeed/outfeed/send/recv ops or can be generated using create_token op with +// no operands. Here we emit a create_token op to generate the token type +// operand of infeed. The mhlo.InfeedOp can produce multiple results and later +// will be exported to XLA infeed op with single tuple return type. +// +// For example the following IR: +// %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) +// +// would be lowered to +// +// %token = "mhlo.create_token"() : () -> !mhlo.token +// %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} : +// (!mhlo.token) -> tensor<3xi32>, tensor<4xf32>, !mhlo.token> +// +class ConvertInfeedDequeueTupleOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op, + PatternRewriter &rewriter) const override { + SmallVector result_types; + result_types.reserve(op.getOutputs().size() + 1); + for (const auto &output : op.getOutputs()) { + Type ty = output.getType(); + if (auto tensor_ty = mlir::dyn_cast(ty)) { + if (!tensor_ty.hasStaticShape()) return failure(); + } + result_types.push_back(ty); + } + + // Infeed takes a single token operand. Generate the token using + // create_token op to pass to the infeed op. + auto token = rewriter.create( + op.getLoc(), mhlo::TokenType::get(rewriter.getContext())); + + result_types.push_back(token.getType()); + + ArrayAttr layout; // filled in during the xla-adjust-layout pass + auto data_and_token = + rewriter.create(op.getLoc(), result_types, token, + /*infeed_config=*/rewriter.getStringAttr(""), + /*layout=*/layout); + + result_types.pop_back(); // remove the token type. + + if (op.get_XlaSharding().has_value()) { + // _XlaSharding attribute in TF is a serialized string of the OpSharding + // proto, so convert to a text form here. + ::xla::OpSharding sharding_proto; + if (tensorflow::DecodeShardingAttribute( + op.get_XlaSharding().value().str(), sharding_proto) + .failed()) { + return failure(); + } + // Token is a control signal and not a real data, so arbitrarily assign + // the token to device 0. + if (sharding_proto.type() == ::xla::OpSharding::TUPLE) { + *sharding_proto.add_tuple_shardings() = + ::xla::sharding_builder::AssignDevice(0); + data_and_token->setAttr( + kShardingAttr, + rewriter.getStringAttr(sharding_proto.SerializeAsString())); + } else { + data_and_token->setAttr(kShardingAttr, op.get_XlaShardingAttr()); + } + } + + if (op->hasAttr("layouts")) { + // Append a UnitAttr for the "token" operand of the mhlo.infeed op here to + // avoid compilation failure when exporting "layouts" attribute of the + // corresponding InfeedDequeueTupleOp to a graph node. + data_and_token->setAttr("layout", op->getAttr("layouts")); + } + llvm::SmallVector results; + results.reserve(result_types.size()); + for (const auto &idx_and_type : llvm::enumerate(result_types)) { + results.push_back(data_and_token.getResult(idx_and_type.index())); + } + rewriter.replaceOp(op, ValueRange(results)); + return success(); + } +}; + +// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed +// ops. +// +// XLA HLO outfeed op expects a token, which we generate by emitting an +// create_token op. +// +// For example the following IR: +// "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> +// () +// +// would be lowered to +// +// %token = "mhlo.create_token"() : () -> !mhlo.token +// %outfeed_token = "mhlo.outfeed"(%val_1, %val_2, %token) {outfeed_config = ""} +// : +// (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token +// +class ConvertOutfeedEnqueueTupleOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, + PatternRewriter &rewriter) const override { + auto token_type = mhlo::TokenType::get(rewriter.getContext()); + auto token = rewriter.create(op.getLoc(), token_type); + + rewriter.create(op.getLoc(), token_type, op.getInputs(), token, + /*outfeed_config=*/rewriter.getStringAttr("")); + rewriter.eraseOp(op); + return success(); + } +}; + +// Converts tf.TopKV2 to chlo.top_k. +class ConvertTopKV2Op : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::TopKV2Op op, + PatternRewriter &rewriter) const override { + // We can only match when the `k` operand is a constant scalar. + DenseIntElementsAttr k_attr; + if (!matchPattern(op.getK(), m_Constant(&k_attr))) return failure(); + int64_t k = (*k_attr.begin()).getSExtValue(); + + TensorType input_type = mlir::cast(op.getInput().getType()); + if (!input_type.hasRank()) return failure(); + int64_t input_rank = input_type.getRank(); + int64_t last_dim_index = input_rank - 1; + int64_t last_dim_size = input_type.getDimSize(last_dim_index); + if (last_dim_size == ShapedType::kDynamic) return failure(); + + rewriter.replaceOpWithNewOp(op, op.getInput(), k); + return success(); + } +}; + +// Converts tf.Unpack to a series of XLA HLO slice ops. +// +// Each slice takes one element along the dimension to unpack and takes the full +// range for all other dimensions. Each slice is then reshaped to drop the +// dimension to unpack (which is always of size 1). +// TODO(antiagainst): consider changing this into a TF internal lowering pass. +class ConvertUnpackOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::UnpackOp op, + PatternRewriter &rewriter) const override { + auto value_type = mlir::dyn_cast(op.getValue().getType()); + if (!value_type) return failure(); + + int64_t value_rank = value_type.getRank(); + int64_t axis = op.getAxis(); + if (axis < 0) axis += value_rank; + + // Parameters for constructing each slice. + SmallVector begin_indices(value_rank, 0); + auto end_indices = llvm::to_vector<4>(value_type.getShape()); + SmallVector strides(value_rank, 1); + + // All HLO slice+squeeze results used to replace the original tf.Unpack op. + SmallVector results; + results.reserve(op.getNumResults()); + + for (int i = 0, end = op.getNumResults(); i < end; ++i) { + begin_indices[axis] = i; + end_indices[axis] = i + 1; + + auto slice_op = rewriter.create( + op.getLoc(), op.getValue(), + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + // Reshape to drop the axis dimension. + auto result = rewriter.create( + op.getLoc(), op.getType(i), slice_op, + rewriter.getI64ArrayAttr(op.getAxis())); + results.push_back(result); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +// Converts tf.Unpack to a series of XLA HLO Slice ops. +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertUnpackOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::UnpackOp op, + PatternRewriter &rewriter) const override { + auto value_type = mlir::dyn_cast(op.getValue().getType()); + if (!value_type) return failure(); + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (value_type.hasStaticShape()) return failure(); + + int64_t value_rank = value_type.getRank(); + int64_t axis = op.getAxis(); + if (axis < 0) axis += value_rank; + Location loc = op.getLoc(); + + auto shape_scalar_type = rewriter.getIntegerType(32); + // Parameters for constructing each slice. + SmallVector begin_indices, end_indices, strides; + begin_indices.reserve(value_rank); + end_indices.reserve(value_rank); + strides.reserve(value_rank); + // final output shape + SmallVector shape_values; + shape_values.reserve(value_rank - 1); + // slice shape before reshape, should be like{?, 1, ?, ?} if axis = 1 + SmallVector slice_shape(value_rank, ShapedType::kDynamic); + for (int64_t dim_idx = 0; dim_idx < value_rank; ++dim_idx) { + int64_t dim_size = value_type.getDimSize(dim_idx); + if (dim_size == ShapedType::kDynamic) { + Value dim_i = rewriter.create( + loc, shape_scalar_type, + rewriter.create(loc, op.getOperand(), dim_idx)); + end_indices.push_back(dim_i); + if (dim_idx != axis) { + shape_values.push_back(dim_i); + } + } else { + Value dim_i = rewriter.create( + loc, shape_scalar_type, + rewriter.getIntegerAttr(shape_scalar_type, dim_size)); + end_indices.push_back(dim_i); + if (dim_idx != axis) { + shape_values.push_back(dim_i); + slice_shape[dim_idx] = dim_size; + } else { + slice_shape[dim_idx] = 1; + } + } + begin_indices.push_back( + rewriter.create(loc, 0, 32)); + strides.push_back(rewriter.create(loc, 1, 32)); + } + + SmallVector results; + results.reserve(op.getNumResults()); + Type i32_ty = rewriter.getI32Type(); + for (int64_t i = 0; i < op.getNumResults(); ++i) { + begin_indices[axis] = rewriter.create(loc, i, 32); + end_indices[axis] = rewriter.create(loc, i + 1, 32); + Value slice_op = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(slice_shape, + value_type.getElementType()), + op.getValue(), + rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(begin_indices.size())}, i32_ty), + begin_indices), + rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(end_indices.size())}, i32_ty), + end_indices), + rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(strides.size())}, i32_ty), + strides)); + // Reshape to drop the axis dimension. + Value new_shape = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(shape_values.size())}, i32_ty), + shape_values); + Value reshape_op = rewriter.create(loc, op.getType(i), + slice_op, new_shape); + results.push_back(reshape_op); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +// Converts the tf.SigmoidGradOp +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertSigmoidGradOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SigmoidGradOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value y = op.getY(); + Value dy = op.getDy(); + auto tp_y = mlir::dyn_cast(y.getType()); + auto tp_dy = mlir::dyn_cast(dy.getType()); + if (!tp_y || !tp_dy) return failure(); + + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure(); + + Attribute attr; + Type elem_tp = tp_y.getElementType(); + if (elem_tp.isSignlessInteger()) { + attr = rewriter.getIntegerAttr(elem_tp, 1); + } else { + assert(mlir::isa(elem_tp)); + attr = rewriter.getFloatAttr(elem_tp, 1); + } + Value one = rewriter.create( + loc, DenseElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape({}, elem_tp), attr)); + + auto v0 = rewriter.create( + loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y)); + auto v1 = rewriter.create( + loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y)); + auto result = rewriter.create( + loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1)); + + rewriter.replaceOp(op, result.getOperation()->getResults()); + return success(); + } +}; + +// Converts TF unsorted segment reduction ops to XLA HLO scatter op. +// +// TF unsorted segment reduction op peforms the following calculation: +// +// Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's shape is +// [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's +// shape, so we can have data's shape represented as [SI0, SI1, ..., SIm, +// Dm+1, ..., Dn]. Then +// output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] = +// over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in] +// where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN). +// +// The op will be translated to XLA HLO scatter with the following parameters: +// * Update window dims is [segment_id_rank, data_rank). +// * Inserted window dims is {0}. +// * Scatter dims to operand dims mapping is {0}. +// * Index vector dim is segment_id_rank. +template +class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto data_type = mlir::dyn_cast(op.getData().getType()); + if (!data_type) return failure(); + int64_t data_rank = data_type.getRank(); + + auto segment_ids_type = + mlir::dyn_cast(op.getSegmentIds().getType()); + if (!segment_ids_type) return failure(); + int64_t segment_ids_rank = segment_ids_type.getRank(); + + DenseIntElementsAttr num_segments_attr; + if (!matchPattern(op.getNumSegments(), m_Constant(&num_segments_attr))) + return failure(); + + // The final shape for TF unsorted segment reduction op is [num_segments] + + // data_shape[segment_ids_rank:]. + SmallVector output_shape; + output_shape.push_back((*num_segments_attr.begin()).getSExtValue()); + auto suffix = data_type.getShape().drop_front(segment_ids_rank); + output_shape.append(suffix.begin(), suffix.end()); + auto output_type = tensorflow::GetTypeFromTFTensorShape( + output_shape, data_type.getElementType()); + + // Broadcast the initial value for reduction. This will become the + // 'operand' parameter to scatter to for the final scatter op. + Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), + op.getLoc(), &rewriter); + auto broadcasted_init = rewriter.create( + op.getLoc(), output_type, init, + GetI64ElementsAttr(output_shape, &rewriter)); + + // Parameters for the generated scatter op. + SmallVector inserted_window_dims(1, 0); + SmallVector scatter_dims_to_operand_dims(1, 0); + int64_t index_vector_dim = segment_ids_rank; + + // Put all parameters in a StructAttr. + auto dims_attr = ScatterDimensionNumbersAttr::get( + rewriter.getContext(), + llvm::to_vector<4>(llvm::seq(segment_ids_rank, data_rank)), + inserted_window_dims, + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, scatter_dims_to_operand_dims, + index_vector_dim); + + auto scatter = rewriter.create( + op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)), + op.getSegmentIds(), op.getData(), dims_attr); + BuildReduceBody(data_type.getElementType(), + &scatter.getUpdateComputation(), &rewriter); + + rewriter.replaceOp(op, scatter.getResult(0)); + return success(); + } +}; + +class ConvertUnsortedSegmentMaxOp + : public GenericConvertUnsortedSegmentReductionOp< + ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> { + public: + using GenericConvertUnsortedSegmentReductionOp:: + GenericConvertUnsortedSegmentReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest, + rewriter); + } +}; + +class ConvertUnsortedSegmentMinOp + : public GenericConvertUnsortedSegmentReductionOp< + ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> { + public: + using GenericConvertUnsortedSegmentReductionOp:: + GenericConvertUnsortedSegmentReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax, + rewriter); + } +}; + +class ConvertUnsortedSegmentProdOp + : public GenericConvertUnsortedSegmentReductionOp< + ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> { + public: + using GenericConvertUnsortedSegmentReductionOp:: + GenericConvertUnsortedSegmentReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); + } +}; + +class ConvertUnsortedSegmentSumOp + : public GenericConvertUnsortedSegmentReductionOp< + ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> { + public: + using GenericConvertUnsortedSegmentReductionOp:: + GenericConvertUnsortedSegmentReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter); + } +}; + +// Converts tf.RandomShuffle op into a series of XLA HLO ops. +// +// tf.RandomShuffle shuffles tensors along the first dimension. If the input +// tensor's rank is 1, then it is translated into HLO sort op(s) according to +// indices randomly generated via HLO rng_uniform ops. Otherwise, it is +// translated into an HLO while op to first emulate shuffling indices using +// HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather +// with the shuffled indices. +class ConvertRandomShuffleOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::RandomShuffleOp op, + PatternRewriter &rewriter) const override { + auto no_op = [&]() { + rewriter.replaceOp(op, op.getValue()); + return success(); + }; + + auto input_type = mlir::dyn_cast(op.getValue().getType()); + if (!input_type) return failure(); + if (input_type.hasStaticShape() && input_type.getNumElements() <= 1) + // No shuffling is required, so copy input directly to output. + return no_op(); + + int64_t input_rank = input_type.getRank(); + int64_t first_dim_size = input_type.getDimSize(0); + if (ShapedType::isDynamic(first_dim_size)) return failure(); + + if (first_dim_size <= 1) + // No shuffling is required, so copy input directly to output. + return no_op(); + + // For vectors, shuffle values by sorting instead of the obvious + // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct, + // but not easily parallelizable. For a sufficiently parallel architecture, + // it is faster to sort many times, than Fisher-Yates shuffle once. + if (input_rank == 1) { + // Shuffle values by assigning each value a random key and sorting the + // keys. Keys can collide causing detectable patterns in the shuffled + // output. Collisions translates into more ascending sub-sequences in the + // shuffled output than would be expected by chance. To avoid collisions, + // the number of possible key values must be sufficiently large. + + // How are more than 2^32 keys created? In each loop iteration, the + // algorithm sorts by random keys. Conceptually, the earlier iterations + // are sorting on the lower-order bits of larger keys that are never + // actually assembled. + + // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is + // the number of possible keys and n is the number of values. If d = n^2, + // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit + // as n goes to infinity is zero. + + // This implementation ensures that the key-space is greater than or equal + // to the cube of the number of values. The risk of collisions can be + // further reduced by increasing Exponent at the expense of + // performance. + + // For Exponent = 2, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is + // about 1/2. + + // For Exponent = 3, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is + // about 1/3255. + + // For Exponent = 4, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is + // about 1/132622. + constexpr int exponent = 3; + int64_t num_elements = input_type.getNumElements(); + uint32_t u32_max = std::numeric_limits::max(); + int rounds = + std::ceil(exponent * std::log(num_elements) / std::log(u32_max)); + + Value current = op.getValue(); + for (int i = 0; i < rounds; ++i) { + auto keys = + CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0, + /*upper_limit=*/u32_max, &rewriter); + auto sorted = createSortOp( + &rewriter, op.getLoc(), {keys, current}, + {rewriter.getIntegerType(32), input_type.getElementType()}, + /*dimension=*/-1, /*isStable=*/false, + /*direction=*/ComparisonDirection::LT); + current = sorted.getResult(1); + } + rewriter.replaceOp(op, current); + return success(); + } + + // The Fisher-Yates algorithm. + + // Generate range(n) as the initial value for the indices to be swapped. + auto indices_type = tensorflow::GetTypeFromTFTensorShape( + {first_dim_size}, rewriter.getIntegerType(32)); + Value indices = rewriter.create( + op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0)); + + // Generate random numbers to be used as swaps for the indices. + Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0, + first_dim_size, &rewriter); + + // While loop body to perform index swaps. + auto swap_body_fn = [&](Location loc, Value i, ArrayRef old_values, + SmallVectorImpl *new_values, + OpBuilder *builder) { + Value swaps = old_values[0]; + Value indices = old_values[1]; + + auto scalar_i32_type = + tensorflow::GetTypeFromTFTensorShape({}, builder->getIntegerType(32)); + auto one_cross_i64_type = tensorflow::GetTypeFromTFTensorShape( + {1}, builder->getIntegerType(64)); + + auto scalar_one = + DenseIntElementsAttr::get(one_cross_i64_type, ArrayRef(1)); + + // We need to swap the indices[i] with indices[swaps[i]]. First get + // these index values. + Value source_index = + builder->create(loc, indices, i, scalar_one); + Value swap_index = builder->create( + loc, scalar_i32_type, + builder->create(loc, swaps, i, scalar_one)); + Value target_index = builder->create( + loc, indices, swap_index, scalar_one); + + // Then perform the swap. + // indices[i] <- indices[swaps[i]] + indices = builder->create( + loc, indices.getType(), indices, target_index, llvm::ArrayRef(i)); + // indices[swaps[i]] <- indices[i] + indices = builder->create( + loc, indices.getType(), indices, source_index, + llvm::ArrayRef(swap_index)); + + // Update new values. + new_values->assign({swaps, indices}); + }; + + // Create a while op to swap indices. + SmallVector while_output; + CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices}, + &while_output, &rewriter); + Value swaped_indices = while_output[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto slice_sizes = tensorflow::ConvertMlirShapeToTF(input_type.getShape()); + slice_sizes[0] = 1; + auto dims_attr = GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/llvm::to_vector<4>(llvm::seq(1, input_rank)), + /*collapsedSliceDims=*/{0}, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, + /*startIndexMap=*/{0}, + /*indexVectorDim=*/1); + + SmallVector slice_sizes_values; + for (auto i = 0; i < slice_sizes.size(); ++i) { + if (slice_sizes[i] == tensorflow::kTFDynamicSize) { + Value i_const = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(i)); + Value slice_size_index = + rewriter.create(op.getLoc(), op.getValue(), i_const); + Value index_to_i64 = rewriter.create( + op.getLoc(), rewriter.getI64Type(), slice_size_index); + Value i64_to_tensor = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), + index_to_i64); + slice_sizes_values.push_back(i64_to_tensor); + } else { + slice_sizes_values.push_back(rewriter.create( + op.getLoc(), GetI64ElementsAttr({slice_sizes[i]}, &rewriter))); + } + } + + auto slice_sizes_concat = rewriter.create( + op.getLoc(), slice_sizes_values, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getValue(), swaped_indices, slice_sizes_concat, + dims_attr); + + return success(); + } +}; + +// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes. +class ConvertXlaShardingOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaShardingOp op, + PatternRewriter &rewriter) const override { + // TODO(b/148313088): define sharding attribute struct in MLIR intead of + // using a string. + if (!op.get_XlaSharding().has_value()) return failure(); + + NamedAttribute call_target_name = rewriter.getNamedAttr( + "call_target_name", rewriter.getStringAttr("Sharding")); + + auto custom_call = rewriter.create( + op.getLoc(), op.getType(), op.getInput(), + ArrayRef{call_target_name}); + custom_call->setAttr(kShardingAttr, op.get_XlaShardingAttr()); + rewriter.replaceOp(op, custom_call.getResult(0)); + + return success(); + } +}; + +// Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO. +class ConvertInplaceUpdateOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, + PatternRewriter &rewriter) const override { + auto input = mlir::dyn_cast>(op.getX()); + if (!input) return failure(); + auto indices = op.getI(); + auto updates = op.getV(); + + // Slice each row of `i` and `v` to perform a separate dynamic-update-slice + // on the contents of `x`. + auto input_type = mlir::cast(input.getType()); + auto updates_type = mlir::cast(updates.getType()); + auto indices_type = mlir::cast(indices.getType()); + if (!input_type.hasRank()) return failure(); + if (!updates_type.hasRank() || updates_type.isDynamicDim(0)) + return failure(); + if (!indices_type.hasStaticShape()) return failure(); + + if (indices_type.getRank() != 1) return failure(); + + SmallVector unpacked_indices_type( + indices_type.getDimSize(0), tensorflow::GetTypeFromTFTensorShape( + {}, indices_type.getElementType())); + // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are + // required to have matching types. This rewrite rule creates + // DynamicUpdateSlice ops where the first "start index" is always i32 and + // subsequent ones are constructed based on zero_attr. Thus the type + // for zero_attr needs to be i32 as well. + auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0); + auto unpacked_indices = rewriter.create( + op.getLoc(), unpacked_indices_type, indices, zero_attr); + + SmallVector split_updates_shape; + split_updates_shape.append(updates_type.getShape().begin(), + updates_type.getShape().end()); + split_updates_shape.front() = 1; + SmallVector split_updates_type; + split_updates_type.resize( + updates_type.getShape().front(), + tensorflow::GetTypeFromTFTensorShape(split_updates_shape, + updates_type.getElementType())); + + auto cst = + rewriter.create(op.getLoc(), zero_attr).getResult(); + auto split_updates = rewriter.create( + op.getLoc(), split_updates_type, cst, updates); + + SmallVector input_indices; + input_indices.resize(input_type.getRank(), cst); + + for (auto pair : + llvm::zip(unpacked_indices.getOutput(), split_updates.getOutput())) { + input_indices.front() = std::get<0>(pair); + input = rewriter.create( + op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + +// Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO. +class ConvertXlaDynamicUpdateSliceOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, + PatternRewriter &rewriter) const override { + auto indices_type = + mlir::dyn_cast(op.getIndices().getType()); + if (!indices_type || !indices_type.hasStaticShape() || + indices_type.getShape().size() != 1) + return failure(); + + SmallVector unpacked_indices_type( + indices_type.getDimSize(0), tensorflow::GetTypeFromTFTensorShape( + {}, indices_type.getElementType())); + auto unpacked_indices = rewriter.create( + op.getLoc(), unpacked_indices_type, op.getIndices(), + IntegerAttr::get(rewriter.getIntegerType(64), 0)); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInput(), op.getUpdate(), + unpacked_indices.getOutput()); + return success(); + } +}; + +// Converts a TF XlaReduceScatter op to ReduceScatter HLO. +class ConvertXlaReduceScatterOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaReduceScatterOp op, + PatternRewriter &rewriter) const override { + DenseIntElementsAttr group_assignment; + if (!matchPattern(op.getGroupAssignment(), m_Constant(&group_assignment))) + return failure(); + auto replica_groups = + mlir::cast(hlo::convertElementsAttr( + group_assignment, rewriter.getIntegerType(64))); + if (replica_groups.getType().getRank() != 2) return failure(); + + APInt scatter_dimension; + if (!matchPattern(op.getScatterDimension(), + m_ConstantInt(&scatter_dimension))) + return failure(); + + Location loc = op.getLoc(); + Type element_type = getElementTypeOrSelf(op.getInput().getType()); + + auto reduce_scatter = rewriter.create( + loc, op.getType(), op.getInput(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + scatter_dimension.getSExtValue()), + replica_groups, ChannelHandleAttr()); + StringRef reduce_op = op.getReduceOp(); + if (reduce_op == "Add") { + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } else if (reduce_op == "Mul") { + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } else if (reduce_op == "Min") { + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } else if (reduce_op == "Max") { + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } else { + // For mean, add replicas in the same group. Then divide the sum by the + // number of replicas in each group below. + assert(reduce_op == "Mean"); + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } + Value result = reduce_scatter.getResult(); + + // For mean, divide the merge result by group size. + if (reduce_op == "Mean") { + int64_t replica_group_size = replica_groups.getType().getDimSize(1); + if (replica_group_size == 0) return failure(); + auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size, + &rewriter); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); + result = rewriter.create( + loc, result, divisor.getResult(), broadcast_dims); + } + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +// Converts tf.XlaReduceWindow to mhlo.ReduceWindow +class ConvertXlaReduceWindowOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaReduceWindowOp op, + PatternRewriter &rewriter) const override { + DenseElementsAttr window_dimensions, window_strides, base_dilations, + window_dilations, padding; + if (!(matchPattern(op.getWindowDimensions(), + m_Constant(&window_dimensions)) && + matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && + matchPattern(op.getBaseDilations(), m_Constant(&base_dilations)) && + matchPattern(op.getWindowDilations(), + m_Constant(&window_dilations)) && + matchPattern(op.getPadding(), m_Constant(&padding)))) + return failure(); + + Location loc = op.getLoc(); + + SmallVector result_types{op.getResult().getType()}; + // Create the mhlo.SelectAndScatter op. + auto reduce_window_op = rewriter.create( + loc, result_types, op.getInput(), op.getInitValue(), + mlir::cast(hlo::convertElementsAttr( + window_dimensions, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_strides, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + base_dilations, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_dilations, rewriter.getIntegerType(64))), + mlir::cast( + hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); + // Insert a call to the reducer in the region of the mhlo op. + mlir::SymbolRefAttr func = op.getComputation(); + auto func_op = cast(SymbolTable::lookupSymbolIn( + op->getParentOfType(), func)); + auto func_ty = func_op.getFunctionType(); + BuildBodyWithCall(rewriter, loc, func, func_ty, + &reduce_window_op.getBody()); + + rewriter.replaceOp(op, reduce_window_op.getResults()); + + return success(); + } +}; + +// Converts ClipByValue to XLA's clamp operation. Includes the broadcasting +// semantics for static and dynamic cases. +class ConvertClipByValueOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ClipByValueOp op, + PatternRewriter &rewriter) const override { + Value input = op.getX(); + Value min = op.getClipValueMin(); + Value max = op.getClipValueMax(); + + auto input_ty = mlir::cast(input.getType()); + auto min_ty = mlir::cast(min.getType()); + auto max_ty = mlir::cast(max.getType()); + + if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) { + return failure(); + } + + auto shape = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({input_ty.getRank()}, + rewriter.getI32Type()), + input); + + if (min_ty != input_ty) { + min = + rewriter.create(op.getLoc(), input_ty, min, shape); + } + + if (max_ty != input_ty) { + max = + rewriter.create(op.getLoc(), input_ty, max, shape); + } + + rewriter.replaceOpWithNewOp(op, input_ty, min, input, max); + return success(); + } +}; + +// Converts ConstOp to XLA's constant operation and introduces a tensor cast if +// needed. +class ConvertConstOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ConstOp op, + PatternRewriter &rewriter) const override { + // Convert only for valid HLO tensors. + auto ty = mlir::dyn_cast(op.getType()); + if (!ty || + !mlir::isa(ty.getElementType())) + return failure(); + + Location loc = op.getLoc(); + Value result = rewriter.create(loc, op.getValue()); + if (result.getType() != op.getType()) + result = rewriter.create(loc, op.getType(), result); + rewriter.replaceOp(op, result); + return success(); + } +}; + +// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by +// setting appropriate window dimensions, with the given aggregation op as the +// reduction function. The input tensor needs to have a static shape, and 'axis' +// must be const. The TableGen pattern is not used for this rewrite because it +// involves regions. +template +class ConvertCumOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + auto input = mlir::dyn_cast>(op.getX()); + if (!input) return failure(); + auto input_type = mlir::dyn_cast(input.getType()); + if (!input_type || !input_type.hasStaticShape()) { + return failure(); + } + + ArrayRef input_shape = input_type.getShape(); + int64_t rank = input_shape.size(); + + // We can only match when the axis is a constant scalar. + DenseIntElementsAttr axis_attr; + if (!matchPattern(op.getAxis(), m_Constant(&axis_attr))) { + return failure(); + } + + // Get the dimension to apply the reduction on, and offset properly if it is + // negative. + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (axis < 0) { + axis += rank; + } + + // If we're supposed to sum things up in the reverse direction, we reverse + // the input and then later reverse the output. + if (op.getReverse()) { + llvm::SmallVector dims_to_reverse({axis}); + input = rewriter.create( + op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + + // Convert if we need to enlarge the element type's bitwidth to avoid + // precision loss. + Type input_element_type = input_type.getElementType(); + + // TODO(hinsu): Handle complex element types. + if (!input_element_type.isIntOrFloat()) return failure(); + + Type sum_element_type = GetSumAccumulationType(input_element_type); + input = rewriter.create(op.getLoc(), input, sum_element_type); + + SmallVector window_dims(rank, 1); + SmallVector window_strides(rank, 1); + window_dims[axis] = input_shape[axis]; + + SmallVector paddings(rank * 2, 0); + paddings[axis * 2] = + std::max(input_shape[axis] - 1, static_cast(0)); + auto paddings_attr = + DenseIntElementsAttr::get(tensorflow::GetTypeFromTFTensorShape( + {rank, 2}, rewriter.getIntegerType(64)), + paddings); + + int64_t init_value = (std::is_same::value) ? 0 : 1; + Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, + &rewriter); + + auto reduce = rewriter.create( + op.getLoc(), input.getType(), input, init, + GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)), + GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + BuildReduceBody(sum_element_type, &reduce.getBody(), + &rewriter); + Value result = reduce.getResult(0); + + if (op.getExclusive()) { + // In "exclusive" operation, the output will start with the "init" (0) + // values. There is no way to express that as a ReduceWindowOp, so run the + // normal operation, and then use a PadOp to add the 0 "column" on the + // left and cut away the last column on the right. + llvm::SmallVector low_padding(rank, 0); + llvm::SmallVector high_padding(rank, 0); + llvm::SmallVector interior_padding(rank, 0); + low_padding[axis] = 1; + high_padding[axis] = -1; + result = rewriter.create( + op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter), + GetI64ElementsAttr(high_padding, &rewriter), + GetI64ElementsAttr(interior_padding, &rewriter)); + } + + // Convert back if we enlarged the element type's bitwidth. + result = + rewriter.create(op.getLoc(), result, input_element_type); + + if (op.getReverse()) { + llvm::SmallVector dims_to_reverse({axis}); + result = rewriter.create( + op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +using ConvertCumsumOp = ConvertCumOp; +using ConvertCumprodOp = ConvertCumOp; + +// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard +// dialect lowerings. This involves extracting the shape type, extracting and +// converting each dimension to a known integer type, and repacking into a final +// tensor. +class ConvertShapeOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ShapeOp op, + PatternRewriter &rewriter) const override { + Value input = op.getInput(); + + auto result_ty = mlir::dyn_cast(op.getResult().getType()); + if (!result_ty) { + return failure(); + } + + auto index_tensor = tensorflow::GetTypeFromTFTensorShape( + result_ty.getShape(), rewriter.getIndexType()); + auto shape_op = + rewriter.create(op.getLoc(), index_tensor, input); + rewriter.replaceOpWithNewOp(op, result_ty, shape_op); + return success(); + } +}; + +class ConvertDynamicExpandDimsOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ExpandDimsOp op, + PatternRewriter &rewriter) const override { + auto input = op.getInput(); + auto input_ty = mlir::cast(input.getType()); + auto result_ty = mlir::cast(op.getType()); + if (!result_ty.hasRank() || !input_ty.hasRank() || + result_ty.hasStaticShape()) { + return failure(); + } + + DenseIntElementsAttr expand_dims_attr; + if (!matchPattern(op.getDim(), m_Constant(&expand_dims_attr))) { + return failure(); + } + + auto shape = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({input_ty.getRank()}, + rewriter.getIndexType()), + input); + auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getValues()); + + llvm::SmallVector dims; + dims.resize(result_ty.getRank()); + + auto inserted_dim = expand_dims[0].getSExtValue(); + + // Handle the negative value use case. + if (inserted_dim < 0) { + inserted_dim += result_ty.getRank(); + // This means the value is completely incorrect, just return. + if (inserted_dim < 0) { + return failure(); + } + } + + dims[inserted_dim] = + rewriter.create(op.getLoc(), 1); + + for (int i = 0; i < dims.size() - 1; i++) { + // Add the extracted dim. + Value index = rewriter.create(op.getLoc(), i); + Value dim = rewriter.create(op.getLoc(), shape, index); + dims[i >= inserted_dim ? i + 1 : i] = dim; + } + + auto from_extents = + rewriter.create(op.getLoc(), dims); + rewriter.replaceOpWithNewOp(op, result_ty, input, + from_extents); + return success(); + } +}; + +class ConvertDynamicSqueezeOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SqueezeOp op, + PatternRewriter &rewriter) const override { + auto input = op.getInput(); + auto input_ty = mlir::cast(input.getType()); + auto result_ty = mlir::cast(op.getType()); + if (!result_ty.hasRank() || !input_ty.hasRank() || + result_ty.hasStaticShape()) { + return failure(); + } + + // The fully dynamic case is unsupported. + if (op.getSqueezeDims().empty()) { + return failure(); + } + + SmallVector squeeze_dims; + int64_t input_rank = input_ty.getRank(); + for (const auto &squeeze_dim_apint : + op.getSqueezeDims().getAsValueRange()) { + int64_t squeeze_dim = squeeze_dim_apint.getSExtValue(); + // Handle negative inputs. + if (squeeze_dim < 0) squeeze_dim += input_rank; + assert(squeeze_dim >= 0 && squeeze_dim < input_rank && + "squeeze dim out of bounds"); + + squeeze_dims.push_back(squeeze_dim); + } + + // Collect the unsqueezed dimensions. + llvm::SmallVector dims; + for (int64_t i = 0; i != input_rank; ++i) { + if (llvm::is_contained(squeeze_dims, i)) continue; + dims.push_back(rewriter.create(op.getLoc(), input, i)); + } + + auto from_extents = + rewriter.create(op.getLoc(), dims); + rewriter.replaceOpWithNewOp(op, result_ty, input, + from_extents); + return success(); + } +}; + +// Converts tf.XlaConvV2 to mhlo.Conv +class ConvertXlaConvV2Op : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaConvV2Op op, + PatternRewriter &rewriter) const override { + DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr, + rhs_dilation_attr, feature_group_count_attr; + if (!(matchPattern(op.getWindowStrides(), + m_Constant(&window_strides_attr)) && + matchPattern(op.getPadding(), m_Constant(&padding_attr)) && + matchPattern(op.getLhsDilation(), m_Constant(&lhs_dilation_attr)) && + matchPattern(op.getRhsDilation(), m_Constant(&rhs_dilation_attr)) && + matchPattern(op.getFeatureGroupCount(), + m_Constant(&feature_group_count_attr)))) + return failure(); + + auto window_strides_named_attr = rewriter.getNamedAttr( + "window_strides", + mlir::cast(hlo::convertElementsAttr( + window_strides_attr, rewriter.getIntegerType(64)))); + + auto padding_named_attr = rewriter.getNamedAttr( + "padding", mlir::cast(hlo::convertElementsAttr( + padding_attr, rewriter.getIntegerType(64)))); + + auto lhs_dilation_named_attr = rewriter.getNamedAttr( + "lhs_dilation", + mlir::cast(hlo::convertElementsAttr( + lhs_dilation_attr, rewriter.getIntegerType(64)))); + + auto rhs_dilation_named_attr = rewriter.getNamedAttr( + "rhs_dilation", + mlir::cast(hlo::convertElementsAttr( + rhs_dilation_attr, rewriter.getIntegerType(64)))); + + int64_t feature_group_count_val = + feature_group_count_attr.getValues()[0].getInt(); + auto feature_group_count_named_attr = rewriter.getNamedAttr( + "feature_group_count", + rewriter.getI64IntegerAttr(feature_group_count_val)); + + auto batch_group_count_named_attr = + rewriter.getNamedAttr("batch_group_count", op.getBatchGroupCountAttr()); + + xla::ConvolutionDimensionNumbers dnums; + dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); + auto dimension_numbers_named_attr = rewriter.getNamedAttr( + "dimension_numbers", + xla::ConvertConvDimensionNumbers(dnums, &rewriter)); + + xla::PrecisionConfig precision_config; + precision_config.ParseFromString( + op.getPrecisionConfigAttr().getValue().str()); + auto precision_config_named_attr = rewriter.getNamedAttr( + "precision_config", + xla::ConvertPrecisionConfig(&precision_config, &rewriter)); + + SmallVector operands{op.getLhs(), op.getRhs()}; + NamedAttribute attrs[] = { + window_strides_named_attr, padding_named_attr, + lhs_dilation_named_attr, rhs_dilation_named_attr, + feature_group_count_named_attr, batch_group_count_named_attr, + dimension_numbers_named_attr, precision_config_named_attr}; + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + llvm::ArrayRef(attrs)); + return success(); + } +}; + +// Converts tf.XlaSelectAndScatter to mhlo.SelectAndScatter +class ConvertXlaSelectAndScatterOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaSelectAndScatterOp op, + PatternRewriter &rewriter) const override { + ElementsAttr window_dimensions, window_strides, padding; + if (!(matchPattern(op.getWindowDimensions(), + m_Constant(&window_dimensions)) && + matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && + matchPattern(op.getPadding(), m_Constant(&padding)))) + return failure(); + + Location loc = op.getLoc(); + + SmallVector result_types{op.getResult().getType()}; + // Create the mhlo.SelectAndScatter op. + auto select_and_scatter_op = rewriter.create( + loc, result_types, op.getOperand(), op.getSource(), op.getInitValue(), + mlir::cast(hlo::convertElementsAttr( + window_dimensions, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_strides, rewriter.getIntegerType(64))), + mlir::cast( + hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); + + auto insert_call_to = [&](const mlir::SymbolRefAttr &func, Region *region) { + auto func_op = cast(SymbolTable::lookupSymbolIn( + op->getParentOfType(), func)); + auto func_ty = func_op.getFunctionType(); + BuildBodyWithCall(rewriter, loc, func, func_ty, region); + }; + + // Insert a call to the select function in the select region of the mhlo op. + insert_call_to(op.getSelect(), &select_and_scatter_op.getSelect()); + // Insert a call to the scatter function in the scatter region of the mhlo + // op. + insert_call_to(op.getScatter(), &select_and_scatter_op.getScatter()); + + rewriter.replaceOp(op, select_and_scatter_op.getResult()); + + return success(); + } +}; + +// Convert tf.XlaSort to mhlo.Sort +class ConvertXlaSortOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaSortOp op, + PatternRewriter &rewriter) const override { + // Create the sort op. + Type element_type = getElementTypeOrSelf(op.getInput().getType()); + auto sort_op = + createSortOp(&rewriter, op.getLoc(), {op.getInput()}, {element_type}, + /*dimension=*/-1, /*isStable=*/false, + /*direction=*/ComparisonDirection::LT); + rewriter.replaceOp(op, sort_op.getResult(0)); + return success(); + } +}; + +inline std::optional TensorFlowRngAlgToXla( + tensorflow::Algorithm alg) { + if (alg == tensorflow::RNG_ALG_PHILOX) { + return xla::RandomAlgorithm::RNG_PHILOX; + } else if (alg == tensorflow::RNG_ALG_THREEFRY) { + return xla::RandomAlgorithm::RNG_THREE_FRY; + } else if (alg == tensorflow::RNG_ALG_AUTO_SELECT) { + return xla::RandomAlgorithm::RNG_DEFAULT; + } + return std::nullopt; +} + +// Converts tf.XlaRngBitGenerator op to mhlo.RngBitGenerator op. +class ConvertXlaRngBitGeneratorOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaRngBitGeneratorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + DenseElementsAttr algorithm; + if (!(matchPattern(op.getAlgorithm(), m_Constant(&algorithm))) || + algorithm.getType().getRank()) { + return op.emitOpError() << "algorithm must be a constant scalar"; + } + auto alg = static_cast( + algorithm.getValues()[0].getInt()); + auto xla_alg = TensorFlowRngAlgToXla(alg); + if (!xla_alg) { + return op.emitOpError() << "unknown algorithm"; + } + + auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( + rewriter.getContext(), + *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.value())); + auto rng_bit_generator_op = rewriter.create( + loc, op.getResultTypes(), algorithm_attr, op.getInitialState()); + + rewriter.replaceOp(op, rng_bit_generator_op.getResults()); + + return success(); + } +}; + +// Converts tf.XlaVariadicReduceV2 to mhlo.Reduce +class ConvertXlaVariadicReduceV2Op + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaVariadicReduceV2Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + mlir::SymbolRefAttr func = op.getReducer(); + auto func_op = cast(SymbolTable::lookupSymbolIn( + op->getParentOfType(), func)); + auto func_ty = func_op.getFunctionType(); + SmallVector elementTypes{llvm::map_range( + func_ty.getResults(), + [](Type ty) { return mlir::cast(ty).getElementType(); })}; + + // Create the mhlo.reduce op. + auto reduce_op = rewriter.create( + loc, op.getInputs(), op.getInitValues(), + GetI64ElementsAttr(op.getDimensionsToReduce()), elementTypes); + + // Insert a call to the reducer in the region of the mhlo op. + BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_op.getBody()); + + rewriter.replaceOp(op, reduce_op.getResults()); + + return success(); + } +}; + +// Convert tf.XlaVariadicSort to mhlo.Sort +class ConvertXlaVariadicSortOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaVariadicSortOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + ElementsAttr dimension; + matchPattern(op.getDimension(), m_Constant(&dimension)); + // Create the mhlo.sort op. + auto sort_op = rewriter.create( + loc, op.getInputs(), dimension.getValues()[0].getInt(), + op.getIsStable()); + mlir::SymbolRefAttr func = op.getComparator(); + auto func_op = cast(SymbolTable::lookupSymbolIn( + op->getParentOfType(), func)); + auto func_ty = func_op.getFunctionType(); + // Insert a call to the reducer in the region of the mhlo op. + BuildBodyWithCall(rewriter, loc, func, func_ty, &sort_op.getComparator()); + + rewriter.replaceOp(op, sort_op.getResults()); + return success(); + } +}; + +// Convert tf.XlaReducePrecision to mhlo.ReducePrecision +class ConvertXlaReducePrecisionOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaReducePrecisionOp op, + PatternRewriter &rewriter) const override { + IntegerType int32_type = rewriter.getIntegerType(32); + APInt exponent_bits = op.getExponentBitsAttr().getValue(); + // Truncating to 32-bits is safe, since pasing any number above the dtype + // size (which is at most 64, for float64) is equivalent to passing the + // dtype size. + IntegerAttr new_exponent_attr = + IntegerAttr::get(int32_type, exponent_bits.truncSSat(32)); + APInt mantissa_bits = op.getMantissaBitsAttr().getValue(); + IntegerAttr new_mantissa_attr = + IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32)); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getOperand(), new_exponent_attr, + new_mantissa_attr); + return success(); + } +}; + +class LowerYieldOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::YieldOp op, TF::YieldOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +// Returns a new tensor type from the given type with element type updated to +// the given type. +TensorType UpdateElementTypeTo(Type ty, Type element_ty) { + auto ranked_ty = mlir::dyn_cast(ty); + if (!ranked_ty) { + return UnrankedTensorType::get(element_ty); + } + return RankedTensorType::get(ranked_ty.getShape(), element_ty, + ranked_ty.getEncoding()); +} + +template +class LowerControlFlowOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SrcOpT op, typename SrcOpT::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + DstOpT mhlo_op; + Location loc = op.getLoc(); + + // To handle quant type conversions, use the converted operands' element + // types and original source op's shapes and encoding to get converted op's + // result types. This is only done for the While op for now. + llvm::SmallVector element_types; + int64_t num_results = op.getNumResults(); + if constexpr (std::is_same::value) { + element_types.reserve(num_results); + for (Value value : adaptor.getOperands()) { + element_types.push_back(getElementTypeOrSelf(value.getType())); + } + } + + if constexpr (std::is_same::value) { + // Explicitly handle the Case op because it has variadic regions and takes + // the number of regions as an input along with the operands. + mhlo_op = rewriter.create(loc, op.getResultTypes(), + adaptor.getBranchIndex(), + op.getBranches().size()); + } else if constexpr (std::is_same::value) { + llvm::SmallVector while_result_types; + while_result_types.reserve(num_results); + for (int64_t idx = 0; idx < num_results; ++idx) { + auto ty = UpdateElementTypeTo(op.getType(idx), element_types[idx]); + while_result_types.push_back(ty); + } + + mhlo_op = rewriter.create(loc, TypeRange(while_result_types), + adaptor.getOperands()); + } else { + mhlo_op = rewriter.create(loc, op.getResultTypes(), + adaptor.getOperands()); + } + + int64_t num_regions = op.getNumRegions(); + for (int64_t idx = 0; idx < num_regions; ++idx) { + Region ®ion = mhlo_op.getBodyRegion(idx); + rewriter.inlineRegionBefore(op.getBodyRegion(idx), region, region.end()); + + // Update region's entry blocks argument types to handle quantized element + // types. + if constexpr (std::is_same::value) { + TypeConverter::SignatureConversion signature(num_results); + Block &block = region.front(); + for (const auto &[block_idx, original_ty] : + llvm::enumerate(block.getArgumentTypes())) { + TensorType updated_ty = + UpdateElementTypeTo(original_ty, element_types[block_idx]); + signature.addInputs(block_idx, {updated_ty}); + } + rewriter.applySignatureConversion(®ion.front(), signature); + } + } + + // Replace all uses of `op` results with the newly created op. + rewriter.replaceOp(op, mhlo_op); + return success(); + } +}; + +// Keep all these in the odml namespace to avoid collisions with the tf2xla +// version for now. +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_legalize_tf.inc" + +// LINT.IfChange +void PopulatePatterns(MLIRContext *context, RewritePatternSet *patterns) { + populateWithGenerated(*patterns); + // clang-format off + patterns->add< + ConvertAllOp, + ConvertAnyOp, + ConvertArgMaxOp, + ConvertArgMinOp, + ConvertBatchMatMulV2Op, + ConvertBiasAddOp, + ConvertBroadcastToOp, + ConvertBF16FloorDivOp, + ConvertClipByValueOp, + ConvertConstOp, + ConvertConv2DOp, + ConvertConv3DOp, + ConvertDepthConv2DOp, + ConvertConv2DBackpropFilterOp, + ConvertConv3DBackpropFilterOp, + ConvertConv2DBackpropInputOp, + ConvertConv3DBackpropInputOp, + ConvertCumprodOp, + ConvertCumsumOp, + ConvertDiagPartOp, + ConvertDynamicExpandDimsOp, + ConvertDynamicSqueezeOp, + ConvertEinsumOp, + ConvertRFFTOp, + ConvertIRFFTOp, + ConvertFusedBatchNormGradOp, + ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, + ConvertFusedBatchNormV2Op, + ConvertFusedBatchNormV3Op, + ConvertInfeedDequeueTupleOp, + ConvertIdentityNOp, + ConvertInplaceUpdateOp, + ConvertLinSpaceOp, + ConvertMaxOp, + ConvertMinOp, + ConvertAvgPool2DOp, + ConvertAvgPool3DOp, + ConvertAvgPool2DGradOp, + ConvertAvgPool3DGradOp, + ConvertMaxPool2DOp, + ConvertMaxPool3DOp, + ConvertMaxPool2DGradOp, + ConvertMaxPool3DGradOp, + ConvertMeanOp, + ConvertOneHotOp, + ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, + ConvertDynamicRangeOp, + ConvertMatrixDiagPartV3Op, + ConvertRangeOp, + ConvertSelectOp, + ConvertShapeOp, + ConvertSplitOp, + ConvertSplitVOp, + ConvertStridedSliceOp, + ConvertStridedSliceGradOp, + ConvertSumOp, + ConvertTensorScatterAddOp, + ConvertTensorScatterSubOp, + ConvertTensorScatterMinOp, + ConvertTensorScatterMaxOp, + ConvertTensorScatterUpdateOp, + ConvertTileOp, + ConvertTopKV2Op, + ConvertUnpackOp, + ConvertUnsortedSegmentMaxOp, + ConvertUnsortedSegmentMinOp, + ConvertUnsortedSegmentProdOp, + ConvertUnsortedSegmentSumOp, + ConvertRandomShuffleOp, + ConvertXlaShardingOp, + ConvertXlaDynamicUpdateSliceOp, + ConvertXlaConvV2Op, + ConvertXlaReducePrecisionOp, + ConvertXlaReduceScatterOp, + ConvertXlaReduceWindowOp, + ConvertXlaRngBitGeneratorOp, + ConvertXlaSelectAndScatterOp, + ConvertXlaSortOp, + ConvertXlaVariadicReduceV2Op, + ConvertXlaVariadicSortOp, + ConvertRollOp, + ConvertLeakyReluOp, + ConvertLeakyReluGradOp, + ConvertSplitOpDynamic, + ConvertSliceOpDynamic, + ConvertTileOpDynamic, + ConvertUnpackOpDynamic, + ConvertSigmoidGradOpDynamic, + ConvertConv2DDynamic, + ConvertPadOpDynamic, + ConvertGatherNdOpDynamic, + LowerControlFlowOp, + LowerControlFlowOp, + LowerControlFlowOp, + LowerYieldOp>(context); + // clang-format on +} +// LINT.ThenChange(:MlirAlwaysOps) +} // end namespace +} // end namespace mhlo + +namespace odml { +void PopulateLegalizeTfPatterns(MLIRContext *context, + RewritePatternSet *patterns) { + mlir::mhlo::PopulatePatterns(context, patterns); +} +} // end namespace odml +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h new file mode 100644 index 00000000000000..9594769e93f71c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { + +namespace func { +class FuncOp; +} +class ModuleOp; +class Operation; +template +class OperationPass; +class Pass; + +namespace odml { + +/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern +/// list. +void PopulateLegalizeTfPatterns(MLIRContext* context, + RewritePatternSet* patterns); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td new file mode 100644 index 00000000000000..185216448a15ed --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td @@ -0,0 +1,802 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern definition file for TF to XLA. + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Tensor/IR/TensorOps.td" +include "stablehlo/dialect/ChloOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "mhlo/IR/hlo_ops.td" + +def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; +def UnsignedIntTensor : TensorOf<[UI8, UI16, UI32, UI64]>; + +// IEEE compliant floating point tensors. +def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; + +//===----------------------------------------------------------------------===// +// BatchNorm op patterns. +//===----------------------------------------------------------------------===// + +def FalseBoolAttr : AttrConstraint().getValue()">>; +def TrueBoolAttr : AttrConstraint().getValue()">>; + +def CastValueToI64: NativeCodeCall< + "CastValueToI64($0.getLoc(), $1, &$_builder)">; + +def CastValueToElementType: NativeCodeCall< + "$_builder.create($0.getLoc(), $1, " + "getElementTypeOrSelf($2.getType()))">; + +// Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is +// the corresponding value of ranked tensor type whose axis is referred in $0. +def GetHLOAxisFromTFAxis : NativeCodeCall< + "GetHLOAxisFromTFAxis(" + "$0, $1.getType().cast().getRank(), &$_builder)">; + +// Same as the above but with $1 of type operand_range from variadic TensorFlow +// input. +def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< + "GetHLOAxisFromTFAxis(" + "$0, (*$1.begin()).getType().cast().getRank(), " + "&$_builder)">; + +def CastElementsToI64Elements : NativeCodeCall< + "hlo::convertElementsAttr(" + "$0.cast(), $_builder.getIntegerType(64)).cast()">; + +def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; + +//===----------------------------------------------------------------------===// +// ApproximateEqual op pattern. +//===----------------------------------------------------------------------===// + +class MHLO_ComparisonDirectionValue : + ConstantAttr; + +class CHLO_ComparisonDirectionValue : + ConstantAttr; + +// TODO(b/228291745): Assert that $x and $y have the same shape. +def : Pat<(TF_ApproximateEqualOp:$result $x, $y, $tolerance), + (CHLO_BroadcastCompareOp + (MHLO_AbsOp:$abs (MHLO_SubtractOp $x, $y)), + (CastValueToElementType $result, (MHLO_ConstantOp $tolerance), $abs), + (NullDenseI64ArrayAttr), + CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE))>; + +//===----------------------------------------------------------------------===// +// Assert op pattern. +//===----------------------------------------------------------------------===// + +// HLO and XLA doesn't support Assertions. +def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; + +//===----------------------------------------------------------------------===// +// Binary op patterns. +//===----------------------------------------------------------------------===// + +// Check that two values can be broadcasted together +def AreBroadcastCompatible : Constraint, + "types must be broadcastable">; + +class DirectBinaryPat + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), + (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; + +foreach fromToBinPair = [[TF_AddV2Op, CHLO_BroadcastAddOp], + [TF_Atan2Op, CHLO_BroadcastAtan2Op], + [TF_ComplexOp, CHLO_BroadcastComplexOp], + [TF_DivOp, CHLO_BroadcastDivOp], + [TF_LeftShiftOp, CHLO_BroadcastShiftLeftOp], + [TF_MaximumOp, CHLO_BroadcastMaxOp], + [TF_MinimumOp, CHLO_BroadcastMinOp], + [TF_ModOp, CHLO_BroadcastRemOp], + [TF_MulOp, CHLO_BroadcastMulOp], + [TF_NextAfterOp, CHLO_BroadcastNextAfterOp], + [TF_PolygammaOp, CHLO_BroadcastPolygammaOp], + [TF_PowOp, CHLO_BroadcastPowOp], + [TF_RealDivOp, CHLO_BroadcastDivOp], + [TF_SubOp, CHLO_BroadcastSubOp], + [TF_ZetaOp, CHLO_BroadcastZetaOp]] in + def : DirectBinaryPat; + +def LowerRightShiftSigned : + Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastShiftRightArithmeticOp $l, $r, + (BinBroadcastDimensions $l, $r)), + [(SignedIntTensor $r)]>; + +def LowerRightShiftUnsigned : + Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastShiftRightLogicalOp $l, $r, + (BinBroadcastDimensions $l, $r)), + [(UnsignedIntTensor $r)]>; + +// Performs a substitution of FloorDiv, pseudo code below: +// +// return floor(div(x, y)) +def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), + (MHLO_FloorOp + (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), + [(IEEEFloatTensor $l)]>; + +// Performs a substitution of FloorDiv for integer tensors, which required +// additional correction for a negative numerator / denominator. Equivalent +// pseudocode is shown below: +// +// T z = x / y +// return (z * y != x && (x < 0) != (y < 0)) ? z - 1 : z +// +// BroadcastToDimensions is used to compute the broadcast attr to higher +// dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') +// without returning the broadcast of 'r' to broadcast('l', 'r'). +def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), + (MHLO_SelectOp + (CHLO_BroadcastAndOp + (CHLO_BroadcastCompareOp + (CHLO_BroadcastMulOp:$mul + (CHLO_BroadcastDivOp:$div $l, $r, + (BinBroadcastDimensions $l, $r)), + $r, (BinBroadcastDimensions $div, $r)), + $l, (BinBroadcastDimensions $mul, $l), CHLO_ComparisonDirectionValue<"NE">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (CHLO_BroadcastCompareOp + (CHLO_BroadcastCompareOp:$l_cmp $l, + (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (CHLO_BroadcastCompareOp:$r_cmp $r, + (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (BinBroadcastDimensions $l_cmp, $r_cmp), CHLO_ComparisonDirectionValue<"NE">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (NullDenseI64ArrayAttr)), + (CHLO_BroadcastSubOp $div, + (MHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), + (NullDenseI64ArrayAttr)), $div), + [(SignedIntTensor $l)]>; + +// FloorDiv of unsigned is just div. +def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), + [(UnsignedIntTensor $l)]>; + +// Performs a substitution of FloorMod designed to correct for possibly negative +// values. Pseudocode shown below: +// +// T trunc_mod = std::fmod(x, y); +// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y +// : trunc_mod +def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), + (MHLO_SelectOp + (CHLO_BroadcastAndOp + (CHLO_BroadcastCompareOp + (CHLO_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), + (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"NE">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (CHLO_BroadcastCompareOp + (CHLO_BroadcastCompareOp:$r_cmp $r, + (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (CHLO_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, + (BinBroadcastDimensions $rem, $r_zeros), CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (BinBroadcastDimensions $r_cmp, $rem_cmp), CHLO_ComparisonDirectionValue<"NE">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (NullDenseI64ArrayAttr)), + (CHLO_BroadcastAddOp $r, + $rem, (BinBroadcastDimensions $r, $rem)), $rem), + [(TensorOf<[I8, I16, I32, I64, F16, F32, F64]> $l)]>; + +// FloorMod of unsigned is just mod. +def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastRemOp $l, $r, (BinBroadcastDimensions $l, $r)), + [(UnsignedIntTensor $l)]>; + +def Get2DTransposePerm: NativeCodeCall< + "Get2DTransposePerm($0, &$_builder)">; + +def : Pat<(TF_RiscAddOp $l, $r), (MHLO_AddOp $l, $r)>; + +def : Pat<(TF_RiscDotOp $a, $b, $transpose_a, $transpose_b), + (MHLO_DotOp + (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), + (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), + /*precision_config=*/(NullArrayAttr))>; + +//===----------------------------------------------------------------------===// +// Logical & bitwise binary op patterns. +//===----------------------------------------------------------------------===// + +class DirectLogicalBinaryPat + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), + (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), + [(AnyTypeOf<[SignedIntTensor, UnsignedIntTensor]> $l)]>; + +foreach fromToBinPair = [[TF_LogicalAndOp, CHLO_BroadcastAndOp], + [TF_LogicalOrOp, CHLO_BroadcastOrOp], + [TF_BitwiseAndOp, CHLO_BroadcastAndOp], + [TF_BitwiseOrOp, CHLO_BroadcastOrOp], + [TF_BitwiseXorOp, CHLO_BroadcastXorOp]] in + def : DirectLogicalBinaryPat; + +//===----------------------------------------------------------------------===// +// Compare op patterns. +//===----------------------------------------------------------------------===// + +class DirectComparePat + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction, + (CHLO_DEFAULT_COMPARISON_TYPE))>; + +def : DirectComparePat>; +def : DirectComparePat>; +def : DirectComparePat>; +def : DirectComparePat>; + +class EqualityPat + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r, + TrueBoolAttr:$incompatible_shape_error), + (CHLO_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction, + (CHLO_DEFAULT_COMPARISON_TYPE)), + [(MHLO_Tensor $l)]>; + +def : EqualityPat>; +def : EqualityPat>; + +//===----------------------------------------------------------------------===// +// Concat op patterns. +//===----------------------------------------------------------------------===// + +def OneElementAttrPred + : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; + +def OneElementAttr + : ElementsAttrBase, + "Scalar ElementsAttr">; + +def HasRankedFirstOperand + : Constraint()">>; + +def IsShapedTensor + : Constraint()">>; + +// This pattern converts TensorFlow axis format to HLO axis format which +// doesn't wrap around like TensorFlow and is always positive. For this +// conversion, use the first input to get inputs rank. Other inputs need not be +// ranked. +// Defining op for `axis` is TensorFlow constant op in the pattern as during +// the conversion, original Concat op operands still refers to the old ops even +// if HLO constant op is introduced as an replacement for the TensorFlow +// Constant op. +def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), + (MHLO_ConcatenateOp $inputs, + (GetHLOAxisFromTFAxisVariadic $axis, $inputs)), + [(HasRankedFirstOperand $inputs)]>; + +//===----------------------------------------------------------------------===// +// CollectivePermute op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)), + (MHLO_CollectivePermuteOp $input, + (CastElementsToI64Elements $source_target_pairs), + (NullChannelHandleAttr))>; + +//===----------------------------------------------------------------------===// +// CrossReplicaSum op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)), + (MHLO_CrossReplicaSumOp $input, + (CastElementsToI64Elements $group_assignment))>; + +//===----------------------------------------------------------------------===// +// All2All op patterns. +//===----------------------------------------------------------------------===// + +def ValueToVariadic: NativeCodeCall<"SmallVector{$0}">; +def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), + (MHLO_AllToAllOp (ValueToVariadic $input), $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment), (NullChannelHandleAttr))>; + +//===----------------------------------------------------------------------===// +// FFT op patterns. +//===----------------------------------------------------------------------===// + +class MHLO_FftTypeValue : + ConstantAttr; + +def GetInnerDimFromValue : NativeCodeCall< + "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + +def CheckInnerDimStatic + : Constraint(), &$_builder)">>; + +def : Pat<(TF_FFTOp:$res $input), + (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), + [(CheckInnerDimStatic $input)]>; + +def : Pat<(TF_IFFTOp:$res $input), + (MHLO_FftOp $input, MHLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), + [(CheckInnerDimStatic $input)]>; + +//===----------------------------------------------------------------------===// +// GatherV2 op patterns. +//===----------------------------------------------------------------------===// + +// Here, $params and $indices needs to be ranked so that $axis and $batch_dims +// attributes can be converted from TensorFlow axis format supporting negative +// indexing to the HLO format. +def LegalizeGatherV2 : + Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices, + (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims), + (MHLO_TorchIndexSelectOp $params, $indices, + (GetHLOAxisFromTFAxis $axis, $params), + (GetHLOAxisFromTFAxis $batch_dims, $indices))>; + +//===----------------------------------------------------------------------===// +// Pad op patterns. +//===----------------------------------------------------------------------===// + +class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< + "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + +class SliceDenseIntElementsAttr : NativeCodeCall< + "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + +// Interior padding attribute based on the TF padding. +def GetInteriorPadding : NativeCodeCall < + "GetInteriorPadding($0.cast())">; + +def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), + (MHLO_PadOp $input, $c, + (SliceDenseIntElementsAttrColumn2D<"0"> $padding), + (SliceDenseIntElementsAttrColumn2D<"1"> $padding), + (GetInteriorPadding $padding))>; + +//===----------------------------------------------------------------------===// +// Identity op patterns. +//===----------------------------------------------------------------------===// + +foreach src = [TF_IdentityOp, TF_StopGradientOp, TF__EagerConstOp] in + def : Pat<(src $op), (replaceWithValue $op)>; + +// TODO(b/32223192): Support CheckNumerics in HLO. +foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in + def : Pat<(src $op, $msg), (replaceWithValue $op)>; + +//===----------------------------------------------------------------------===// +// MatMul op patterns. +//===----------------------------------------------------------------------===// + +def GetPrecisionConfig: NativeCodeCall< + "GetPrecisionConfig(&$_builder)">; + +def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), + (MHLO_DotOp + (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), + (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), + /*precision_config=*/(GetPrecisionConfig))>; + +//===----------------------------------------------------------------------===// +// Lower `tf.ZerosLike` +//===----------------------------------------------------------------------===// + +def : Pat<(TF_ZerosLikeOp AnyTensor:$arg), + (MHLO_ConstantLike<"0"> $arg)>; + +//===----------------------------------------------------------------------===// +// Lower `tf.OnesLike` +//===----------------------------------------------------------------------===// + +def : Pat<(TF_OnesLikeOp AnyTensor:$arg), + (MHLO_ConstantLike<"1"> $arg)>; + +//===----------------------------------------------------------------------===// +// Elu op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_EluOp AnyTensor:$features), + (MHLO_SelectOp + (MHLO_CompareOp + $features, + (MHLO_ConstantLike<"0">:$zero $features), + MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), + $features, + (MHLO_Expm1Op $features))>; + +def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), + (MHLO_SelectOp + (CHLO_BroadcastCompareOp + $features, + (MHLO_ConstantOp:$zero (GetScalarOfType<0> $features)), + (BinBroadcastDimensions $zero, $features), + CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE)), + $gradients, + (MHLO_MulOp + $gradients, + (CHLO_BroadcastAddOp + $features, + (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), + (BinBroadcastDimensions $one, $features))))>; + +//===----------------------------------------------------------------------===// +// Relu op patterns. +//===----------------------------------------------------------------------===// + +// TODO(hinsu): Make these patterns to TF to TF lowering. Relu6 lowering will +// require HLO canonicalization of min and max on a tensor to ClampOp. + +// TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. +def : Pat<(TF_ReluOp AnyTensor:$input), + (CHLO_BroadcastMaxOp + (MHLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, + (BinBroadcastDimensions $zero, $input)), + [(TF_IntOrFpTensor $input)]>; + +// TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. +def : Pat<(TF_Relu6Op AnyRankedTensor:$input), + (MHLO_ClampOp (MHLO_ConstantOp (GetScalarOfType<0> $input)), $input, + (MHLO_ConstantOp (GetScalarOfType<6> $input))), + [(TF_IntOrFpTensor $input)]>; + +// ReluGrad(gradients, features) = gradients * (features > 0) +// The condition that $gradients and $features need to have the same shape is +// implicitly enforced: $zero is created to have the same shape as $features, +// MHLO_SelectOp enforces that $gradients and $zero have the same shape. +def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), + (MHLO_SelectOp + (MHLO_CompareOp $features, (MHLO_ConstantLike<"0">:$zero $features), + MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), + $gradients, $zero)>; + +//===----------------------------------------------------------------------===// +// Softsign op patterns. +//===----------------------------------------------------------------------===// + +/// Converts a TF::SoftsignOp to HLO. +/// Softsign(features) = features / (1 + abs(features)) +def : Pat<(TF_SoftsignOp AnyTensor:$input), + (MHLO_DivOp + $input, + (MHLO_AddOp (MHLO_ConstantLike<"1"> $input), (MHLO_AbsOp $input)) + ) + >; + +/// Converts a TF::SoftsignGradOp to HLO. +/// SoftsignGrad(gradient, features) = gradient / ((1 + abs(features)) ^ 2) +def : Pattern< + (TF_SoftsignGradOp AnyRankedTensor:$gradients, AnyRankedTensor:$features), + [(CHLO_BroadcastAddOp:$add + (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (MHLO_AbsOp $features), + (BinBroadcastDimensions $one, $features) + ), + (CHLO_BroadcastDivOp + $gradients, + (MHLO_MulOp $add, $add), + (BinBroadcastDimensions $gradients, $add) + ) + ]>; + +//===----------------------------------------------------------------------===// +// Slice op patterns. +//===----------------------------------------------------------------------===// + +def UnpackStartingIndices: NativeCodeCall< + "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; + +def CanBeTranslatedToDynamicSlice : Constraint())">>; + +def TFSliceSizes2HLOSliceSizes : NativeCodeCall< + "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," + "&$_builder)">; + +def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, + (ConstantLikeMatcher AnyAttr:$slice_sizes)), + (MHLO_DynamicSliceOp $input, + (UnpackStartingIndices $op, $starting_indices), + (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), + [(CanBeTranslatedToDynamicSlice $input, $starting_indices, + $slice_sizes)]>; + +//===----------------------------------------------------------------------===// +// Select op patterns. +//===----------------------------------------------------------------------===// + + def : Pat<(TF_SelectV2Op MHLO_Tensor:$pred, MHLO_Tensor:$on_true, + MHLO_Tensor:$on_false), + (CHLO_BroadcastSelectOp $pred, $on_true, $on_false)>; + +//===----------------------------------------------------------------------===// +// PartitionedCall and LegacyCall op patterns. +//===----------------------------------------------------------------------===// + +def ArgTypesMatchCallee : Constraint< + // $0 is a resultset (possibly empty), and $_op isn't assigned. So retrieve + // the op using the builder. + CPred<"ArgTypesMatchCallee(&*$_builder.getInsertionPoint(), $1, $2)">>; + +foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { + def : Pat<(callOp:$op $args, FlatSymbolRefAttr:$f, + $config, $config_proto, $executor_type), + (CallOp $f, $args), + [(ArgTypesMatchCallee $op, $args, $f)]>; +} + +// The extra attr on this op is _disable_call_shape_inference, which we ignore +// in the bridge. +def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), + (CallOp $f, $args), + [(ArgTypesMatchCallee $op, $args, $f)]>; + +//===----------------------------------------------------------------------===// +// Reverse op patterns. +//===----------------------------------------------------------------------===// + +// Handles axis conversion for TF reverse. +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; + +def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), + (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; + +//===----------------------------------------------------------------------===// +// Unary op patterns. +//===----------------------------------------------------------------------===// + +foreach Mapping = [ + [TF_AbsOp, MHLO_AbsOp], + [TF_CeilOp, MHLO_CeilOp], + [TF_ComplexAbsOp, MHLO_AbsOp], + [TF_CosOp, MHLO_CosineOp], + [TF_ExpOp, MHLO_ExpOp], + [TF_Expm1Op, MHLO_Expm1Op], + [TF_ErfOp, MHLO_ErfOp], + [TF_FloorOp, MHLO_FloorOp], + [TF_ImagOp, MHLO_ImagOp], + [TF_InvertOp, MHLO_NotOp], + [TF_IsFiniteOp, MHLO_IsFiniteOp], + [TF_LogOp, MHLO_LogOp], + [TF_Log1pOp, MHLO_Log1pOp], + [TF_LogicalNotOp, MHLO_NotOp], + [TF_NegOp, MHLO_NegOp], + [TF_RealOp, MHLO_RealOp], + [TF_RsqrtOp, MHLO_RsqrtOp], + [TF_SigmoidOp, MHLO_LogisticOp], + [TF_SinOp, MHLO_SineOp], + [TF_SqrtOp, MHLO_SqrtOp], + [TF_TanhOp, MHLO_TanhOp], + [TF_TanOp, MHLO_TanOp] + ] in { + def : Pat<(Mapping[0] MHLO_Tensor:$input), + (Mapping[1] $input)>; +} + +foreach Mapping = [ + [TF_AcosOp, CHLO_AcosOp], + [TF_AcoshOp, CHLO_AcoshOp], + [TF_AsinOp, CHLO_AsinOp], + [TF_AsinhOp, CHLO_AsinhOp], + [TF_AtanOp, CHLO_AtanOp], + [TF_AtanhOp, CHLO_AtanhOp], + [TF_CoshOp, CHLO_CoshOp], + [TF_ConjOp, CHLO_ConjOp], + [TF_DigammaOp, CHLO_DigammaOp], + [TF_ErfcOp, CHLO_ErfcOp], + [TF_IsInfOp, CHLO_IsInfOp], + [TF_LgammaOp, CHLO_LgammaOp], + [TF_SinhOp, CHLO_SinhOp], + ] in { + def : Pat<(Mapping[0] MHLO_AnyTensor:$input), + (Mapping[1] $input)>; +} + +def : Pat<(TF_AngleOp $x), (MHLO_Atan2Op (MHLO_ImagOp $x), (MHLO_RealOp $x))>; + +// TODO(bixia): Lower with Truncate=True for floating point value conversions. +def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (MHLO_ConvertOp $arg)>; + +def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), + (MHLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; + + +// Lowering these ops with static shape to mhlo.reshape +foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { + def : Pat<(TfOp:$res MHLO_Tensor:$arg, $ignored), + (MHLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], [], + (addBenefit 2)>; +} + +// Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +def : Pat<(TF_SignOp $x), (MHLO_SignOp $x)>; + +def BothElementTypesSameWidthIntOrFloat : Constraint, + "element types must be integers or floats">; + +// TODO(mgester): Due to restrictions of xla::BitcastConvertType we currently +// only lower if both input and output types are int or float and have same width + +def : Pat<(TF_BitcastOp:$res MHLO_Tensor:$arg), + (MHLO_BitcastConvertOp $arg), + [(BothElementTypesSameWidthIntOrFloat $res, $arg)]>; + +// TODO(jpienaar): Lower constant like to constant to broadcast if dynamic +// and going to MHLO. + +//===----------------------------------------------------------------------===// +// Random ops. +//===----------------------------------------------------------------------===// +// TODO(b/148269299): handle random number generator seeds/states correctly. + +class MHLO_RngDistributionValue : + ConstantAttr; + +def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), + (MHLO_RngOp + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), + (CastValueToI64 $old, $shape), + MHLO_RngDistributionValue<"UNIFORM">), + [(IsShapedTensor $shape)]>; + +def : Pat<(TF_RandomStandardNormalOp:$old $shape, $seed, $seed2), + (MHLO_RngOp + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), + (CastValueToI64 $old, $shape), + MHLO_RngDistributionValue<"NORMAL">), + [(IsShapedTensor $shape)]>; + +//===----------------------------------------------------------------------===// +// Sigmoid grad op. +//===----------------------------------------------------------------------===// + +// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the +// shape of $l instead of having it as a constant. +def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), + (MHLO_MulOp + (MHLO_MulOp $r, $l), + (MHLO_SubtractOp (MHLO_ConstantOp (ConstantSplat<"1"> $l)), $l))>; + +//===----------------------------------------------------------------------===// +// Softplus op. +//===----------------------------------------------------------------------===// + +def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">; + +def : Pattern<(TF_SoftplusOp AnyTensor:$features), + [ + (MHLO_ExpOp:$features_exp $features), + (CHLO_BroadcastAddOp:$threshold + (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features))), + (MHLO_ConstantOp (GetScalarOfType<2> $features)), + (NullDenseI64ArrayAttr) + ), + (MHLO_SelectOp:$output + (CHLO_BroadcastCompareOp + $features, + (MHLO_NegOp $threshold), + (NullDenseI64ArrayAttr), + CHLO_ComparisonDirectionValue<"GT">, + (CHLO_DEFAULT_COMPARISON_TYPE) + ), + $features, + (MHLO_SelectOp + (CHLO_BroadcastCompareOp + $features, + $threshold, + (NullDenseI64ArrayAttr), + CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE) + ), + $features_exp, + (MHLO_Log1pOp $features_exp) + ) + ), + (replaceWithValue $output) + ]>; + +//===----------------------------------------------------------------------===// +// XlaReplicaId op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaReplicaIdOp), + (TF_CastOp (MHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; + +//===----------------------------------------------------------------------===// +// XlaGather op. +//===----------------------------------------------------------------------===// + +def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">; + +def HasValidGatherDims : Constraint>; + +def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), + $dimension_numbers, $indices_are_sorted), + (MHLO_GatherOp $operand, $start_indices, + (ToGatherDimNumsAttr $dimension_numbers), + (CastElementsToI64Elements $slice_sizes), + $indices_are_sorted), + [(HasValidGatherDims $dimension_numbers)]>; + +//===----------------------------------------------------------------------===// +// XlaDotOp op. +//===----------------------------------------------------------------------===// + +def ToDotDimNumsAttr : NativeCodeCall<"GetDotDimNumsAttr($0, &$_builder)">; + +def ToPrecisionConfigsAttr : NativeCodeCall<"GetPrecisionConfigAttr($0, &$_builder)">; + +def HasValidDotDims : Constraint>; + +def HasValidPrecisionConfig : Constraint>; + +def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), + (MHLO_DotGeneralOp $lhs, $rhs, + (ToDotDimNumsAttr $dimension_numbers), + (ToPrecisionConfigsAttr $precision_config), + (EmptyDotAlgorithmAttr)), + [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; + +//===----------------------------------------------------------------------===// +// XlaDotV2Op op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), + (MHLO_DotGeneralOp $lhs, $rhs, + (ToDotDimNumsAttr $dimension_numbers), + (ToPrecisionConfigsAttr $precision_config), + (EmptyDotAlgorithmAttr)), + [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; + +//===----------------------------------------------------------------------===// +// XlaDynamicSlice op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaDynamicSliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, + (ConstantLikeMatcher AnyAttr:$slice_sizes)), + (MHLO_DynamicSliceOp $input, + (UnpackStartingIndices $op, $starting_indices), + (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes))>; + +//===----------------------------------------------------------------------===// +// XlaEisumOp op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaEinsumOp $lhs, $rhs, $equation), + (MHLO_EinsumOp $lhs, $rhs, $equation)>; + +//===----------------------------------------------------------------------===// +// XlaOptimizationBarrierOp op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaOptimizationBarrierOp $args), + (MHLO_OptimizationBarrierOp $args)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc index ae4ee26eab9b8c..e38cad1d4c7edc 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc @@ -33,11 +33,11 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/Register.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" -#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -102,7 +102,7 @@ void TFToMhloPass::runOnOperation() { MLIRContext *context = func->getContext(); RewritePatternSet patterns(context); - mhlo::PopulateLegalizeTfPatterns(context, &patterns); + odml::PopulateLegalizeTfPatterns(context, &patterns); TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns); mhlo::Tf2XlaTypeConverter converter; mhlo::PopulateLegalizeTfWithTf2XlaPatterns( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc new file mode 100644 index 00000000000000..b120a6f02e1460 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" + +#include + +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/utils/hlo_utils.h" + +namespace mlir { +namespace odml { + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder) { + return builder->create(loc, + hlo::getScalarOfType(ty, raw_value)); +} + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder) { + return builder->create(loc, + hlo::getScalarNegZeroOfType(ty)); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) { + RankedTensorType ty = + RankedTensorType::get(static_cast(attr.size()), + IntegerType::get(attr.getContext(), 64)); + return DenseIntElementsAttr::get(ty, attr.getValue()); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h new file mode 100644 index 00000000000000..13ff4c4767721d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h @@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Builds body for reduce op by using the template binary op as the +// reducer op. +template +void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { + OpBuilder::InsertionGuard guard(*builder); + Block* block = builder->createBlock(body); + + // Block arguments are scalars of the given element type. + Type type = RankedTensorType::get(/*shape=*/{}, element_type); + Location loc = body->getLoc(); + block->addArguments({type, type}, SmallVector(2, loc)); + + auto reducer = + builder->create(loc, block->getArgument(0), block->getArgument(1)); + builder->create(loc, reducer.getResult()); +} + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder); + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder); + +// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); +DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, + Builder* builder); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc new file mode 100644 index 00000000000000..63926df535b6be --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" + +#include + +#include +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { +namespace { + +TEST(UtilsTest, GetScalarConstOfType) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + Type ty = builder.getI32Type(); + mhlo::ConstantOp op = GetScalarConstOfType(ty, loc, 123, &builder); + EXPECT_EQ(op.getValue().getValues()[0], 123); + + op->destroy(); +} + +TEST(UtilsTest, GetScalarNegZeroOfType) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + Type ty = builder.getF32Type(); + mhlo::ConstantOp op = GetScalarNegZeroOfType(ty, loc, &builder); + EXPECT_EQ(op.getValue().getValues()[0], -0.f); + + op->destroy(); +} + +TEST(UtilsTest, GetI64ElementsAttr) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + ArrayRef values = {1, 2, 3}; + auto valuesAttr = builder.getI64ArrayAttr(values); + DenseIntElementsAttr attr = GetI64ElementsAttr(valuesAttr); + EXPECT_EQ(attr.getValues()[0], 1); + EXPECT_EQ(attr.getValues()[1], 2); + EXPECT_EQ(attr.getValues()[2], 3); +} + +TEST(UtilsTest, GetI64ElementsAttrBuilder) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + ArrayRef values = {1, 2, 3}; + DenseIntElementsAttr attr = GetI64ElementsAttr(values, &builder); + EXPECT_EQ(attr.getValues()[0], 1); + EXPECT_EQ(attr.getValues()[1], 2); + EXPECT_EQ(attr.getValues()[2], 3); +} + +} // namespace + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index e67e0e45961117..041aa03134924c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -64,6 +64,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 4231990e0769d1..1939a3dc8cd875 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -2654,3813 +2654,3 @@ func.func @sigmoid_grad_dynamic(%arg0: tensor, %arg1: tensor) -> t func.return %0 : tensor } -// ----- - -// CHECK-LABEL: @sin -func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.sine %arg0 : tensor<2xf32> - %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @sin_dynamic -func.func @sin_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.sine %arg0 : tensor - %0 = "tf.Sin"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @rsqrt -func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.rsqrt %arg0 : tensor<2xf32> - %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @rsqrt_dynamic -func.func @rsqrt_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.rsqrt %arg0 : tensor - %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @sqrt -func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.sqrt %arg0 : tensor<2xf32> - %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @sqrt_dynamic -func.func @sqrt_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.sqrt %arg0 : tensor - %0 = "tf.Sqrt"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @tanh -func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.tanh %arg0 : tensor<2xf32> - %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @tanh_dynamic -func.func @tanh_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.tanh %arg0 : tensor - %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @bitcast -func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xf32> - %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @bitcast_dynamic -func.func @bitcast_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor) -> tensor - %0 = "tf.Bitcast"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @bitcast_same_widths -func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xi32> - %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: func @bitcast_smaller_input_width -func.func @bitcast_smaller_input_width(%arg0: tensor<8xi8>) -> tensor { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor<8xi8>) -> tensor - %0 = "tf.Bitcast"(%arg0) : (tensor<8xi8>) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @bitcast_smaller_output_width -func.func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2x2xf16> { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2x2xf16> - %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf16> - func.return %0 : tensor<2x2xf16> -} - -// ----- - -// CHECK-LABEL: squeeze -func.func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { - // CHECK: mhlo.reshape - %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> - func.return %0 : tensor<1x10xf32> -} - -// ----- - -// CHECK-LABEL: squeeze_ranked -func.func @squeeze_ranked(%arg0: tensor) -> tensor { - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor - // CHECK: %[[T:.*]] = tensor.from_elements %[[D2]] : tensor<1xindex> - // CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor, tensor<1xindex>) -> tensor - // CHECK: return %[[R]] : tensor - %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [0, 1] }: (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: squeeze_ranked_negative -func.func @squeeze_ranked_negative(%arg0: tensor) -> tensor { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor - // CHECK: %[[T:.*]] = tensor.from_elements %[[D0]], %[[D2]] : tensor<2xindex> - // CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor, tensor<2xindex>) -> tensor - // CHECK: return %[[R]] : tensor - %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [-2] }: (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: squeeze_ranked_dynamic -func.func @squeeze_ranked_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Squeeze" - %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: squeeze_dynamic -func.func @squeeze_dynamic(%arg0: tensor) -> tensor<*xf32> { - // CHECK: "tf.Squeeze" - %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: expand_dims -func.func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { - // CHECK: mhlo.reshape - %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> - func.return %0 : tensor<1x2xf32> -} - -// ----- - -// CHECK-LABEL: expand_dims_dynamic -func.func @expand_dims_dynamic(%arg0: tensor) -> tensor { - %axis = "tf.Const"() {value = dense<1> : tensor} : () -> (tensor) - - // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[CST0:.+]] = arith.constant 0 - // CHECK-DAG: %[[CST1:.+]] = arith.constant 1 - // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] - // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1 - // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] - // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]] - // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]] - %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor, tensor) -> tensor - - // CHECK: return %[[RESHAPE]] - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: expand_dynamic_dims_rank1_axis -func.func @expand_dynamic_dims_rank1_axis(%arg0: tensor) -> tensor { - %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - - // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[CST0:.+]] = arith.constant 0 - // CHECK-DAG: %[[CST1:.+]] = arith.constant 1 - // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] - // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1 - // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] - // CHECK-DAG: %[[CST2:.+]] = arith.constant 2 - // CHECK-DAG: %[[GETEXTENT2:.+]] = tensor.extract %[[SHAPEOF]][%[[CST2]]] - // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]], %[[GETEXTENT2]] - // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]] - %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor, tensor<1xi32>) -> tensor - - // CHECK: return %[[RESHAPE]] - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @sign -// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> -func.func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { - // CHECK: [[SIGN:%.*]] = mhlo.sign [[ARG]] - // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32> - %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) - func.return %0 : tensor<1x2x3x4xf32> -} - -// ----- - -// CHECK-LABEL: func @sign_dynamic -func.func @sign_dynamic(%arg0: tensor) -> tensor { - // CHECK: [[SIGN:%.*]] = mhlo.sign %arg0 : tensor - // CHECK: return [[SIGN]] : tensor - %0 = "tf.Sign"(%arg0) : (tensor) -> (tensor) - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: slice_constant_start -func.func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : - // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> - // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) - // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : - // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> - // CHECK: return %[[RESULT]] : tensor<2xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: slice_i32_consts -func.func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: slice_constant_start_negative_one_size -func.func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<3> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<3xi32> - // CHECK: return %[[RESULT]] : tensor<3xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} - -// ----- - -// CHECK-LABEL: slice_constant_start_dynamic_shape -func.func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor - // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice" - // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]]) - // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : - // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: slice_variable_start -func.func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1) - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor - // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1) - // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: slice_mhlo_sizes -func.func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { - // CHECK-NOT: "tf.Slice" - %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> - %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> - func.return %1 : tensor<1x512x4xf32> -} - -// ----- - -// CHECK-LABEL: slice_variable_start_negative_one_size -func.func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[RESULT:.*]] = "tf.Slice" - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: slice_real_dynamic_slice -func.func @slice_real_dynamic_slice(%arg0: tensor<4xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor { - // CHECK: tensor.extract {{.*}} : tensor<1xi64> - // CHECK: tensor.extract {{.*}} : tensor<1xi64> - // CHECK: arith.index_cast {{.*}} : index to i64 - // CHECK: arith.cmpi eq, {{.*}} : i64 - // CHECK: arith.addi {{.*}} : i64 - // CHECK: tensor.dim {{.*}} : tensor<4xi32> - // CHECK: arith.index_cast {{.*}} : index to i64 - // CHECK: select {{.*}} : i64 - // CHECK: arith.index_cast {{.*}} : i64 to index - // CHECK: arith.index_cast {{.*}} : i64 to index - // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> - %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// StridedSlice op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: simple_strided_slice -func.func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { - %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: mhlo.slice - // CHECK-DAG-SAME: start_indices = dense<[0, 1]> - // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> - // CHECK-DAG-SAME: strides = dense<[1, 3]> - // CHECK-SAME: -> tensor<3x2xf32> - - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> - func.return %output : tensor<3x2xf32> -} - -// ----- - -// CHECK-LABEL: dynamic_strided_slice -func.func @dynamic_strided_slice(%input: tensor) -> tensor { - %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: "tf.StridedSlice" - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor - func.return %output : tensor -} - -// ----- - -// CHECK-LABEL: strided_slice_negative_indices -func.func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { - %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> - - // CHECK: mhlo.slice - // CHECK-DAG-SAME: start_indices = dense<[0, 1]> - // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> - // CHECK-DAG-SAME: strides = dense<[1, 3]> - // CHECK-SAME: -> tensor<3x2xf32> - - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> - func.return %output : tensor<3x2xf32> -} - -// ----- - -// CHECK-LABEL: dynamic_strided_slice_negative_indices -func.func @dynamic_strided_slice_negative_indices(%input: tensor) -> tensor { - %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: tf.StridedSlice - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor - func.return %output : tensor -} - -// ----- - -// CHECK-LABEL: strided_slice_range_clamping -func.func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> { - %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: mhlo.slice - // CHECK-DAG-SAME: start_indices = dense<[0, 0]> - // CHECK-DAG-SAME: limit_indices = dense<[1, 8]> - // CHECK-DAG-SAME: strides = dense<[1, 3]> - // CHECK-SAME: -> tensor<1x3xf32> - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32> - func.return %output : tensor<1x3xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_empty -func.func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> { - %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - - // CHECK: mhlo.constant dense<> : tensor<0xf32> - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32> - func.return %output : tensor<0xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_begin_end_mask -// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32> -func.func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { - - // For StridedSlice - // Dim #: 0, 1, 2 - // Input shape: [4, 128, 1024] - // Begin: 1, 4, -3 - // End: 8, 65, 42 - // Stride: 1, 4, -1 - // Begin mask: 0, 0, 1 (= 1) - // End mask: 1, 0, 0 (= 4) - - // So result shape: - // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 - // Dim #1: 4 to 65 stride 4: so 16 - // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 - // result shape: [4, 16, 1022] - - %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - - // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]]) - - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]]) - // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]> - // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]> - // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> - // CHECK-SAME: -> tensor<4x16x1022xf32> - - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32> - - // CHECK: mhlo.reshape %[[SLICE]] - // CHECK-SAME: -> tensor<4x16x1022xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: strided_slice_shrink_axis_mask -// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32> -func.func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) { - - // For StridedSlice - // Dim #: 0, 1, 2 - // Input shape: [4, 128, 1024] - // Begin: 1, 4, -3 - // End: 8, 65, 42 - // Stride: 1, 4, -1 - // Begin mask: 1, 0, 0 (= 1) - // End mask: 0, 0, 1 (= 4) - // Shrink axis mask: 1, 0, 1 (= 5) - - // So result shape: - // Dim #0: shrink axis, take value at [1] - // Dim #1: 4 to 65 stride 4: so 16 - // Dim #2: shrink axis, take value at [-3] - // result shape: [16] - - // As output shape of StridedSlice differs, a reshape will follow. - - %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) - // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]> - // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]> - // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> - // CHECK-SAME: -> tensor<1x16x1xf32> - - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32> - - // CHECK: mhlo.reshape %[[SLICE]] - // CHECK-SAME: -> tensor<16xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: strided_slice_ellipsis_mask -// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> -func.func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) { - // For StridedSlice input[1, ..., 8:, :10, 2:6:2] - // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized - // slice input[1, :, :, 8:, :10, 2:6:2] - - // The start, limit indices and strides attributes of mhlo.slice would - // reflect the canonicalized slice. - // As output shape of StridedSlice differs, a reshape will follow. - - %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) - - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) - // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> - // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> - // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> - // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32> - - // CHECK: mhlo.reshape %[[SLICE]] - // CHECK-SAME: -> tensor<4x8x8x10x2xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: strided_slice_new_axis_mask -// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> -func.func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { - // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] - // New axis mask is at index 1 and 6 of sparse spec, so - // new_axis_mask = 2^1 + 2^6 = 66 - // The ellipsis mask is applied to dim #1, #2 of input i.e, we get - // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] - // This is then reshaped to add the new axes. - - // The start, limit indices and strides attributes of mhlo.slice would - // reflect the canonicalized slice. - // As output shape of StridedSlice differs, a reshape will follow to reflect - // new axes added. - - %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) - %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) - %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) - - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) - // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> - // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> - // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> - // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32> - - // CHECK: mhlo.reshape %[[SLICE]] - // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: strided_slice_implicit_ellipsis_mask( -// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> -func.func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { - // StridedSlice gets input[8:10], which is same as input[8:10, ...] - // The start_indices, limit_indices, and strides attribute of mhlo.slice - // reflect the canonicalized slice. - %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> - %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> - %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]]) - // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> - // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> - // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[SLICE]] : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> - // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> - func.return %0 : tensor<2x16x2xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end -func.func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) { - // In this case, the `begin` and `end` inputs are unknown at compile time -- - // so the StridedSlice needs to slice these vectors and use that as input to - // an HLO dynamic slice. - %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - // CHECK: %[[A:.*]] = mhlo.reshape %arg0 : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]]) - // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> - // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> - // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] - // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : - // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor - // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" - // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) - // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : - // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x1x97xi32> - // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32> - %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_with_start_end_mask -// CHECK-SAME: (%[[INPUT:.*]]: tensor<32x1x97xi32>, %[[BEGIN:.*]]: tensor<3xi32>, %[[END:.*]]: tensor<3xi32>) -func.func @strided_slice_nonconstant_begin_end_with_start_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<3xi32>, %end: tensor<3xi32>) -> (tensor<1x97xi32>) { - %strides = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> - - // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> - // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> - // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] - // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : - // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor - // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" - // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) - // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : - // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x1x97xi32> - // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32> - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1 -func.func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) { - // Dynamic stride: when `begin` and `end` inputs are unknown at compile time, - // `strides` must be known. - // CHECK: tf.StridedSlice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2 -func.func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { - // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown - // at compile time, `strides` must be known to have all 1 values. - %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: tf.StridedSlice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count -func.func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> { - %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> - // When begin/end are dynamic, the number of output elements must be equal to - // the number of input elements sliced. - // CHECK: tf.StridedSlice - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32> - func.return %0 : tensor<6x10xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask -func.func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { - // This ellipsis mask is not supported because it does not refer to the last - // dimension. - // [0, 1, 0] = 2 - %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: tf.StridedSlice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask -func.func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { - // This ellipsis mask is supported because it refers to the last dimension. - // [1, 0, 0] = 4 - %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: mhlo.dynamic_slice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask -func.func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { - // This shrink_axis mask is supported because it refers to a major dimension. - // [1, 1, 1] = 7 - %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: mhlo.dynamic_slice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -//===----------------------------------------------------------------------===// -// Reduction op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @mean -func.func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = array} : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[MEAN]] : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @mean_scalar_dim -func.func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // Verify that tf.Mean op with scalar attributes are lowered successfully. - - // CHECK-NOT: tf.Mean - %dimension = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @mean_dynamic -func.func @mean_dynamic(%arg0: tensor) -> tensor { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor) -> tensor - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor, tensor) -> tensor - // CHECK: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<2xindex> - // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index - // CHECK: %[[REDUCED_DIM:.*]] = tensor.extract %[[SHAPE0]][%[[C1_2]]] : tensor<2xindex> - // CHECK: %[[MUL:.*]] = arith.muli %[[C1_1]], %[[REDUCED_DIM]] : index - // CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[MUL]] : index to i64 - // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor - // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor - // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = array} : (tensor, tensor) -> tensor - // CHECK: %[[MEAN_CONVERTED:.*]] = mhlo.convert %[[MEAN]] : (tensor) -> tensor - // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor -> tensor<1xindex> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[UNREDUCED_DIM:.*]] = tensor.extract %[[SHAPE1]][%[[C0]]] : tensor<1xindex> - // CHECK: %[[RESULT_SHAPE:.*]] = tensor.from_elements %[[UNREDUCED_DIM]], %[[C1]] : tensor<2xindex> - // CHECK: %[[RESULT:.*]] = mhlo.dynamic_reshape %[[MEAN_CONVERTED]], %[[RESULT_SHAPE]] : (tensor, tensor<2xindex>) -> tensor - // CHECK: return %[[RESULT]] : tensor - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor, tensor<1xi64>) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @sum -func.func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @sum_dynamic -func.func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x?xf16>) -> tensor<4x?xf32> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x?xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @max -func.func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x8xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @max_qint -// Regression test to ensure we don't crash getting the initial value for -// tf.Max when using quantized integer types. -func.func @max_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> - func.return %0 : tensor<4x1x!tf_type.qint8> -} - -// ----- - -// CHECK-LABEL: func @max_dynamic -func.func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x?xf16> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x?xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @min -func.func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.minimum across dimensions = [1] : (tensor<4x8xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @min_qint -// Regression test to ensure we don't crash getting the initial value for -// tf.Min when using quantized integer types. -func.func @min_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> - func.return %0 : tensor<4x1x!tf_type.qint8> -} - -// ----- - -// CHECK-LABEL: func @prod -func.func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.multiply across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @prod_qint -// Regression test to ensure we don't crash getting the initial value for -// tf.Prod when using quantized integer types. -func.func @prod_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> - func.return %0 : tensor<4x1x!tf_type.qint8> -} - -// ----- - -// CHECK-LABEL: @all -func.func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[INIT:.*]] = mhlo.constant dense : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.and across dimensions = [1] : (tensor<4x8xi1>, tensor) -> tensor<4xi1> - %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- - -// CHECK-LABEL: @all_keep_dim -func.func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { - // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1> - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> - func.return %0 : tensor<4x1xi1> -} - -// ----- - -// CHECK-LABEL: @all_dynamic -func.func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1> - // CHECK: mhlo.reduce(%[[ARG]] - %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> - func.return %0 : tensor<4x1xi1> -} - -// ----- - -// CHECK-LABEL: @any -func.func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[INIT:.*]] = mhlo.constant dense : tensor - // CHECK: mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.or across dimensions = [1] : (tensor<4x8xi1>, tensor) -> tensor<4xi1> - %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- - -// CHECK-LABEL: @any_keep_dim -func.func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { - // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1> - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> - func.return %0 : tensor<4x1xi1> -} - -// ----- - -// CHECK-LABEL: @any_dynamic -func.func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1> - // CHECK: mhlo.reduce(%[[ARG]] - %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> - func.return %0 : tensor<4x1xi1> -} - -//===----------------------------------------------------------------------===// -// Tile op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @tile_by_reshape -func.func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { - // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[BROADCASTED]] : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> - // CHECK: return %[[RESULT]] : tensor<28x24xf32> - %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> - %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32> - func.return %0 : tensor<28x24xf32> -} - -// ----- - -// CHECK-LABEL: func @tile_just_broadcast -func.func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { - // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<7x3xf32> - // CHECK: return %[[RESULT]] : tensor<7x3xf32> - %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> - %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32> - func.return %0 : tensor<7x3xf32> -} - -// ----- - -// CHECK-LABEL: func @tile_dynamic_shape -func.func @tile_dynamic_shape(%arg0: tensor) -> tensor { - %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi32> } : () -> tensor<2xi32> - // CHECK: tensor.dim {{.*}} : tensor - // CHECK: tensor.from_elements {{.*}} : tensor<4xindex> - // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor, tensor<4xindex>) -> tensor - // CHECK: muli {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<2xindex> - // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xindex>) -> tensor - %0 = "tf.Tile"(%arg0, %multiples) : (tensor, tensor<2xi32>) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// ArgMax/ArgMin op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0 -func.func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor - // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x7xi32> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: (%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) - // CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: %[[RESULT1:.*]] = mhlo.select %[[COMPARE]], %[[ARG1]], %[[ARG3]] : tensor, tensor - // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] - // CHECK: %[[RESULT2:.*]] = mhlo.select %[[COMPARE]], %[[ARG2]], %[[ARG4]] : tensor, tensor - // CHECK: %[[RESULT3:.*]] = mhlo.select %[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]] : tensor, tensor - // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor, tensor - // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> - %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor) -> tensor<7xi32> - func.return %0 : tensor<7xi32> -} - -// ----- - -// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1 -func.func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor - // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xf32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x7xi64> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> - %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor) -> tensor<3xi64> - func.return %0 : tensor<3xi64> -} - -// ----- - -// CHECK-LABEL: func @argmax_i1_input_i64_output_axis_1 -func.func @argmax_i1_input_i64_output_axis_1(%arg0: tensor<3x7xi1>) -> tensor<3xi64> { - // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense : tensor - // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi1> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x7xi64> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> - %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi1>, tensor) -> tensor<3xi64> - func.return %0 : tensor<3xi64> -} - -// ----- - -// CHECK-LABEL: func @argmax_dynamic_shape_input_output -func.func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor - // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x?xi32> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: return %[[REDUCE]]#1 : tensor - %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @argmax_dynamic_shape_input -func.func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { - // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor - // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x?xi32> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: return %[[REDUCE]]#1 : tensor<3xi32> - %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} - -// ----- - -// CHECK-LABEL: func @argmin_i64_input_i32_output_axis_0 -func.func @argmin_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<9223372036854775807> : tensor - // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x7xi32> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: (%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) - // CHECK: %[[COMPARE:.*]] = mhlo.compare LE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: %[[RESULT1:.*]] = mhlo.select %[[COMPARE]], %[[ARG1]], %[[ARG3]] : tensor, tensor - // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] - // CHECK: %[[RESULT2:.*]] = mhlo.select %[[COMPARE]], %[[ARG2]], %[[ARG4]] : tensor, tensor - // CHECK: %[[RESULT3:.*]] = mhlo.select %[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]] : tensor, tensor - // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor, tensor - // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> - %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %0 = "tf.ArgMin"(%arg0, %axis) : (tensor<3x7xi64>, tensor) -> tensor<7xi32> - func.return %0 : tensor<7xi32> -} - -//===----------------------------------------------------------------------===// -// Random op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @rng_uniform -func.func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { - // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> - %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> - // CHECK: return %[[F32]] - func.return %0 : tensor<12x?x64xf32> -} - -// ----- - -// CHECK-LABEL: func @random_uniform_simple -func.func @random_uniform_simple(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { - // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> - %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> - // CHECK: return %[[F32]] - func.return %0 : tensor<12x?x64xf32> -} - -// ----- - -// CHECK-LABEL: func @random_uniform_with_seeds -func.func @random_uniform_with_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12x64xf32> { - // CHECK: %0 = mhlo.constant dense<[32, 12, 12, 64]> : tensor<4xi32> - // CHECK-NEXT: %1 = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %2 = mhlo.constant dense<1.000000e+00> : tensor - // CHECK-NEXT: %3 = mhlo.convert %0 : (tensor<4xi32>) -> tensor<4xi64> - // CHECK-NEXT: %4 = "mhlo.rng"(%1, %2, %3) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> - %cst = "tf.Const"() {value = dense<[32, 12, 12, 64]> : tensor<4xi32>} : () -> tensor<4xi32> - %0 = "tf.RandomUniform"(%cst) {seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<4xi32>) -> tensor<32x12x12x64xf32> - // CHECK: return %4 : tensor<32x12x12x64xf32> - func.return %0 : tensor<32x12x12x64xf32> -} - -// ----- - -// CHECK-LABEL: func @rng_std_normal -func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { - // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*NORMAL.*}} -> tensor<12x?x64xf32> - %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> - // CHECK: return %[[F32]] - func.return %0 : tensor<12x?x64xf32> -} - -//===----------------------------------------------------------------------===// -// Range op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @range -// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor -func.func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { - %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor - // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" - // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = array} - // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} - %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> - func.return %3 : tensor<5xf32> -} - -// ----- - -// CHECK-LABEL: func @range_dynamic -// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor -func.func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 - // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]] - // CHECK-DAG: [[CONVERT_1:%.+]] = mhlo.convert [[ABS1]] - // CHECK-DAG: [[CONVERT_2:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT_1]], [[CONVERT_2]] - // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] - // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert [[CEIL]] - // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT_3]] - // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) <{iota_dimension = 0 : i64}> - // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 - // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} - %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor - - // CHECK: return [[ADD]] - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: func @range_int_dynamic -// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor -func.func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 - // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]] - // CHECK-DAG: [[CONVERT_1:%.+]] = mhlo.convert [[ABS1]] - // CHECK-DAG: [[CONVERT_2:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT_1]], [[CONVERT_2]] - // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] - // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert [[CEIL]] - // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT_3]] - // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) <{iota_dimension = 0 : i64}> - // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 - // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} - %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor - - // CHECK: return [[ADD]] - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: func @linspace_static -// CHECK-SAME: [[START:%.*]]: tensor, [[STOP:%.*]]: tensor -func.func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { - // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4> - // CHECK-DAG: [[NUM_F32:%.*]] = mhlo.convert [[NUM]] - // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] - // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> - // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = array} - // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} - // CHECK: return [[LINSPACE]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor - %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @linspace_dynamic -func.func @linspace_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: "tf.LinSpace" - %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @linspace_invalid_num -func.func @linspace_invalid_num(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: mhlo.constant dense<> : tensor<0xi32> - // CHECK: "tf.LinSpace" - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> - %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor<0xi32>) -> tensor - func.return %1 : tensor -} - -//===----------------------------------------------------------------------===// -// LegacyCall op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -func.func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { - func.return %arg0: tensor<10x2xf32> -} - -// CHECK-LABEL: testSimpleLegacyCallOp -func.func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { - // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32> - %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32> - // CHECK: return %[[RESULT]] - func.return %0: tensor<10x2xf32> -} - -// ----- - -func.func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { - func.return %arg0: tensor<10x2xf32> -} - -// CHECK-LABEL: testMultiInputLegacyCallOp -func.func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { - // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> - %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> - // CHECK: return %[[RESULT]] - func.return %0: tensor<10x2xf32> -} - -//===----------------------------------------------------------------------===// -// Conv op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: conv_simple -func.func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> { - - // CHECK: mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] - // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[0, 1], [2, 3]], rhs_dilate = [2, 3]} - // CHECK-SAME: batch_group_count = 1 - // CHECK-SAME: feature_group_count = 2 - - %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> - func.return %0 : tensor<256x8x7x16xf32> -} - -// ----- - -// CHECK-LABEL: conv3d_simple -func.func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> { - - // CHECK: mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] - // CHECK-SAME{LITERAL}: window = {stride = [5, 6, 7], pad = [[1, 2], [2, 3], [2, 3]], rhs_dilate = [2, 3, 4]} - // CHECK-SAME: batch_group_count = 1 - // CHECK-SAME: feature_group_count = 2 - - %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> - func.return %0 : tensor<256x7x6x5x16xf32> -} - -// ----- - -// CHECK-LABEL: depthwiseconv_simple -func.func @depthwiseconv_simple(%arg0: tensor, %arg1: tensor<2x2x3x3xf32>) -> tensor { - // CHECK: %[[RESHAPED_FILTER:.*]] = mhlo.reshape %arg1 : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> - // CHECK: mhlo.convolution(%arg0, %[[RESHAPED_FILTER]]) - // CHECK-SAME: feature_group_count = 3 - %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { - data_format = "NHWC", - device = "", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1] - } : (tensor, tensor<2x2x3x3xf32>) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: conv_valid_padding -func.func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { - // CHECK: mhlo.convolution(%arg0, %arg1) - - %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> - func.return %0 : tensor<1x2x3x1xf32> -} - -// ----- - -// CHECK-LABEL: conv_explicit_paddings -func.func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> { - - // CHECK: mhlo.convolution(%arg0, %arg1) - // CHECK-SAME{LITERAL}: pad = [[6, 0], [3, 3]] - - %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> - func.return %0 : tensor<256x9x7x16xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_input_dynamic -func.func @conv2d_backprop_input_dynamic(%filter: tensor<2x2x1x16xf32>, %out_backprop: tensor) -> tensor { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) - // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - // CHECK: return %[[RESULT]] - %cst_0_1d = "tf.Const"() {device = "", value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_1_0d = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor - %cst_1_1d = "tf.Const"() {device = "", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_512_0d = "tf.Const"() {device = "", value = dense<512> : tensor} : () -> tensor - %out_backprop_shape = "tf.Shape"(%out_backprop) {device = ""} : (tensor) -> tensor<4xi32> - %batch_size = "tf.StridedSlice"(%out_backprop_shape, %cst_0_1d, %cst_1_1d, %cst_1_1d) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %input_shape = "tf.Pack"(%batch_size, %cst_512_0d, %cst_512_0d, %cst_1_0d) {axis = 0 : i64, device = ""} : (tensor, tensor, tensor, tensor) -> tensor<4xi32> - %result = "tf.Conv2DBackpropInput"(%input_shape, %filter, %out_backprop) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<2x2x1x16xf32>, tensor) -> tensor - return %result : tensor -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_input -func.func @conv2d_backprop_input( - %filter: tensor<3x3x1x32xf32>, - %out_backprop: tensor<100x26x26x32xf32> - ) -> tensor<100x28x28x1xf32> { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) - // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - // CHECK: return %[[RESULT]] - %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32> - %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { - data_format = "NHWC", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1], - use_cudnn_on_gpu = true - } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> - func.return %result : tensor<100x28x28x1xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_input_grouped -func.func @conv2d_backprop_input_grouped( - %filter: tensor<2x2x5x21xf32>, - %out_backprop: tensor<5x2x2x21xf32> - ) -> tensor<5x3x3x15xf32> { - %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32> - - // Verify filter transformation for grouped convolution. - - // CHECK: %[[RESHAPE:.*]] = mhlo.reshape %arg0 : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32> - // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]]) - // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]> - // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32> - // CHECK: mhlo.reshape %[[TRANSPOSE]] : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32> - - %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { - data_format = "NHWC", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1], - use_cudnn_on_gpu = true - } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32> - func.return %result : tensor<5x3x3x15xf32> -} - - -// CHECK-LABEL: @conv3d_backprop_input -func.func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) - // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, o, i]->[b, 0, 1, 2, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64, - // CHECK-SAME: feature_group_count = 1 : i64 - - // CHECK: return %[[RESULT]] - %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32> - %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> - func.return %result : tensor<2x8x8x8x1xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_filter -func.func @conv2d_backprop_filter( - %input: tensor<100x28x28x1xf32>, - %out_backprop: tensor<100x26x26x32xf32> - ) -> tensor<3x3x1x32xf32> { - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - // CHECK: return %[[RESULT]] - %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32> - %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { - data_format = "NHWC", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1], - use_cudnn_on_gpu = true - } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32> - func.return %result : tensor<3x3x1x32xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_filter_grouped -func.func @conv2d_backprop_filter_grouped( - %input: tensor<1x2x2x2xf32>, - %out_backprop: tensor<1x1x1x2xf32> - ) -> tensor<2x2x1x2xf32> { - - // CHECK: mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: batch_group_count = 2 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - - %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32> - %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { - data_format = "NHWC", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1], - use_cudnn_on_gpu = true - } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32> - func.return %result : tensor<2x2x1x2xf32> -} - - -// CHECK-LABEL: @conv3d_backprop_filter -func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> { - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - // CHECK: return %[[RESULT]] - %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> - %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> - func.return %result : tensor<3x3x3x1x6xf32> -} - -// ----- - -// CHECK-LABEL: @collective_permute -func.func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { - %source_target_pairs = "tf.Const" () { - value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32> - } : () -> tensor<3x2xi32> - - // CHECK: "mhlo.collective_permute" - // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> - %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) { - } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32> - - func.return %0 : tensor<128x32xf32> -} - -// ----- - -// CHECK-LABEL: @cross_replica_sum -func.func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { - %replica_groups = "tf.Const" () { - value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> - } : () -> tensor<2x4xi32> - - // CHECK: mhlo.cross-replica-sum - // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32> - func.return %result : tensor<10xf32> -} - -// ----- - -// CHECK-LABEL: conv_dynamic -func.func @conv_dynamic(%arg0: tensor, %arg1: tensor<3x3x3x16xf32>) -> tensor { - // CHECK: "mhlo.dynamic_conv" - // CHECK-SAME: <{batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>}> : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor - %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor, tensor<3x3x3x16xf32>) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.Split legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @split_not_match_dynamic_split_dim_input -func.func @split_not_match_dynamic_split_dim_input(%input: tensor<4x4xf32>, %split_dim: tensor) -> (tensor<*xf32>, tensor<*xf32>) { - // CHECK: tf.Split - %0:2 = "tf.Split"(%split_dim, %input) : (tensor, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) - func.return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @split_not_match_dynamic_input_shape -func.func @split_not_match_dynamic_input_shape(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { - %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> - // CHECK: arith.divsi {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> - // CHECK: muli {{.*}} : index - // CHECK: muli {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> - %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) - func.return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32> -} - -// ----- - -// CHECK-LABEL: @split_not_match_static_split_dim_size -func.func @split_not_match_static_split_dim_size(%input: tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> - // CHECK: arith.divsi {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> - // CHECK: muli {{.*}} : index - // CHECK: muli {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> - %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) - func.return %0#0, %0#1 : tensor<2x?x4xf32>, tensor<2x?x4xf32> -} - -// ----- - -// CHECK-LABEL: @split_match_and_split_into_two -func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<2x6xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<2x6xf32> - %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) - // CHECK: return %[[ONE]], %[[TWO]] - func.return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> -} - -// ----- - -// CHECK-LABEL: @split_match_and_split_into_three -// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) -func.func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { - %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> - %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) - // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] - func.return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> -} - -//===----------------------------------------------------------------------===// -// tf.TopKV2 legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: topk_v2_non_const_k -func.func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor) -> (tensor, tensor) { - // CHECK: tf.TopKV2 - %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor) -> (tensor, tensor) - func.return %0#0, %0#1: tensor, tensor -} - -// ----- - -// CHECK-LABEL: topk_v2_unknown_input_last_dim -func.func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) { - %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor - // CHECK: tf.TopKV2 - %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor) -> (tensor<16x?xf32>, tensor<16x?xi32>) - func.return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32> -} - -// ----- - -// CHECK-LABEL: topk_v2 -// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32> -func.func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { - %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor - - // CHECK: chlo.top_k(%[[INPUT]], k = 8) - %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) - func.return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> -} - -//===----------------------------------------------------------------------===// -// tf.SplitV legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @splitv_match_and_split_into_three -// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) -func.func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { - %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x1xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x3xf32> - %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) - // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] - func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> -} - -// ----- - -// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes -func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { - %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64> - // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64> - // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64> - %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) - func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> -} - -// ----- - -// CHECK-LABEL: @splitv_dynamic -func.func @splitv_dynamic(%input: tensor) -> (tensor, tensor, tensor) { - %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: tf.SplitV - %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) - func.return %0#0, %0#1, %0#2 : tensor, tensor, tensor -} - -//===----------------------------------------------------------------------===// -// tf.Assert legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @assert -func.func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { - // CHECK-NOT: tf.Assert - "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor, tensor<*xf32>) -> () - func.return -} - -//===----------------------------------------------------------------------===// -// tf.Unpack legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @unpack -func.func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { - // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES1:.*]] = mhlo.reshape %[[SLICE1]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES2:.*]] = mhlo.reshape %[[SLICE2]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES3:.*]] = mhlo.reshape %[[SLICE3]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - - %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) - // return %[[RES1]], %[[RES2]], %[[RES3]] - func.return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32> -} - -// ----- - -// CHECK-LABEL: func @unpack_dynamic -func.func @unpack_dynamic(%arg0: tensor) -> (tensor, tensor) { - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor - // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> - // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xi32>) -> tensor - // CHECK: tensor.from_elements {{.*}} : tensor<3xi32> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor - // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> - // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xi32>) -> tensor - // CHECK: return {{.*}} : tensor, tensor - %0:2 = "tf.Unpack"(%arg0) {axis = -1 : i64} : (tensor) -> (tensor, tensor) - func.return %0#0, %0#1 : tensor, tensor -} - -//===----------------------------------------------------------------------===// -// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @unsorted_segment_sum -// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32> -// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32> -func.func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) { - %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) <{broadcast_sizes = dense<[4, 64]> : tensor<2xi64>}> : (tensor) -> tensor<4x64xf32> - // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) - // CHECK-SAME: indices_are_sorted = false, - // CHECK-SAME: scatter_dimension_numbers = - // CHECK-SAME: update_window_dims = [2] - // CHECK-SAME: inserted_window_dims = [0] - // CHECK-SAME: scatter_dims_to_operand_dims = [0] - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: unique_indices = false - // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor - // CHECK: mhlo.return [[ADD]] - // CHECK-NEXT: (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> - // CHECK: return [[SCATTER]] - %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor) -> (tensor<4x64xf32>) - func.return %0: tensor<4x64xf32> -} - -// ----- - -// CHECK-LABEL: @unsorted_segment_prod -// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32> -// CHECK-SAME: [[SI:%.*]]: tensor -func.func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { - %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) <{broadcast_sizes = dense<[4, 64]> : tensor<2xi64>}> : (tensor) -> tensor<4x64xf32> - // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers = - // CHECK-SAME: update_window_dims = [2] - // CHECK-SAME: inserted_window_dims = [0] - // CHECK-SAME: scatter_dims_to_operand_dims = [0] - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: unique_indices = false - // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor - // CHECK: mhlo.return [[MUL]] - // CHECK-NEXT: (tensor<4x64xf32>, tensor, tensor<8x?x64xf32>) -> tensor<4x?xf32> - // CHECK: return [[SCATTER]] - %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) - func.return %0: tensor<4x?xf32> -} - -// ----- - -// CHECK-LABEL: @unsorted_segment_min -func.func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { - %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: mhlo.constant dense<3.40282347E+38> : tensor - // CHECK: mhlo.scatter - // CHECK: mhlo.minimum - %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) - func.return %0: tensor<4x?xf32> -} - -// ----- - -// CHECK-LABEL: @unsorted_segment_max -func.func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { - %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor - // CHECK: mhlo.scatter - // CHECK: mhlo.maximum - %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) - func.return %0: tensor<4x?xf32> -} - -//===----------------------------------------------------------------------===// -// tf.GatherNd legalization -//===----------------------------------------------------------------------===// -// CHECK-LABEL: func @gatherNd_dynamic -func.func @gatherNd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: tensor.dim - // CHECK: index_cast - // CHECK: tensor.from_elements - // CHECK: mhlo.dynamic_gather - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [2] - // CHECK-SAME: collapsed_slice_dims = [0, 1] - // CHECK-SAME: start_index_map = [0, 1] - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: indices_are_sorted = false - %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @gatherNd_static -func.func @gatherNd_static(%arg0: tensor<2x4x128xf32>, %arg1: tensor<2x1xi32>) -> tensor<2x4x128xf32> { - // CHECK: "mhlo.gather"({{.*}}) <{ - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [1, 2] - // CHECK-SAME: collapsed_slice_dims = [0] - // CHECK-SAME: start_index_map = [0] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: slice_sizes = dense<[1, 4, 128]> - // CHECK-SAME: (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> - %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> - func.return %0 : tensor<2x4x128xf32> -} - -//===----------------------------------------------------------------------===// -// tf.GatherV2 legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @gather_v2 -// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -func.func @gather_v2(%params: tensor<16x2x3xf32>, %indices: tensor<16x5xi32>) -> tensor<16x2x5xf32> { - // CHECK: mhlo.torch_index_select - // CHECK-SAME: %[[PARAMS]], %[[INDICES]] - // CHECK-SAME: batch_dims = 1 - // CHECK-SAME: dim = 2 - %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> - func.return %1 : tensor<16x2x5xf32> -} - -// ----- - -// CHECK-LABEL: @gather_v2_dynamic -// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -func.func @gather_v2_dynamic(%params: tensor, %indices: tensor) -> tensor { - // CHECK: mhlo.torch_index_select - // CHECK-SAME: %[[PARAMS]], %[[INDICES]] - // CHECK-SAME: batch_dims = 1 - // CHECK-SAME: dim = 2 - %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @gather_v2_dynamic_index_i64 -// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -func.func @gather_v2_dynamic_index_i64(%params: tensor, %indices: tensor) -> tensor { - // CHECK: mhlo.torch_index_select - // CHECK-SAME: %[[PARAMS]], %[[INDICES]] - // CHECK-SAME: batch_dims = 1 - // CHECK-SAME: dim = 2 - %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @gather_v2_dynamic_shape -// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -func.func @gather_v2_dynamic_shape(%params: tensor, %indices: tensor) -> tensor { - // CHECK: mhlo.torch_index_select - // CHECK-SAME: %[[PARAMS]], %[[INDICES]] - // CHECK-SAME: batch_dims = 1 - // CHECK-SAME: dim = 2 - %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor - func.return %1 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.StridedSliceGrad legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: strided_slice_grad -// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32> -func.func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> { - - // For StridedSlice - // Dim #: 0, 1, 2 - // Input shape: [4, 128, 1024] - // Begin: 1, 4, -3 - // End: 8, 65, 42 - // Stride: 1, 4, -1 - // Begin mask: 1, 0, 0 (= 1) - // End mask: 0, 0, 1 (= 4) - - // So result shape: - // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 - // Dim #1: 4 to 65 stride 4: so 16 - // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 - // result shape: [4, 16, 1022] - - // To pad back: - // Dim #: 0, 1, 2 - // Pad low: 0, 4, 0 - // Pad interm: 0, 3, 0 - // Pad high: 0, 63, 2 - - %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape %arg0 : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> - // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) <{dimensions = dense<2> : tensor<1xi64>}> : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> - // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) <{edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>}> : (tensor<4x16x1022xf32>, tensor) -> tensor<4x128x1024xf32> - - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> - // CHECK: return [[PAD]] - func.return %0: tensor<4x128x1024xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_grad_shrink_axis_mask -// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32> -func.func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> { - // Input to StridedSlice was of shape 4x8xf32 - // Strided slice gets input[2:3, 0:8] - // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32 - // which is the shape of gradient. - // StridedSliceGrad would reshape the gradient to 1x8xf32 and - // then pad to match the shape of input 4x8xf32. - - %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<8xf32>) -> tensor<1x8xf32> - // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) - // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64> - // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64> - // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64> - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32> - - // CHECK: return [[PAD]] : tensor<4x8xf32> - func.return %0 : tensor<4x8xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_grad_new_axis_mask -// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32> -func.func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> { - // Input to StridedSlice was of shape 8xf32 - // Strided slice gets input[tf.new_axis, 2:4] - // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is - // 1x2xf32 which is the shape of gradient. - // StridedSliceGrad would reshape the gradient to 2xf32 and - // then pad to match the shape of input 4x8xf32. - - %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x2xf32>) -> tensor<2xf32> - // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) - // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64> - // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64> - // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64> - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32> - - // CHECK: return [[PAD]] : tensor<8xf32> - func.return %0 : tensor<8xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_grad_ellipsis_mask -// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32> -func.func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> { - // Input to StridedSlice was of shape 4x4x8xf32 - // Strided slice gets input[2:4, ...] - // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and - // dim#2, ignoring begin and end indices for these dimensions. So the output - // is 2x4x8xf32 which is the shape of gradient. - // StridedSliceGrad would pad the gradient to match the shape of - // input 4x4x8xf32. - - %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> - // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) - // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64> - // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64> - // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64> - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32> - - // CHECK: return [[PAD]] : tensor<4x4x8xf32> - func.return %0 : tensor<4x4x8xf32> -} - - -// CHECK-LABEL: strided_slice_grad_all_masks -// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32> -func.func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> { - // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] - // New axis mask is at index 1 and 6 of sparse spec, so - // new_axis_mask = 2^1 + 2^6 = 66 - // The ellipsis mask is applied to dim #1, #2 of input i.e, we get - // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] - // The StridedSliceGrad op would propogate the gradient for the sliced tensor - // to the original input tensor by padding with zeroes. - - %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>) - %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) - %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) - %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) - - // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0) - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> - // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // The edge_padding_low, edge_padding_high and interior_padding attributes of - // mhlo.pad would reflect the padding required to get the shape of the - // input of StridedSlice op. - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]]) - // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> - // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64> - // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64> - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> - - // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32> - func.return %0 : tensor<2x4x8x16x32x64xf32> -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_update -func.func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: mhlo.return %arg4 : tensor - // CHECK: }) - %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_update_scalar_update -func.func @tensor_scatter_update_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor) -> tensor<4x3xi32> { - // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> - // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi64>) -> tensor<2x3xi32> - // CHECK: "mhlo.scatter" - %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor) -> tensor<4x3xi32> - func.return %0 : tensor<4x3xi32> -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_add -func.func @tensor_scatter_add(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.add %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) - %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_add_scalar_update -func.func @tensor_scatter_add_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor) -> tensor<4x3xi32> { - // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> - // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi64>) -> tensor<2x3xi32> - // CHECK: "mhlo.scatter - %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor) -> tensor<4x3xi32> - func.return %0 : tensor<4x3xi32> -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_sub -func.func @tensor_scatter_sub(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.subtract %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) - %0 = "tf.TensorScatterSub"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_min -func.func @tensor_scatter_min(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.minimum %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) - %0 = "tf.TensorScatterMin"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_max -func.func @tensor_scatter_max(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.maximum %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) - %0 = "tf.TensorScatterMax"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.RandomShuffle legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @random_shuffle_num_elems_le_1 -func.func @random_shuffle_num_elems_le_1() -> tensor { - // CHECK: [[INPUT:%.*]] = mhlo.constant dense<1.000000e+20> : tensor - // CHECK-NEXT: return [[INPUT]] - %cst = "tf.Const"() {value = dense<1.000000e+20> : tensor} : () -> tensor - %0 = "tf.RandomShuffle"(%cst) {device = "", seed = -4294967297 : i64, seed2 = -2147483649 : i64} : (tensor) -> tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @random_shuffle_first_dim_1 -// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32> -func.func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> { - %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>) - // CHECK-NEXT: return [[INPUT]] - func.return %0: tensor<1x?xf32> -} - -// ----- - -// CHECK-LABEL: @random_shuffle_1D_16 -// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32> -func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { - // CHECK-DAG: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> - // CHECK-DAG: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor - // CHECK-DAG: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor - // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) <{rng_distribution = #mhlo.rng_distribution}> - // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) <{dimension = -1 : i64, is_stable = {{.*}}}> ({ - // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): - // CHECK: mhlo.compare LT, [[ARG1]], [[ARG2]], TOTALORDER - // CHECK: }) : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) - // CHECK: return [[SORT]]#1 - %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) - func.return %0: tensor<16xf32> -} - -// ----- - -// CHECK-LABEL: @random_shuffle_1D_10240 -func.func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { - // CHECK: mhlo.rng{{.*UNIFORM.*}} - // CHECK: mhlo.sort - // CHECK: mhlo.rng{{.*UNIFORM.*}} - // CHECK: mhlo.sort - %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) - func.return %0: tensor<10240xf32> -} - -// ----- - -// CHECK-LABEL: @random_shuffle_3D -// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> -func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { - // CHECK: [[INDICES:%.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> - - // CHECK-DAG: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> - // CHECK-DAG: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor - // CHECK-DAG: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor - // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) <{rng_distribution = #mhlo.rng_distribution}> - - // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor - - // CHECK: [[WHILE_OUT:%.*]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[IV_INIT]], [[ITER_ARG1:.*]] = [[SWAPS]], [[ITER_ARG2:.*]] = [[INDICES]]) - // CHECK: [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor - // CHECK: [[CMP:%.*]] = mhlo.compare LT, [[ITER_ARG0]], [[LIMIT]], NOTYPE - // CHECK: mhlo.return [[CMP]] - // CHECK: } do { - // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[ITER_ARG0]]) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<1xi32> - // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG0]]) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<1xi32> - // CHECK: [[SWP:%.*]] = mhlo.reshape [[SWP_IDX]] : (tensor<1xi32>) -> tensor - // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) <{slice_sizes = dense<1> : tensor<1xi64>}> - // CHECK: [[INDICES1:%.*]] = mhlo.dynamic_update_slice [[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG0]] : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - // CHECK: [[INDICES2:%.*]] = mhlo.dynamic_update_slice [[INDICES1]], [[SRC_IDX]], [[SWP]] : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor - // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[ITER_ARG0]], [[ONE]] - // CHECK: mhlo.return [[NEW_IV]], [[ITER_ARG1]], [[INDICES2]] - // CHECK: } - - // CHECK: [[CONSTANT1:%.*]] = mhlo.constant dense<1> : tensor<1xi64> - // CHECK: [[ARITH_CONSTANT:%.*]] = arith.constant 1 : index - // CHECK: [[SHAPE_DIM:%.*]] = shape.dim %arg0, [[ARITH_CONSTANT]] : tensor<4x?x16xf32>, index -> index - // CHECK: [[INDEX_CAST:%.*]] = arith.index_cast [[SHAPE_DIM]] : index to i64 - // CHECK: [[FROM_ELEMENTS:%.*]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> - // CHECK: [[CONSTANT2:%.*]] = mhlo.constant dense<16> : tensor<1xi64> - // CHECK: [[CONCATENATE:%.*]] = "mhlo.concatenate"([[CONSTANT1]], [[FROM_ELEMENTS]], [[CONSTANT2]]) <{dimension = 0 : i64}> : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> - // CHECK: [[DYNAMIC_GATHER:%.*]] = "mhlo.dynamic_gather"([[INPUT]], [[WHILE_OUT]]#2, [[CONCATENATE]]) - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [1, 2] - // CHECK-SAME: collapsed_slice_dims = [0] - // CHECK-SAME: start_index_map = [0] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME:: (tensor<4x?x16xf32>, tensor<4xi32>, tensor<3xi64>) -> tensor<4x?x16xf32> - - // CHECK: return [[DYNAMIC_GATHER]] - - %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>) - func.return %0: tensor<4x?x16xf32> -} - -//===----------------------------------------------------------------------===// -// tf.AvgPool legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @avgpool_valid_padding -// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> -// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> -// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> -// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): -// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] -// CHECK: mhlo.return [[ADD]] -// CHECK: }) -// CHECK-SAME: -> tensor<2x3x5x7xf32> -// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<2x3x5x7xf32> -// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] -// CHECK-SAME: -> tensor<2x3x5x7xf16> -// CHECK: return [[CONV16]] -func.func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> { - %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> - func.return %0 : tensor<2x3x5x7xf16> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_valid_padding -// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> -// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> -// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> -// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): -// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] -// CHECK: mhlo.return [[ADD]] -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x3x5x7xf32> -// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<2x4x3x5x7xf32> -// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] -// CHECK-SAME: -> tensor<2x4x3x5x7xf16> -// CHECK: return [[CONV16]] -func.func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> { - %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> - func.return %0 : tensor<2x4x3x5x7xf16> -} - -// ----- - -// CHECK-LABEL: @avgpool_nchw_format -// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> -// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> -// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> -// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): -// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] -// CHECK: mhlo.return [[ADD]] -// CHECK: }) -// CHECK-SAME: -> tensor<2x7x3x5xf32> -// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<2x7x3x5xf32> -// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] -// CHECK-SAME: -> tensor<2x7x3x5xf16> -// CHECK: return [[CONV16]] -func.func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> { - %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> - func.return %0 : tensor<2x7x3x5xf16> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_ncdhw_format -// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> -// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> -// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> -// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> -// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): -// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] -// CHECK: mhlo.return [[ADD]] -// CHECK: }) -// CHECK-SAME: -> tensor<2x7x4x3x5xf32> -// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<2x7x4x3x5xf32> -// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] -// CHECK-SAME: -> tensor<2x7x4x3x5xf16> -// CHECK: return [[CONV16]] -func.func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> { - %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> - func.return %0 : tensor<2x7x4x3x5xf16> -} - -// ----- - -// CHECK-LABEL: @avgpool_same_padding( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> -// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x6x7xf32> -// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x6x7xf32> -// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> -// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> -// CHECK: } -func.func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> { - %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> - func.return %0 : tensor<2x4x6x7xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_same_padding( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> -// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x4x6x7xf32> -// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x4x6x7xf32> -// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] -// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> -// CHECK: } -func.func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> { - %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> - func.return %0 : tensor<2x4x4x6x7xf32> -} - -//===----------------------------------------------------------------------===// -// AvgPoolGrad op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @avgpool_grad_valid_padding( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<10x12x16x64xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> -// CHECK-SAME: -> tensor<10x25x33x64xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<10x24x32x64xf32> -// CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32> -func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) - %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { - data_format = "NHWC", - ksize = [1, 2, 2, 1], - padding = "VALID", - strides = [1, 2, 2, 1] - } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> - func.return %result : tensor<10x24x32x64xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_grad_valid_padding( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = array} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> -// CHECK-SAME: -> tensor<10x8x25x33x64xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<10x8x24x32x64xf32> -// CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32> -func.func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { - data_format = "NDHWC", - ksize = [1, 1, 2, 2, 1], - padding = "VALID", - strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> - func.return %result : tensor<10x8x24x32x64xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_grad_same_padding( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x7x9xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> -// CHECK-SAME: -> tensor<2x14x27x9xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x13x25x9xf32> -// CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32> -func.func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>) - %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { - data_format = "NHWC", - ksize = [1, 2, 3, 1], - padding = "SAME", - strides = [1, 4, 4, 1] - } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> - func.return %result : tensor<2x13x25x9xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_grad_same_padding( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x8x4x7x9xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> -// CHECK-SAME: -> tensor<2x8x14x27x9xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x8x13x25x9xf32> -// CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32> -func.func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { - data_format = "NDHWC", - ksize = [1, 1, 2, 3, 1], - padding = "SAME", - strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> - func.return %result : tensor<2x8x13x25x9xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_grad_nchw_format( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x9x4x7xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]> -// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> -// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> -// CHECK-SAME: -> tensor<2x9x14x27xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x9x13x25xf32> -// CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32> -func.func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>) - %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { - data_format = "NCHW", - ksize = [1, 1, 2, 3], - padding = "SAME", - strides = [1, 1, 4, 4] - } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> - func.return %result : tensor<2x9x13x25xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x9x8x4x7xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]> -// CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> -// CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> -// CHECK-SAME: -> tensor<2x9x8x14x27xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x9x8x13x25xf32> -// CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32> -func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { - data_format = "NCDHW", - ksize = [1, 1, 1, 2, 3], - padding = "SAME", - strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> - func.return %result : tensor<2x9x8x13x25xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_grad_bf16( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<10x12x16x64xbf16> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> -// CHECK-SAME: -> tensor<10x25x33x64xbf16> -// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = mhlo.convert %[[REDUCE_WINDOW_INPUT]] : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> -// CHECK: %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<10x24x32x64xf32> -// CHECK: %[[RESULT_CONVERTED:.*]] = mhlo.convert %[[RESULT]] : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> -// CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> -func.func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { - %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) - %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { - data_format = "NHWC", - ksize = [1, 2, 2, 1], - padding = "VALID", - strides = [1, 2, 2, 1] - } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> - func.return %result : tensor<10x24x32x64xbf16> -} - -// ----- - -// CHECK-LABEL: xla_sharding -func.func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { - // CHECK-NEXT: mhlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} - %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> - func.return %0 : tensor<4x16xf32> -} - -// ----- - -// CHECK-LABEL: inplace_update_one -func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { - // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> - // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> - // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] - // CHECK-DAG: [[UPDATE:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE2]], [[RESHAPE1]], [[CST]] - %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> - - // CHECK: return [[UPDATE]] - func.return %0 : tensor<8x4xf32> -} - -// ----- - -// CHECK-LABEL: inplace_update_three -func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { - // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> - // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> - // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> - // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> - // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> - // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> - // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] - // CHECK-DAG: [[RESHAPE2:%.+]] = mhlo.reshape [[SLICE2]] - // CHECK-DAG: [[RESHAPE3:%.+]] = mhlo.reshape [[SLICE3]] - // CHECK-DAG: [[UPDATE1:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]] - // CHECK-DAG: [[UPDATE2:%.+]] = mhlo.dynamic_update_slice [[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]] - // CHECK-DAG: [[UPDATE3:%.+]] = mhlo.dynamic_update_slice [[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]] - %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> - - // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> - func.return %0 : tensor<8x8x4xf32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_update_slice -func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { - // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor - // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]] : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> - // CHECK: return [[DUS]] - %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> - func.return %0 : tensor<4x16xf32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_update_slice2 -func.func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { - // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<1xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]] : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> - // CHECK: return [[DUS]] - %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -//===----------------------------------------------------------------------===// -// AllToAll op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @alltoall_basic -// See https://www.tensorflow.org/api_docs/python/tf/raw_ops/AllToAll -func.func @alltoall_basic(%input: tensor<1x2xf32>) -> tensor<2x1xf32> { - %group_assignment = "tf.Const" () { - value = dense<[[0, 1]]> : tensor<1x2xi32> - } : () -> tensor<1x2xi32> - %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 0 : i64, split_count = 2 : i64, split_dimension = 1 : i64} : (tensor<1x2xf32>, tensor<1x2xi32>) -> tensor<2x1xf32> - // CHECK: mhlo.all_to_all - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - func.return %result : tensor<2x1xf32> -} - - -//===----------------------------------------------------------------------===// -// Cumsum op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @cumsum_static -// CHECK-SAME: [[X:%.*]]: tensor<4xf32> -func.func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> - // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ - // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor - // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> - // CHECK: return [[CONVERT_REDUCE]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_exclusive -// CHECK-SAME: [[X:%.*]]: tensor<4xf32> -func.func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> - // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ - // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor - // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) <{edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> - // CHECK: return [[CONVERT_REDUCE]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_reverse -// CHECK-SAME: [[X:%.*]]: tensor<4xf32> -func.func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> - // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ - // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor - // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> - // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: return [[REVERSE_BACK]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_exclusive_reverse -// CHECK-SAME: [[X:%.*]]: tensor<4xf32> -func.func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> - // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ - // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor - // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) <{edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> - // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: return [[REVERSE_BACK]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_empty -func.func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - - // CHECK: mhlo.constant dense<> : tensor<0xf32> - %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor) -> tensor<0xf32> - func.return %1 : tensor<0xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_dynamic -func.func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "tf.Cumsum" - %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// Cumprod op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @cumprod -func.func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) - // CHECK: mhlo.mul - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -//===----------------------------------------------------------------------===// -// tf.Softplus legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @softplus_f16 -// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>) -func.func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] - // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] - // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] - // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] - // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] - // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] - %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16> - - // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16> - func.return %0 : tensor<8x16xf16> -} - -// ----- - -// CHECK-LABEL: func @softplus_bf16 -// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>) -func.func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] - // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] - // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] - // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] - // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] - // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] - %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16> - - // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16> - func.return %0 : tensor<8x16xbf16> -} - -// ----- - -// CHECK-LABEL: func @softplus_f32 -// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>) -func.func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] - // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] - // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] - // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] - // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] - // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] - %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> - - // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32> - func.return %0 : tensor<8x16xf32> -} - -// ----- - -// CHECK-LABEL: func @softplus_f64 -// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>) -func.func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] - // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] - // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] - // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] - // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] - // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] - %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64> - - // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> - func.return %0 : tensor<8x16xf64> -} - -// ----- - -// CHECK-LABEL: @xla_gather -func.func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> { - %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> - - // CHECK: "mhlo.gather" - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [0, 1] - // CHECK-SAME: collapsed_slice_dims = [0] - // CHECK-SAME: start_index_map = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: indices_are_sorted = true - // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> - - %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> - func.return %0 : tensor<1x300x10xf32> -} - -// ----- - -// CHECK-LABEL: @xla_gather_i32 -func.func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> { - %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32> - - // CHECK: "mhlo.gather" - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [0, 1] - // CHECK-SAME: collapsed_slice_dims = [0] - // CHECK-SAME: start_index_map = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: indices_are_sorted = true - // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> - - %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<1x300x10xf32> - func.return %0 : tensor<1x300x10xf32> -} - - -// CHECK: func @stridedslice_with_i32 -func.func @stridedslice_with_i32(%arg0: tensor) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} { -// CHECK-NOT: tf.StridedSlice -// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic_slice -// CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[DYNSLICE]] -// CHECK: return [[RESHAPE]] - %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf_type.shape<>], device = ""} : (tensor, tensor) -> tensor - %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf_type.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32> - func.return %6 : tensor<4xf32> -} - -func.func @replica_id() -> tensor { - // CHECK: %[[ID:.*]] = mhlo.replica_id : tensor - // CHECK: %[[RESULT:.*]] = mhlo.convert %0 : (tensor) -> tensor - %0 = "tf.XlaReplicaId"() : () -> tensor - func.return %0 : tensor -} - -// CHECK: func @angle_c64 -// CHECK-SAME: ([[ARG0:%.*]]: tensor>) -func.func @angle_c64(%arg0: tensor>) -> tensor { -// CHECK: [[IMAG:%.*]] = mhlo.imag [[ARG0]] -// CHECK: [[REAL:%.*]] = mhlo.real [[ARG0]] -// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]] - %0 = "tf.Angle"(%arg0): (tensor>) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.ApproximateEqual legalization -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @approximateequal_f64 -func.func @approximateequal_f64(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor - // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor - // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : (tensor) -> tensor - // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK: return %[[LE]] : tensor - %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor, tensor) -> tensor - func.return %equal : tensor -} - -// CHECK-LABEL: func @approximateequal_i32 -func.func @approximateequal_i32(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor - // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor - // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : (tensor) -> tensor - // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK: return %[[LE]] : tensor - %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor, tensor) -> tensor - func.return %equal : tensor -} - -// CHECK-LABEL: func @approximateequal_complex64 -func.func @approximateequal_complex64(%arg0: tensor>, %arg1: tensor>) -> tensor { - // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor> - // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : (tensor>) -> tensor - // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : tensor - // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK: return %[[LE]] : tensor - %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor>, tensor>) -> tensor - func.return %equal : tensor -} - -//===----------------------------------------------------------------------===// -// tf.XlaConvV2 legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: xla_conv_v2 -func.func @xla_conv_v2(%lhs: tensor<8x4x16x16x16xf32>, %rhs: tensor<4x3x3x16x16xf32>) -> (tensor<4x4x14x14x16xf32>) { - %feature_group_count = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> - %rhs_dilation = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> - %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32> - %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> - // CHECK: mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], window = {stride = [3, 1, 1], pad = {{\[\[}}0, 0], {{\[}}0, 0], {{\[}}0, 0]], lhs_dilate = [4, 1, 1], rhs_dilate = [1, 1, 1]} {batch_group_count = 2 : i64, feature_group_count = 1 : i64, precision_config = []} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>) -> tensor<4x4x14x14x16xf32> - %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {batch_group_count = 2 : i64, dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<4x4x14x14x16xf32> - func.return %0 : tensor<4x4x14x14x16xf32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaDot legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @xladot_matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> -func.func @xladot_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { - // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{ - // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< - // CHECK-NOT: lhs_batching_dimensions = - // CHECK-NOT: rhs_batching_dimensions = - // CHECK-SAME: lhs_contracting_dimensions = [1] - // CHECK-SAME: rhs_contracting_dimensions = [0] - // CHECK-SAME: precision_config = [] - %res = "tf.XlaDot"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> - func.return %res : tensor<64x16xi32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaDotV2 legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @xladotv2_matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> -func.func @xladotv2_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { - // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{ - // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< - // CHECK-NOT: lhs_batching_dimensions = - // CHECK-NOT: rhs_batching_dimensions = - // CHECK-SAME: lhs_contracting_dimensions = [1] - // CHECK-SAME: rhs_contracting_dimensions = [0] - // CHECK-SAME: precision_config = [] - %res = "tf.XlaDotV2"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> - func.return %res : tensor<64x16xi32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaDynamicSlice legalization -//===----------------------------------------------------------------------===// -// ----- - -// CHECK-LABEL: xla_dynamic_slice_constant_start -func.func @xla_dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : - // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> - // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor - // CHECK-NEXT: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) - // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : - // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_slice_i32_consts -func.func @xla_dynamic_slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_slice_constant_start_dynamic_shape -func.func @xla_dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor - // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice" - // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]]) - // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : - // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_slice_variable_start -func.func @xla_dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1) - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor - // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1) - // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_slice_mhlo_sizes -func.func @xla_dynamic_slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { - // CHECK-NOT: "tf.XlaDynamicSlice" - %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> - %1 = "tf.XlaDynamicSlice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> - func.return %1 : tensor<1x512x4xf32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaEinsum legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @xlaeinsum -func.func @xlaeinsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { - // CHECK-NEXT: mhlo.einsum - %0 = "tf.XlaEinsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> - func.return %0: tensor<2x4xf32> -} - - -//===----------------------------------------------------------------------===// -// tf.XlaReduceWindow legalization -//===----------------------------------------------------------------------===// -// ----- -// CHECK-LABEL: @test_xla_reduce_window -func.func @test_xla_reduce_window(%arg0: tensor<7xf32>, %arg1: tensor) -> tensor<10xf32> { - %cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32> - %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_2 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[REDUCE:.*]] = "mhlo.reduce_window"(%arg0, %arg1) <{base_dilations = dense<3> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<4> : tensor<1xi64>, window_dimensions = dense<1> : tensor<1xi64>, window_strides = dense<2> : tensor<1xi64>}> ({ - // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer3(%[[ARG0]], %[[ARG1]]){{.*}} - // CHECK-NEXT: mhlo.return %[[SUM]] : tensor - // CHECK-NEXT: }) : (tensor<7xf32>, tensor) -> tensor<10xf32> - // CHECK-NEXT: return %[[REDUCE]] - %0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer3} : (tensor<7xf32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32> - func.return %0 : tensor<10xf32> -} - -func.func private @sum_reducer3(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.XlaSort legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @xlasort_int -// CHECK-SAME: %[[INPUT:.*]]: tensor<16xi32> -func.func @xlasort_int(%input: tensor<16xi32>) -> (tensor<16xi32>) { - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = -1 : i64, is_stable = false}> ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], NOTYPE - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) : (tensor<16xi32>) -> tensor<16xi32> - // CHECK-NEXT: return %[[SORT]] - %output = "tf.XlaSort"(%input) : (tensor<16xi32>) -> (tensor<16xi32>) - func.return %output : tensor<16xi32> -} - -// ----- - -// CHECK-LABEL: @xlasort_float -// CHECK-SAME: %[[INPUT:.*]]: tensor<8xf64> -func.func @xlasort_float(%input: tensor<8xf64>) -> (tensor<8xf64>) { - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = -1 : i64, is_stable = false}> ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], TOTALORDER - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) : (tensor<8xf64>) -> tensor<8xf64> - // CHECK-NEXT: return %[[SORT]] - %output = "tf.XlaSort"(%input) : (tensor<8xf64>) -> (tensor<8xf64>) - func.return %output : tensor<8xf64> -} - -// ----- - -// CHECK-LABEL: @xlasort_const -func.func @xlasort_const() -> (tensor<2x3xi64>) { - // CHECK: [2, 4, 3], [6, 5, 1] - %input = "tf.Const"() {value = dense<[[2, 4, 3], [6, 5, 1]]> : tensor<2x3xi64>} : () -> (tensor<2x3xi64>) - // CHECK-NEXT: [2, 3, 4], [1, 5, 6] - %output = "tf.XlaSort"(%input): (tensor<2x3xi64>) -> (tensor<2x3xi64>) - func.return %output : tensor<2x3xi64> -} - -//===----------------------------------------------------------------------===// -// tf.XlaRngBitGenerator legalization -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @xla_rng_bit_generator -// CHECK-SAME: %[[STATE:.*]]: tensor<2xui64> -func.func @xla_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0,_retval1"}} { - // CHECK-NEXT: %0 = mhlo.constant dense<[10, 12]> : tensor<2xi32> - %cst = "tf.Const"() {value = dense<[10, 12]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK-NEXT: %1 = mhlo.constant dense<3> : tensor - %cst_0 = "tf.Const"() {value = dense<3> : tensor} : () -> tensor - // CHECK-NEXT: %[[OUTPUT_STATE:.*]], %[[OUTPUT:.*]] = "mhlo.rng_bit_generator"(%[[STATE]]) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) - // CHECK-NEXT: return %[[OUTPUT_STATE]], %[[OUTPUT]] : tensor<2xui64>, tensor<10x12xui32> - %output_key, %output = "tf.XlaRngBitGenerator"(%cst_0, %arg0, %cst) : (tensor, tensor<2xui64>, tensor<2xi32>) -> (tensor<2xui64>, tensor<10x12xui32>) - func.return %output_key, %output : tensor<2xui64>, tensor<10x12xui32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaVariadicV2 legalization -//===----------------------------------------------------------------------===// - -// ----- -// CHECK-LABEL: @xla_variadic_reduce_v2 -func.func @xla_variadic_reduce_v2(%arg0: tensor<2x3xcomplex>, %arg1: tensor>) -> tensor<3xcomplex> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { - // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1) - // CHECK-SAME: dimensions = [0] - // CHECK-NEXT: (%[[ARG0:.*]]: tensor>, %[[ARG1:.*]]: tensor>) - // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer(%[[ARG0]], %[[ARG1]]){{.*}} - // CHECK-NEXT: mhlo.return %[[SUM]] : tensor> - // CHECK: return %[[REDUCE]] - %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operandSegmentSizes = array, reducer = @sum_reducer} : (tensor<2x3xcomplex>, tensor>) -> tensor<3xcomplex> - func.return %0 : tensor<3xcomplex> -} - -func.func private @sum_reducer(%arg0: tensor>, %arg1: tensor>) -> tensor> { - %0 = "tf.AddV2"(%arg1, %arg0) : (tensor>, tensor>) -> tensor> - func.return %0 : tensor> -} - -// ----- - -// CHECK-LABEL: @xla_variadic_reduce_v2_dynamic -func.func @xla_variadic_reduce_v2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { - // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1) - // CHECK-SAME: dimensions = [0] - // CHECK-NEXT: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer2(%[[ARG0]], %[[ARG1]]){{.*}} - // CHECK-NEXT: mhlo.return %[[SUM]] : tensor - // CHECK: return %[[REDUCE]] - %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operandSegmentSizes = array, reducer = @sum_reducer2} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -func.func private @sum_reducer2(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg1, %arg0) : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.XlaVariadicSort legalization -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @xla_variadic_sort -// CHECK-SAME: %[[INPUT:.*]]: tensor<2x3x4xui8> -func.func @xla_variadic_sort(%arg0: tensor<2x3x4xui8>) -> tensor<2x3x4xui8> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { - // CHECK-NEXT: {{.*}} = mhlo.constant dense<0> : tensor - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = 0 : i64, is_stable = false}> ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) - // CHECK-NEXT: %[[CMP:.*]] = func.call @compare_lt(%[[LHS]], %[[RHS]]) : (tensor, tensor) -> tensor - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) : (tensor<2x3x4xui8>) -> tensor<2x3x4xui8> - // CHECK-NEXT: return %[[SORT]] - %0 = "tf.XlaVariadicSort"(%arg0, %cst) {_XlaHasReferenceVars = false, comparator = @compare_lt, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", is_stable = false} : (tensor<2x3x4xui8>, tensor) -> tensor<2x3x4xui8> - func.return %0 : tensor<2x3x4xui8> -} - -func.func private @compare_lt(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._disable_call_shape_inference = true} { - %0 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.NextAfter legalization -//===----------------------------------------------------------------------===// -// CHECK-LABEL: func @nextafter -func.func @nextafter(%arg0: tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %0 = chlo.broadcast_next_after %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - // CHECK-NEXT: return %0 : tensor<2xf32> - %0 = "tf.NextAfter"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %0: tensor<2xf32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaReduceScatter legalization -//===----------------------------------------------------------------------===// -// CHECK-LABEL: func @xla_reduce_scatter -func.func @xla_reduce_scatter(%arg0: tensor<128x128xf32>) -> tensor<64x128xf32> { - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %cst_0 = "tf.Const"() {value = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> - // CHECK: "mhlo.reduce_scatter"(%arg0) - // CHECK{LITERAL}: replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> - // CHECK-SAME: scatter_dimension = 0 - // - %1 = "tf.XlaReduceScatter"(%arg0, %cst_0, %cst) {reduce_op = "Add"} : (tensor<128x128xf32>, tensor<4x2xi32>, tensor) -> tensor<64x128xf32> - func.return %1 : tensor<64x128xf32> -} - - -//===----------------------------------------------------------------------===// -// tf.XlaSelectAndScatter legalization -//===----------------------------------------------------------------------===// -func.func @test_xla_select_and_scatter(%arg0: tensor<4x5x1x1xbf16>, %arg1: tensor<2x2x1x1xbf16>, %arg2: tensor) -> tensor { - %cst = "tf.Const"() {value = dense<0> : tensor<4x2xi32>} : () -> tensor<4x2xi32> - %cst_0 = "tf.Const"() {value = dense<[2, 2, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> - %cst_1 = "tf.Const"() {value = dense<[2, 3, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> - // CHECK: %[[SELECT_AND_SCATTER:.*]] = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>, window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>}> ({ - // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - // CHECK-NEXT: %[[RES:.*]] = func.call @ge_select(%[[ARG0]], %[[ARG1]]){{.*}} - // CHECK-NEXT: mhlo.return %[[RES]] : tensor - // CHECK-NEXT: }, { - // CHECK-NEXT: ^{{.*}}(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor) - // CHECK-NEXT: %[[RES:.*]] = func.call @add_scatter(%[[ARG2]], %[[ARG3]]){{.*}} - // CHECK-NEXT: mhlo.return %[[RES]] : tensor - // CHECK-NEXT: }) : (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor) -> tensor - // CHECK-NEXT: return %[[SELECT_AND_SCATTER]] - %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) {scatter = @add_scatter, select = @ge_select} : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor) -> tensor - func.return %0 : tensor -} - -func.func private @add_scatter(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -func.func private @ge_select(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.GreaterEqual"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.XlaOptimizationBarrier legalization -//===----------------------------------------------------------------------===// - -func.func @test_xla_optimization_barrier(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) { - // CHECK: %[[OPT_BARRIER:.*]]:2 = mhlo.optimization_barrier %arg0, %arg1 - // CHECK-NEXT: return %[[OPT_BARRIER]]#0, %[[OPT_BARRIER]]#1 - %0, %1 = "tf.XlaOptimizationBarrier"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) - func.return %0, %1 : tensor<4x4xf32>, tensor<3x4xi32> -} - -// CHECK-LABEL: @ifRegion -// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) -func.func @ifRegion(%arg0: tensor, %arg1: tensor) -> (tensor) { - // CHECK: [[VAL0:%.+]] = mhlo.compare GT, [[ARG0]], [[ARG1]] - %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: [[VAL1:%.+]] = "mhlo.if"([[VAL0]]) ({ - %1 = "tf.IfRegion"(%0) ({ - // CHECK: [[VAL2:%.+]] = mhlo.log [[ARG0]] - %2 = "tf.Log"(%arg0) : (tensor) -> tensor - // CHECK: mhlo.return [[VAL2]] - "tf.Yield"(%2) : (tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.exponential [[ARG1]] - %2 = "tf.Exp"(%arg1) : (tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%2) : (tensor) -> () - // CHECK: }) : (tensor) -> tensor - }) {is_stateless = true} : (tensor) -> tensor - // CHECK: return [[VAL1]] - func.return %1 : tensor -} - -// CHECK-LABEL: func @caseRegion -// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor, [[ARG0:.+]]: tensor, [[ARG1:%.+]]: tensor) -func.func @caseRegion(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - // CHECK: [[VAL1:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]]) ({ - %0:2 = "tf.CaseRegion"(%index) ({ - // CHECK: [[VAL2:%.+]] = mhlo.exponential [[ARG1]] - %1 = mhlo.exponential %arg1 : (tensor) -> tensor - // CHECK: mhlo.return [[VAL2]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.log [[ARG0]] - %1 = mhlo.log %arg0 : (tensor) -> tensor - // CHECK: mhlo.return [[VAL3]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - }, { - // CHECK: [[VAL4:%.+]] = mhlo.floor [[ARG0]] - %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor - // CHECK: mhlo.return [[VAL4]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - // CHECK: }) : (tensor) -> (tensor, tensor) - }) {is_stateless = true} : (tensor) -> (tensor, tensor) - // CHECK: return [[VAL1]]#0, [[VAL1]]#1 : tensor, tensor - func.return %0#0, %0#1 : tensor, tensor -} - -// ----- - -// This test case also ensures the mhlo dialect is loaded as a dependency by the -// pass and hence the split here. - -// CHECK-LABEL: func @whileRegion -func.func @whileRegion() -> tensor { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ - ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): - %3 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - "tf.Yield"(%3) : (tensor) -> () - }, { - ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): - %4 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - "tf.Yield"(%4, %4, %4) : (tensor, tensor, tensor) -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) - func.return %2#2 : tensor -} - -// ----- - -// CHECK-LABEL: func @whileRegionAdd -func.func @whileRegionAdd() -> tensor { - // CHECK: [[VAL0:%.+]] = mhlo.constant - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant - %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - // CHECK: [[VAL2:%.+]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[VAL0]], [[ITER_ARG1:.*]] = [[VAL1]], [[ITER_ARG2:.*]] = [[VAL0]]) - %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ - ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.constant - %3 = "tf.Const"() {value = dense<10> : tensor} : () -> tensor - // CHECK: [[VAL4:%.+]] = mhlo.compare LT, [[ITER_ARG2]], [[VAL3]] - %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL4]] - "tf.Yield"(%4) : (tensor) -> () - }, { - ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): - // CHECK: [[VAL5:%.+]] = mhlo.constant - %5 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: [[VAL6:%.+]] = mhlo.add [[ITER_ARG2]], [[VAL5]] - %6 = mhlo.add %barg2, %5 : tensor - // CHECK: [[VAL7:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL5]] - %7 = mhlo.add %barg0, %5 : tensor - // CHECK: mhlo.return [[VAL7]], [[ITER_ARG1]], [[VAL6]] - "tf.Yield"(%7, %barg1, %6) : (tensor, tensor, tensor) -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) - // CHECK: return [[VAL2]]#2 - func.return %2#2 : tensor -} - -// ----- - -// CHECK-LABEL: func @whileRegionImplicitInputs -// CHECK-SAME: ([[ARG0:%.+]]: tensor) -func.func @whileRegionImplicitInputs(%arg0: tensor) -> tensor { - // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> - %0 = mhlo.constant dense<0> : tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> - %1 = mhlo.constant dense<-1> : tensor - // CHECK: [[VAL2:%.+]] = mhlo.while([[ITER_ARG0:.*]] = [[ARG0]]) - %2 = "tf.WhileRegion"(%arg0) ({ - ^cond(%carg0: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[ITER_ARG0]], [[VAL0]] - %3 = mhlo.compare LT, %carg0, %0 : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%3) : (tensor) -> () - }, { - ^body(%barg0: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL1]] - %3 = mhlo.add %barg0, %1 : tensor - // CHECK: [[VAL4:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL3]] - %4 = mhlo.add %barg0, %3 : tensor - // CHECK: mhlo.return [[VAL4]] - "tf.Yield"(%4) : (tensor) -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor) -> tensor - // CHECK: return [[VAL2]] - func.return %2 : tensor -} - -// CHECK-LABEL: func @whileRegionMultipleImplicitInputs -func.func @whileRegionMultipleImplicitInputs() { - // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> - %0 = mhlo.constant dense<0> : tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> - %1 = mhlo.constant dense<-1> : tensor - // CHECK: mhlo.while() - "tf.WhileRegion"() ({ - // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[VAL0]], [[VAL1]] - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%2) : (tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.add [[VAL0]], [[VAL1]] - %2 = mhlo.add %0, %1 : tensor - // CHECK: mhlo.return - "tf.Yield"() : () -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> () - // CHECK: return - func.return -} From 845733207208bb7347b068b835ad23c0465dcc90 Mon Sep 17 00:00:00 2001 From: pizzud Date: Mon, 30 Sep 2024 11:57:28 -0700 Subject: [PATCH 437/483] cuda_driver_test: Delete the allocated graph. We allocate a graph but never delete it, which is a clear memory leak. PiperOrigin-RevId: 680664528 --- third_party/xla/xla/stream_executor/cuda/BUILD | 6 +++++- .../xla/xla/stream_executor/cuda/cuda_driver_test.cc | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 4255d6566b0a66..0e1a584b00264f 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -318,7 +318,10 @@ xla_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], backends = ["gpu"], - tags = ["cuda-only"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ ":cuda_diagnostics", ":cuda_driver", @@ -326,6 +329,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/gpu:scoped_activate_context", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", "@local_config_cuda//cuda:cuda_headers", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc index cd7fc58bfe5ca0..a28109078a05c3 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/cleanup/cleanup.h" #include "absl/log/log.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" @@ -102,6 +103,8 @@ TEST_F(CudaDriverTest, GraphGetNodeCountTest) { CHECK_CUDA(cuCtxCreate(&context, 0, device)); gpu::GpuGraphHandle graph; TF_CHECK_OK(gpu::GpuDriver::CreateGraph(&graph)); + absl::Cleanup cleanup( + [graph] { TF_CHECK_OK(gpu::GpuDriver::DestroyGraph(graph)); }); EXPECT_THAT(gpu::GpuDriver::GraphGetNodeCount(graph), IsOkAndHolds(0)); gpu::GpuGraphNodeHandle node; TF_CHECK_OK(gpu::GpuDriver::GraphAddEmptyNode(&node, graph, {})); From a5575bbe98a7d657fe7cc2fe971247ae2f4347ac Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 30 Sep 2024 12:07:45 -0700 Subject: [PATCH 438/483] [XLA:GPU] Fix the derivation for the number of warps for tiled HLO computations. The number of warps used to process a computation determines how many registers we are able to use concurrently. Therefore, looking at the largest (padded) tile size makes sense, since it determines the minimum number of elements that must be live concurrently. Previously, the logic erroneously only looked at the output tile sizes. This approach is not perfect, and may be further improved by e.g. doing a live range analysis on the tiles of the computation. PiperOrigin-RevId: 680668856 --- .../model/gpu_indexing_performance_model.cc | 11 ++++- .../gpu_indexing_performance_model_test.cc | 48 +++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index f8ac967dd1ac68..22ac7903c0bc4e 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -524,9 +524,16 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( LaunchDimensions GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion( const TiledHloComputation& tiled_hlo_computation) { - const auto* tiled_root = tiled_hlo_computation.GetRoot(); int64_t num_blocks = tiled_hlo_computation.num_output_tiles(); - int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes())); + + // Decide on the number of warps to use based on the largest live tile size + // at any given point within the computation. + int64_t largest_live_tile_size = 1; + for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) { + largest_live_tile_size = std::max( + largest_live_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes())); + } + int64_t num_warps = GetNumWarps(largest_live_tile_size); return {static_cast(num_blocks), static_cast(num_warps * WarpSize())}; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index f9f6b05702e79e..0de3856b9864d3 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -620,6 +620,54 @@ ENTRY main { // and corresponds to 4 warps. EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); } + +TEST_F(GpuIndexingPerformanceModelTest, + NumberOfWarpsDependsOnLargestLiveTileSize) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add = f32[] add(param_0, param_1) +} + +fusion_computation { + param_0 = f32[1,4096] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[1] reduce(param_0, c0), dimensions={1}, to_apply=add +} + +ENTRY main { + param_0 = f32[1,4096] parameter(0) + ROOT fusion = f32[1] fusion(param_0), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_, + /*emitter_specific_constraints_builder=*/nullptr); + ASSERT_TRUE(std::holds_alternative(analysis_or_error)); + + TF_ASSERT_OK_AND_ASSIGN( + TiledHloComputation tiled_hlo_computation, + std::get(analysis_or_error) + .ComputeTiledHloInstructions(/*tile_parameters=*/{1})); + + LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis:: + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + EXPECT_EQ(launch_dimensions.num_blocks(), 1); + + // The largest tile size is 1 * 4096, for which our implementation recommends + // using 4 warps. + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); +} + class FlopsPerElementTest : public GpuIndexingPerformanceModelTest { public: void CompareFlopsModels(absl::string_view hlo_module_string) { From 94314719412ca442b01c8358a1bfde6f0450b41a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 12:11:52 -0700 Subject: [PATCH 439/483] Minor fix to only handle sc and tc planes. PiperOrigin-RevId: 680670430 --- tensorflow/core/profiler/utils/op_metrics_db_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc index 61e7fa812d7c51..66f3ab1e6a129f 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc @@ -37,7 +37,7 @@ namespace tensorflow { namespace profiler { const absl::string_view kIdle = "IDLE"; -const uint32_t kSparseCoreIndexStart = 1000; +const uint32_t kSparseCoreIndexStart = 1000000; namespace { From 2a158a64eae8a1bb0c7e440a3bbaac6a9442bf3e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 30 Sep 2024 13:29:21 -0700 Subject: [PATCH 440/483] [XLA:Python] Fix crash introduced by https://github.com/openxla/xla/pull/10150 If a non-weakreferenceable object is passed, remove the entry again, otherwise we leave a garbage entry in the map. (The original code did a double lookup for reasons that I had missed, but this is better.) PiperOrigin-RevId: 680697569 --- .../xla/xla/python/weakref_lru_cache.cc | 36 +++++++++++-------- .../xla/xla/python/weakref_lru_cache_test.py | 14 ++++++++ 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/python/weakref_lru_cache.cc b/third_party/xla/xla/python/weakref_lru_cache.cc index 736bf437d106d6..a6fd73bc645b03 100644 --- a/third_party/xla/xla/python/weakref_lru_cache.cc +++ b/third_party/xla/xla/python/weakref_lru_cache.cc @@ -158,21 +158,27 @@ class WeakrefLRUCache : public std::enable_shared_from_this { auto& value = it->second; value.cache = std::make_shared(&lru_list_); - value.weakref = - nb::weakref(key.object, nb::cpp_function([this_weak = weak_from_this(), - key](nb::handle weakref) { - auto cache = this_weak.lock(); - if (cache == nullptr) { - return; - } - auto it = cache->entries_.find(key); - if (it == cache->entries_.end()) { - return; - } - // Create temp-var to avoid re-entrant erase. - auto tmp = std::move(it->second); - cache->entries_.erase(it); - })); + auto weakref_gc_callback = nb::cpp_function( + [this_weak = weak_from_this(), key](nb::handle weakref) { + auto cache = this_weak.lock(); + if (cache == nullptr) { + return; + } + auto it = cache->entries_.find(key); + if (it == cache->entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + cache->entries_.erase(it); + }); + PyObject* ref = + PyWeakref_NewRef(key.object.ptr(), weakref_gc_callback.ptr()); + if (!ref) { + entries_.erase(it); + throw nb::python_error(); + } + value.weakref = nb::steal(ref); return value.cache; } diff --git a/third_party/xla/xla/python/weakref_lru_cache_test.py b/third_party/xla/xla/python/weakref_lru_cache_test.py index ad5f07bee0bf72..0376cf1d3690dd 100644 --- a/third_party/xla/xla/python/weakref_lru_cache_test.py +++ b/third_party/xla/xla/python/weakref_lru_cache_test.py @@ -15,6 +15,7 @@ import threading import time +import weakref from absl.testing import absltest @@ -111,6 +112,19 @@ class WRKey: cache(wrkey, "arg2") self.assertLen(cache.cache_keys(), 2) + def testNonWeakreferenceableKey(self): + class NonWRKey: + __slots__ = () + + non_wr_key = NonWRKey() + with self.assertRaises(TypeError): + weakref.ref(non_wr_key) + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x: 2048) + for _ in range(100): + with self.assertRaises(TypeError): + cache(non_wr_key) + def testCrashingKey(self): class WRKey: pass From 8fac27b486939f40bc8e362b94a16a4a8bb51869 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Mon, 30 Sep 2024 13:56:25 -0700 Subject: [PATCH 441/483] Set the output tensor byte size in reshape to the input tensor byte size. Reshape doesn't alter it's data, and when it is prepared, we know it's inputs have a correct shape/byte size (because dynamicaly shaped tensors stop the prepare cycle and force an eval one before restarting a prepare cycle). This allows the ArenaPlanner to match the dynamic Reshape input and output tensors and reuse the input tensor data for the output tensor. PiperOrigin-RevId: 680707865 --- tensorflow/lite/kernels/reshape.cc | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/kernels/reshape.cc b/tensorflow/lite/kernels/reshape.cc index 83ce8727c03e8e..006af7583218c1 100644 --- a/tensorflow/lite/kernels/reshape.cc +++ b/tensorflow/lite/kernels/reshape.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/lite/array.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/tensor.h" @@ -97,7 +98,9 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); - if (shape == nullptr) return nullptr; + if (shape == nullptr) { + return nullptr; + } TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]); for (int i = 0; i < output_shape->size; ++i) { @@ -160,17 +163,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); if (NumInputs(node) == 1 || IsConstantOrPersistentTensor(shape)) { + op_data->output_shape_known = true; if (IsConstantOrPersistentTensor(input)) { SetTensorToPersistentRo(output); TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); op_data->output_ptr = output->data.data; memcpy(output->data.data, input->data.data, input->bytes); - return kTfLiteOk; } else { TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); } + return kTfLiteOk; } else { op_data->output_shape_known = false; + // We know the output bytes size is the same as the input. Setting this + // enables tensor sharing in the ArenaPlanner. + if (output->allocation_type == kTfLiteArenaRw) { + output->bytes = input->bytes; + } return kTfLiteOutputShapeNotKnown; } } @@ -252,7 +261,8 @@ TfLiteRegistration* Register_RESHAPE() { /*version=*/0, /*registration_external=*/nullptr, /*async_kernel=*/nullptr, - kTfLiteInplaceOpInput0Shared | kTfLiteInplaceOpDataUnmodified}; + /*inplace_operator=*/kTfLiteInplaceOpInput0Shared | + kTfLiteInplaceOpDataUnmodified}; return &r; } From 760c99c3eba3df3f4ab427fe1f5c1696fec372de Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 30 Sep 2024 14:31:32 -0700 Subject: [PATCH 442/483] Remove superfluous `gpu` tag on `//xla/stream_executor/cuda:cuda_driver_test` PiperOrigin-RevId: 680720952 --- third_party/xla/xla/stream_executor/cuda/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 0e1a584b00264f..35d33e03af429c 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -320,7 +320,6 @@ xla_test( backends = ["gpu"], tags = [ "cuda-only", - "gpu", ], deps = [ ":cuda_diagnostics", From ef27c2c8c86d10d31718eefc27ee5d80d96fd25d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 14:34:00 -0700 Subject: [PATCH 443/483] Remove XLA_BACKEND_SUPPORTS_BFLOAT16 macro. This macro is set for all devices so it can be cleaned up. PiperOrigin-RevId: 680722061 --- .../tests/exhaustive/exhaustive_binary_test_definitions.inc | 4 ---- ...exhaustive_binary_test_f16_and_smaller_instantiation.inc | 4 ---- .../tests/exhaustive/exhaustive_unary_test_definitions.inc | 4 ---- .../exhaustive_unary_test_f32_and_smaller_instantiation.inc | 4 ---- third_party/xla/xla/tests/pad_test.cc | 6 ------ third_party/xla/xla/tests/reduce_window_test.cc | 6 ------ third_party/xla/xla/tests/reshape_test.cc | 5 ----- third_party/xla/xla/tests/reverse_test.cc | 6 ------ 8 files changed, 39 deletions(-) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc index e4feb10c9918cd..1c1967256218a0 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc @@ -264,13 +264,9 @@ using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; #define BINARY_TEST_F16(test_name, ...) #endif -#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) #define BINARY_TEST_BF16(test_name, ...) \ XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \ __VA_ARGS__ -#else -#define BINARY_TEST_BF16(test_name, ...) -#endif #define BINARY_TEST_F32(test_name, ...) \ XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \ diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc index 04456e6f3a8eaa..339406cd0da05d 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc @@ -27,12 +27,8 @@ INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2BinaryTest, GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); #endif -#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest, ::testing::ValuesIn(CreateExhaustiveU32Ranges())); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); -#endif #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc index dc160bac741954..5c023eec713643 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc @@ -123,13 +123,9 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, #define UNARY_TEST_E5M2(test_name, ...) #endif -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 #define UNARY_TEST_BF16(test_name, ...) \ XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ __VA_ARGS__ -#else -#define UNARY_TEST_BF16(test_name, ...) -#endif #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) #define UNARY_TEST_F16(test_name, ...) \ diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc index b0c1a087b9283b..efb173686c9849 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc @@ -27,12 +27,8 @@ INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2UnaryTest, GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2UnaryTest); #endif -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, ::testing::Values(std::make_pair(0, 1 << 16))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); -#endif #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16UnaryTest, diff --git a/third_party/xla/xla/tests/pad_test.cc b/third_party/xla/xla/tests/pad_test.cc index cd039d3daaf931..1e1a4b8a03306a 100644 --- a/third_party/xla/xla/tests/pad_test.cc +++ b/third_party/xla/xla/tests/pad_test.cc @@ -31,13 +31,7 @@ limitations under the License. namespace xla { namespace { -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -// Tests both F32 and BF16. static std::array use_bfloat16_params{false, true}; -#else -// Only tests F32. -static std::array use_bfloat16_params{false}; -#endif class PadTest : public ClientLibraryTestBase { protected: diff --git a/third_party/xla/xla/tests/reduce_window_test.cc b/third_party/xla/xla/tests/reduce_window_test.cc index 4417ded2499353..e0fe47f92d3e69 100644 --- a/third_party/xla/xla/tests/reduce_window_test.cc +++ b/third_party/xla/xla/tests/reduce_window_test.cc @@ -43,13 +43,7 @@ limitations under the License. namespace xla { namespace { -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -// Tests both F32 and BF16. static std::array use_bfloat16_params{false, true}; -#else -// Only tests F32. -static std::array use_bfloat16_params{false}; -#endif class ReduceWindowTestBase : public ClientLibraryTestBase { public: diff --git a/third_party/xla/xla/tests/reshape_test.cc b/third_party/xla/xla/tests/reshape_test.cc index 9e3c09dd12ffc0..28836514705398 100644 --- a/third_party/xla/xla/tests/reshape_test.cc +++ b/third_party/xla/xla/tests/reshape_test.cc @@ -1017,12 +1017,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { zero_error_spec_, &expected.shape()); } -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::Bool()); -#else -INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, - ::testing::ValuesIn(std::vector{false})); -#endif using ReshapeHloTest = HloTestBase; diff --git a/third_party/xla/xla/tests/reverse_test.cc b/third_party/xla/xla/tests/reverse_test.cc index 299ea416e3c9e7..6cd6afa268e385 100644 --- a/third_party/xla/xla/tests/reverse_test.cc +++ b/third_party/xla/xla/tests/reverse_test.cc @@ -30,13 +30,7 @@ limitations under the License. namespace xla { namespace { -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -// Tests both F32 and BF16. static std::array use_bfloat16_params{false, true}; -#else -// Only tests F32. -static std::array use_bfloat16_params{false}; -#endif struct ReverseSpec { std::vector input_dims; From d245f648c5a7e0783dffd21b8a87a497dc185931 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Mon, 30 Sep 2024 14:56:44 -0700 Subject: [PATCH 444/483] [StableHLO] Fix updated indices_are_sorted and handle case of batching dim size overflowing indices integer type. PiperOrigin-RevId: 680729958 --- third_party/stablehlo/temporary.patch | 616 ++++++++++++++++++ .../xla/third_party/stablehlo/temporary.patch | 616 ++++++++++++++++++ 2 files changed, 1232 insertions(+) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe9..b9970316773eaf 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,617 @@ +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ++++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +@@ -69,7 +69,7 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> + func.return %0 : tensor<4x3x5x8xi32> + } +@@ -77,9 +77,9 @@ + // ----- + + // CHECK-LABEL: @gather_with_batching_no_index_vector_dim ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +-// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> + // CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ + // CHECK-SAME: dimension_numbers = #stablehlo.gather< +@@ -102,7 +102,7 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> + func.return %0 : tensor<4x3x5x8xi32> + } +@@ -133,9 +133,305 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> + func.return %0 : tensor<0x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<3x4x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 1 : tensor<3x4x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<3x4x5x1xi32>, tensor<3x4x5x1xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x4xi32>) -> tensor<3x4x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<3x4x5x8xi32> ++func.func @gather_batching_dims_indices_become_unsorted(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [0, 1], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> ++ func.return %0 : tensor<3x4x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted_2 ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_become_unsorted_2(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_remain_sorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = true, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [0, 2], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_remain_unsorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_remain_unsorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [0, 2], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_does_not_overflow_indices_type ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x127x5x1xi8> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x127x5x1xi8> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x127x5x1xi8>, tensor<4x127x5x1xi8>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x4xi8> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x4xi8>) -> tensor<4x127x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x127x5x8xi32> ++func.func @gather_batching_dims_does_not_overflow_indices_type(%arg0: tensor<127x2x4x7x9xi32>, %arg1: tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> ++ func.return %0 : tensor<4x127x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_signless_indices_type ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5x2xi8>) -> tensor<4x128x5x2xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x2xi32>) -> tensor<4x128x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x4xi32>) -> tensor<4x128x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> ++func.func @gather_batching_dim_overflows_signless_indices_type(%arg0: tensor<128x2x4x7x9xi32>, %arg1: tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> ++ func.return %0 : tensor<4x128x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_unsigned_indices_type ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<256x4x5x2xui8>) -> tensor<256x4x5x2xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<256x4x5x1xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<256x4x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim0]], %[[iota_dim1]], %[[convert]], dim = 3 : (tensor<256x4x5x1xi32>, tensor<256x4x5x1xi32>, tensor<256x4x5x2xi32>) -> tensor<256x4x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x4xi32>) -> tensor<256x4x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<256x4x5x8xi32> ++func.func @gather_batching_dim_overflows_unsigned_indices_type(%arg0: tensor<256x2x4x7x9xi32>, %arg1: tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [0, 1], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> ++ func.return %0 : tensor<256x4x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_indices_type_and_i32 ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x2xi64> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x2147483648x5x1xi64> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x2147483648x5x1xi64> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x2xi64>) -> tensor<4x2147483648x5x4xi64> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x4xi64>) -> tensor<4x2147483648x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x2147483648x5x8xi32> ++func.func @gather_batching_dim_overflows_indices_type_and_i32(%arg0: tensor<2147483648x2x4x7x9xi32>, %arg1: tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> ++ func.return %0 : tensor<4x2147483648x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_dynamic_size ++// CHECK: operand_batching_dims = [0, 2] ++// CHECK: start_indices_batching_dims = [1, 0] ++func.func @gather_batching_dim_dynamic_size(%arg0: tensor, %arg1: tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor, tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> ++ func.return %0 : tensor<4x?x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_and_no_index_vector_dim ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5xi8>) -> tensor<4x128x5xi32> ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %[[convert]] : (tensor<4x128x5xi32>) -> tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>) -> tensor<4x128x5x3xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], ++// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<128x2x4x9xi32>, tensor<4x128x5x3xi32>) -> tensor<4x128x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> ++func.func @gather_batching_dim_overflows_and_no_index_vector_dim(%arg0: tensor<128x2x4x9xi32>, %arg1: tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<128x2x4x9xi32>, tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> ++ func.return %0 : tensor<4x128x5x8xi32> + } + + // ----- +@@ -156,7 +452,7 @@ + // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ +- indices_are_sorted = true, ++ indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1, 3], +@@ -176,9 +472,9 @@ + // ----- + + // CHECK-LABEL: @scatter_with_batching_no_index_vector_dim ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +-// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> + // CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ + // CHECK-SAME: indices_are_sorted = false, +@@ -192,7 +488,7 @@ + // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ +- indices_are_sorted = true, ++ indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1], +@@ -208,3 +504,60 @@ + }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> + func.return %0 : tensor<3x2x4x9xi32> + } ++ ++// ----- ++ ++// CHECK-LABEL: @scatter_batching_dims_indices_remain_sorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ ++// CHECK-SAME: indices_are_sorted = true, ++// CHECK-SAME: dimension_numbers = #stablehlo.scatter< ++// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], ++// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: unique_indices = false}> ++// CHECK: (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> ++// CHECK-NEXT: return %[[scatter]] : tensor<2x5x4x7x9xi32> ++func.func @scatter_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>, %arg2: tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> { ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ ++ indices_are_sorted = true, ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [3], ++ inserted_window_dims = [2, 3], ++ input_batching_dims = [0, 1], ++ scatter_indices_batching_dims = [0, 2], ++ scatter_dims_to_operand_dims = [2, 3], ++ index_vector_dim = 3 ++ >, ++ unique_indices = false ++ }> ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ stablehlo.return %arg4 : tensor ++ }) : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> ++ func.return %0 : tensor<2x5x4x7x9xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @scatter_batching_dim_dynamic_scatter_indices ++// CHECK: input_batching_dims = [0, 2] ++// CHECK: scatter_indices_batching_dims = [1, 0] ++func.func @scatter_batching_dim_dynamic_scatter_indices(%arg0: tensor, %arg1: tensor<4x?x5x2xi32>, %arg2: tensor<4x?x5x8xi32>) -> tensor { ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ ++ indices_are_sorted = false, ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [3], ++ inserted_window_dims = [1, 3], ++ input_batching_dims = [0, 2], ++ scatter_indices_batching_dims = [1, 0], ++ scatter_dims_to_operand_dims = [1, 3], ++ index_vector_dim = 3 ++ >, ++ unique_indices = false ++ }> ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ stablehlo.return %arg4 : tensor ++ }) : (tensor, tensor<4x?x5x2xi32>, tensor<4x?x5x8xi32>) -> tensor ++ func.return %0 : tensor ++} +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ++++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +@@ -22,8 +22,11 @@ + #include "llvm/ADT/STLExtras.h" + #include "llvm/ADT/SmallVector.h" + #include "llvm/Support/ErrorHandling.h" ++#include "llvm/Support/MathExtras.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/BuiltinTypeInterfaces.h" + #include "mlir/IR/BuiltinTypes.h" + #include "mlir/IR/Diagnostics.h" + #include "mlir/IR/PatternMatch.h" +@@ -75,6 +78,42 @@ + return result; + } + ++bool fitsInIntegralType(int64_t size, IntegerType type) { ++ if (type.isUnsigned()) { ++ return llvm::isUIntN(type.getWidth(), size); ++ } else { ++ return llvm::isIntN(type.getWidth(), size); ++ } ++} ++ ++// If `type` is an integer type in which `size` doesn't fit, promote it to i32 ++// or i64 (depending on `size`). ++Type promoteTypeForSize(Type type, int64_t size, OpBuilder &builder) { ++ // Gather/Scatter should have an integer type, but we check just in case. ++ auto intType = dyn_cast(type); ++ if (!intType || fitsInIntegralType(size, intType)) { ++ return type; ++ } ++ if (fitsInIntegralType(size, builder.getI32Type())) { ++ return builder.getI32Type(); ++ } ++ return builder.getI64Type(); ++} ++ ++// If `indices_batching_dims` and `updated_index_map` are both sorted, then the ++// `indices_are_sorted` property is preserved. ++// ++// This is because each concatenated iota is monotonically increasing, sorted ++// indices batching dims mean their order corresponds to the order of batching ++// dims in the operand, and a sorted updated start index map means the order of ++// the index vector dim corresponds to the order of operand dims. ++bool getUpdatedIndicesAreSorted(bool indices_are_sorted, ++ ArrayRef indices_batching_dims, ++ ArrayRef updated_index_map) { ++ return indices_are_sorted && llvm::is_sorted(indices_batching_dims) && ++ llvm::is_sorted(updated_index_map); ++} ++ + // Returns an updated indices tensor such that an `IotaOp` is prepended for each + // dim in `indicesBatchingDims` with a `ConcatenateOp`. + // +@@ -85,16 +124,31 @@ + PatternRewriter &rewriter) { + Location loc = indices.getLoc(); + auto indicesType = cast(indices.getType()); ++ Type elementType = indicesType.getElementType(); ++ ++ // The batching dim sizes might not fit in the existing element type, ++ // in which case we need to promote it. ++ for (int64_t batchingDim : indicesBatchingDims) { ++ elementType = promoteTypeForSize( ++ elementType, indicesType.getDimSize(batchingDim), rewriter); ++ } ++ if (elementType != indicesType.getElementType()) { ++ indicesType = RankedTensorType::get(indicesType.getShape(), elementType); ++ indices = rewriter.create(loc, indicesType, indices); ++ } ++ + bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); +- + SmallVector iotaShape(indicesType.getShape()); + if (indexVectorDimOnLastDim) { + iotaShape.push_back(1); + } else { + iotaShape[indexVectorDim] = 1; + } +- auto iotaType = +- RankedTensorType::get(iotaShape, indicesType.getElementType()); ++ auto iotaType = RankedTensorType::get(iotaShape, elementType); ++ ++ if (indexVectorDimOnLastDim) { ++ indices = rewriter.create(loc, iotaType, indices); ++ } + + SmallVector indicesToConcat; + indicesToConcat.reserve(indicesBatchingDims.size() + 1); +@@ -102,12 +156,7 @@ + indicesToConcat.push_back( + rewriter.create(loc, iotaType, batchingDim)); + } +- if (indexVectorDimOnLastDim) { +- indicesToConcat.push_back( +- rewriter.create(loc, iotaType, indices)); +- } else { +- indicesToConcat.push_back(indices); +- } ++ indicesToConcat.push_back(indices); + return rewriter.create(loc, indicesToConcat, indexVectorDim); + } + +@@ -125,9 +174,17 @@ + PatternRewriter &rewriter) const override { + GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); + ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); ++ ArrayRef startIndicesBatchingDims = ++ dimNumbers.getStartIndicesBatchingDims(); + if (operandBatchingDims.empty()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "gather op has no batching dims"; ++ }); ++ } ++ ++ if (!op.getStartIndices().getType().hasStaticShape()) { ++ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { ++ diag << "gather op has start indices with dynamic shape, can't expand"; + }); + } + +@@ -136,16 +193,18 @@ + SmallVector newStartIndexMap = + llvm::to_vector(llvm::concat( + operandBatchingDims, dimNumbers.getStartIndexMap())); +- Value newIndices = createConcatIndices( +- op.getStartIndices(), dimNumbers.getIndexVectorDim(), +- dimNumbers.getStartIndicesBatchingDims(), rewriter); ++ Value newIndices = createConcatIndices(op.getStartIndices(), ++ dimNumbers.getIndexVectorDim(), ++ startIndicesBatchingDims, rewriter); + rewriter.replaceOpWithNewOp( + op, op.getOperand(), newIndices, + GatherDimensionNumbersAttr::get( + op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, + /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, + newStartIndexMap, dimNumbers.getIndexVectorDim()), +- op.getSliceSizes(), /*indicesAreSorted=*/false); ++ op.getSliceSizes(), ++ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), ++ startIndicesBatchingDims, newStartIndexMap)); + + return success(); + } +@@ -161,9 +220,17 @@ + PatternRewriter &rewriter) const override { + ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); + ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); ++ ArrayRef scatterIndicesBatchingDims = ++ dimNumbers.getScatterIndicesBatchingDims(); + if (inputBatchingDims.empty()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "scatter op has no batching dims"; ++ }); ++ } ++ ++ if (!op.getScatterIndices().getType().hasStaticShape()) { ++ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { ++ diag << "gather op has start indices with dynamic shape, can't expand"; + }); + } + +@@ -174,7 +241,7 @@ + inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); + Value newIndices = createConcatIndices( + op.getScatterIndices(), dimNumbers.getIndexVectorDim(), +- dimNumbers.getScatterIndicesBatchingDims(), rewriter); ++ scatterIndicesBatchingDims, rewriter); + auto newScatterOp = rewriter.create( + op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, + op.getUpdates(), +@@ -183,7 +250,10 @@ + newInsertedWindowDims, + /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, + newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), +- /*indicesAreSorted=*/false, op.getUniqueIndices()); ++ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), ++ scatterIndicesBatchingDims, ++ newScatterDimsToOperandDims), ++ op.getUniqueIndices()); + + newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); + rewriter.replaceOp(op, newScatterOp.getResults()); diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 8b137891791fe9..b9970316773eaf 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1 +1,617 @@ +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ++++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +@@ -69,7 +69,7 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> + func.return %0 : tensor<4x3x5x8xi32> + } +@@ -77,9 +77,9 @@ + // ----- + + // CHECK-LABEL: @gather_with_batching_no_index_vector_dim ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +-// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> + // CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ + // CHECK-SAME: dimension_numbers = #stablehlo.gather< +@@ -102,7 +102,7 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> + func.return %0 : tensor<4x3x5x8xi32> + } +@@ -133,9 +133,305 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> + func.return %0 : tensor<0x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<3x4x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 1 : tensor<3x4x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<3x4x5x1xi32>, tensor<3x4x5x1xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x4xi32>) -> tensor<3x4x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<3x4x5x8xi32> ++func.func @gather_batching_dims_indices_become_unsorted(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [0, 1], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> ++ func.return %0 : tensor<3x4x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted_2 ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_become_unsorted_2(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_remain_sorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = true, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [0, 2], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_remain_unsorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_remain_unsorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [0, 2], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_does_not_overflow_indices_type ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x127x5x1xi8> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x127x5x1xi8> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x127x5x1xi8>, tensor<4x127x5x1xi8>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x4xi8> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x4xi8>) -> tensor<4x127x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x127x5x8xi32> ++func.func @gather_batching_dims_does_not_overflow_indices_type(%arg0: tensor<127x2x4x7x9xi32>, %arg1: tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> ++ func.return %0 : tensor<4x127x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_signless_indices_type ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5x2xi8>) -> tensor<4x128x5x2xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x2xi32>) -> tensor<4x128x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x4xi32>) -> tensor<4x128x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> ++func.func @gather_batching_dim_overflows_signless_indices_type(%arg0: tensor<128x2x4x7x9xi32>, %arg1: tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> ++ func.return %0 : tensor<4x128x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_unsigned_indices_type ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<256x4x5x2xui8>) -> tensor<256x4x5x2xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<256x4x5x1xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<256x4x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim0]], %[[iota_dim1]], %[[convert]], dim = 3 : (tensor<256x4x5x1xi32>, tensor<256x4x5x1xi32>, tensor<256x4x5x2xi32>) -> tensor<256x4x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x4xi32>) -> tensor<256x4x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<256x4x5x8xi32> ++func.func @gather_batching_dim_overflows_unsigned_indices_type(%arg0: tensor<256x2x4x7x9xi32>, %arg1: tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [0, 1], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> ++ func.return %0 : tensor<256x4x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_indices_type_and_i32 ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x2xi64> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x2147483648x5x1xi64> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x2147483648x5x1xi64> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x2xi64>) -> tensor<4x2147483648x5x4xi64> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x4xi64>) -> tensor<4x2147483648x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x2147483648x5x8xi32> ++func.func @gather_batching_dim_overflows_indices_type_and_i32(%arg0: tensor<2147483648x2x4x7x9xi32>, %arg1: tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> ++ func.return %0 : tensor<4x2147483648x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_dynamic_size ++// CHECK: operand_batching_dims = [0, 2] ++// CHECK: start_indices_batching_dims = [1, 0] ++func.func @gather_batching_dim_dynamic_size(%arg0: tensor, %arg1: tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor, tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> ++ func.return %0 : tensor<4x?x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_and_no_index_vector_dim ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5xi8>) -> tensor<4x128x5xi32> ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %[[convert]] : (tensor<4x128x5xi32>) -> tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>) -> tensor<4x128x5x3xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], ++// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<128x2x4x9xi32>, tensor<4x128x5x3xi32>) -> tensor<4x128x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> ++func.func @gather_batching_dim_overflows_and_no_index_vector_dim(%arg0: tensor<128x2x4x9xi32>, %arg1: tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<128x2x4x9xi32>, tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> ++ func.return %0 : tensor<4x128x5x8xi32> + } + + // ----- +@@ -156,7 +452,7 @@ + // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ +- indices_are_sorted = true, ++ indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1, 3], +@@ -176,9 +472,9 @@ + // ----- + + // CHECK-LABEL: @scatter_with_batching_no_index_vector_dim ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +-// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> + // CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ + // CHECK-SAME: indices_are_sorted = false, +@@ -192,7 +488,7 @@ + // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ +- indices_are_sorted = true, ++ indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1], +@@ -208,3 +504,60 @@ + }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> + func.return %0 : tensor<3x2x4x9xi32> + } ++ ++// ----- ++ ++// CHECK-LABEL: @scatter_batching_dims_indices_remain_sorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ ++// CHECK-SAME: indices_are_sorted = true, ++// CHECK-SAME: dimension_numbers = #stablehlo.scatter< ++// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], ++// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: unique_indices = false}> ++// CHECK: (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> ++// CHECK-NEXT: return %[[scatter]] : tensor<2x5x4x7x9xi32> ++func.func @scatter_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>, %arg2: tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> { ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ ++ indices_are_sorted = true, ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [3], ++ inserted_window_dims = [2, 3], ++ input_batching_dims = [0, 1], ++ scatter_indices_batching_dims = [0, 2], ++ scatter_dims_to_operand_dims = [2, 3], ++ index_vector_dim = 3 ++ >, ++ unique_indices = false ++ }> ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ stablehlo.return %arg4 : tensor ++ }) : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> ++ func.return %0 : tensor<2x5x4x7x9xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @scatter_batching_dim_dynamic_scatter_indices ++// CHECK: input_batching_dims = [0, 2] ++// CHECK: scatter_indices_batching_dims = [1, 0] ++func.func @scatter_batching_dim_dynamic_scatter_indices(%arg0: tensor, %arg1: tensor<4x?x5x2xi32>, %arg2: tensor<4x?x5x8xi32>) -> tensor { ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ ++ indices_are_sorted = false, ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [3], ++ inserted_window_dims = [1, 3], ++ input_batching_dims = [0, 2], ++ scatter_indices_batching_dims = [1, 0], ++ scatter_dims_to_operand_dims = [1, 3], ++ index_vector_dim = 3 ++ >, ++ unique_indices = false ++ }> ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ stablehlo.return %arg4 : tensor ++ }) : (tensor, tensor<4x?x5x2xi32>, tensor<4x?x5x8xi32>) -> tensor ++ func.return %0 : tensor ++} +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ++++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +@@ -22,8 +22,11 @@ + #include "llvm/ADT/STLExtras.h" + #include "llvm/ADT/SmallVector.h" + #include "llvm/Support/ErrorHandling.h" ++#include "llvm/Support/MathExtras.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/BuiltinTypeInterfaces.h" + #include "mlir/IR/BuiltinTypes.h" + #include "mlir/IR/Diagnostics.h" + #include "mlir/IR/PatternMatch.h" +@@ -75,6 +78,42 @@ + return result; + } + ++bool fitsInIntegralType(int64_t size, IntegerType type) { ++ if (type.isUnsigned()) { ++ return llvm::isUIntN(type.getWidth(), size); ++ } else { ++ return llvm::isIntN(type.getWidth(), size); ++ } ++} ++ ++// If `type` is an integer type in which `size` doesn't fit, promote it to i32 ++// or i64 (depending on `size`). ++Type promoteTypeForSize(Type type, int64_t size, OpBuilder &builder) { ++ // Gather/Scatter should have an integer type, but we check just in case. ++ auto intType = dyn_cast(type); ++ if (!intType || fitsInIntegralType(size, intType)) { ++ return type; ++ } ++ if (fitsInIntegralType(size, builder.getI32Type())) { ++ return builder.getI32Type(); ++ } ++ return builder.getI64Type(); ++} ++ ++// If `indices_batching_dims` and `updated_index_map` are both sorted, then the ++// `indices_are_sorted` property is preserved. ++// ++// This is because each concatenated iota is monotonically increasing, sorted ++// indices batching dims mean their order corresponds to the order of batching ++// dims in the operand, and a sorted updated start index map means the order of ++// the index vector dim corresponds to the order of operand dims. ++bool getUpdatedIndicesAreSorted(bool indices_are_sorted, ++ ArrayRef indices_batching_dims, ++ ArrayRef updated_index_map) { ++ return indices_are_sorted && llvm::is_sorted(indices_batching_dims) && ++ llvm::is_sorted(updated_index_map); ++} ++ + // Returns an updated indices tensor such that an `IotaOp` is prepended for each + // dim in `indicesBatchingDims` with a `ConcatenateOp`. + // +@@ -85,16 +124,31 @@ + PatternRewriter &rewriter) { + Location loc = indices.getLoc(); + auto indicesType = cast(indices.getType()); ++ Type elementType = indicesType.getElementType(); ++ ++ // The batching dim sizes might not fit in the existing element type, ++ // in which case we need to promote it. ++ for (int64_t batchingDim : indicesBatchingDims) { ++ elementType = promoteTypeForSize( ++ elementType, indicesType.getDimSize(batchingDim), rewriter); ++ } ++ if (elementType != indicesType.getElementType()) { ++ indicesType = RankedTensorType::get(indicesType.getShape(), elementType); ++ indices = rewriter.create(loc, indicesType, indices); ++ } ++ + bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); +- + SmallVector iotaShape(indicesType.getShape()); + if (indexVectorDimOnLastDim) { + iotaShape.push_back(1); + } else { + iotaShape[indexVectorDim] = 1; + } +- auto iotaType = +- RankedTensorType::get(iotaShape, indicesType.getElementType()); ++ auto iotaType = RankedTensorType::get(iotaShape, elementType); ++ ++ if (indexVectorDimOnLastDim) { ++ indices = rewriter.create(loc, iotaType, indices); ++ } + + SmallVector indicesToConcat; + indicesToConcat.reserve(indicesBatchingDims.size() + 1); +@@ -102,12 +156,7 @@ + indicesToConcat.push_back( + rewriter.create(loc, iotaType, batchingDim)); + } +- if (indexVectorDimOnLastDim) { +- indicesToConcat.push_back( +- rewriter.create(loc, iotaType, indices)); +- } else { +- indicesToConcat.push_back(indices); +- } ++ indicesToConcat.push_back(indices); + return rewriter.create(loc, indicesToConcat, indexVectorDim); + } + +@@ -125,9 +174,17 @@ + PatternRewriter &rewriter) const override { + GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); + ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); ++ ArrayRef startIndicesBatchingDims = ++ dimNumbers.getStartIndicesBatchingDims(); + if (operandBatchingDims.empty()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "gather op has no batching dims"; ++ }); ++ } ++ ++ if (!op.getStartIndices().getType().hasStaticShape()) { ++ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { ++ diag << "gather op has start indices with dynamic shape, can't expand"; + }); + } + +@@ -136,16 +193,18 @@ + SmallVector newStartIndexMap = + llvm::to_vector(llvm::concat( + operandBatchingDims, dimNumbers.getStartIndexMap())); +- Value newIndices = createConcatIndices( +- op.getStartIndices(), dimNumbers.getIndexVectorDim(), +- dimNumbers.getStartIndicesBatchingDims(), rewriter); ++ Value newIndices = createConcatIndices(op.getStartIndices(), ++ dimNumbers.getIndexVectorDim(), ++ startIndicesBatchingDims, rewriter); + rewriter.replaceOpWithNewOp( + op, op.getOperand(), newIndices, + GatherDimensionNumbersAttr::get( + op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, + /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, + newStartIndexMap, dimNumbers.getIndexVectorDim()), +- op.getSliceSizes(), /*indicesAreSorted=*/false); ++ op.getSliceSizes(), ++ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), ++ startIndicesBatchingDims, newStartIndexMap)); + + return success(); + } +@@ -161,9 +220,17 @@ + PatternRewriter &rewriter) const override { + ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); + ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); ++ ArrayRef scatterIndicesBatchingDims = ++ dimNumbers.getScatterIndicesBatchingDims(); + if (inputBatchingDims.empty()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "scatter op has no batching dims"; ++ }); ++ } ++ ++ if (!op.getScatterIndices().getType().hasStaticShape()) { ++ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { ++ diag << "gather op has start indices with dynamic shape, can't expand"; + }); + } + +@@ -174,7 +241,7 @@ + inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); + Value newIndices = createConcatIndices( + op.getScatterIndices(), dimNumbers.getIndexVectorDim(), +- dimNumbers.getScatterIndicesBatchingDims(), rewriter); ++ scatterIndicesBatchingDims, rewriter); + auto newScatterOp = rewriter.create( + op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, + op.getUpdates(), +@@ -183,7 +250,10 @@ + newInsertedWindowDims, + /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, + newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), +- /*indicesAreSorted=*/false, op.getUniqueIndices()); ++ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), ++ scatterIndicesBatchingDims, ++ newScatterDimsToOperandDims), ++ op.getUniqueIndices()); + + newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); + rewriter.replaceOp(op, newScatterOp.getResults()); From 99d795aba355f4303cb9a4c9ba0c81becfe32f45 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 16:09:59 -0700 Subject: [PATCH 445/483] Reverts 1b240fae66605d39e14d4459450664c0ac20e97f PiperOrigin-RevId: 680754448 --- tensorflow/compiler/mlir/lite/BUILD | 1 - tensorflow/compiler/mlir/lite/stablehlo/BUILD | 93 +- .../mlir/lite/stablehlo/odml_to_stablehlo.cc | 1 - .../lite/stablehlo/tests/legalize-tf.mlir | 2532 ------ .../lite/stablehlo/transforms/legalize_tf.cc | 6911 ----------------- .../stablehlo/transforms/legalize_tf_passes.h | 51 - .../transforms/legalize_tf_patterns.td | 802 -- .../stablehlo/transforms/tf_stablehlo_pass.cc | 4 +- .../mlir/lite/stablehlo/transforms/utils.cc | 55 - .../mlir/lite/stablehlo/transforms/utils.h | 61 - .../lite/stablehlo/transforms/utils_test.cc | 83 - .../mlir/lite/transforms/prepare_tf.cc | 1 - .../mlir/tf2xla/tests/legalize-tf.mlir | 3810 +++++++++ 13 files changed, 3814 insertions(+), 10591 deletions(-) delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 3684b75dd13bbe..7bb70a19f4f116 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -906,7 +906,6 @@ cc_library( "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/lite/stablehlo:optimize_layout", "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 53d08196547f42..a0c3febeead92f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -1,6 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -120,92 +119,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "legalize_utils", - srcs = ["transforms/utils.cc"], - hdrs = ["transforms/utils.h"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@local_xla//xla/mlir_hlo", - ], -) - -tf_cc_test( - name = "legalize_utils_test", - srcs = ["transforms/utils_test.cc"], - deps = [ - ":legalize_utils", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@local_xla//xla/mlir_hlo", - ], -) - -gentbl_cc_library( - name = "legalize_tf_patterns_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tf.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "transforms/legalize_tf_patterns.td", - deps = [ - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncTdFiles", - "@llvm-project//mlir:TensorOpsTdFiles", - "@local_xla//xla/mlir_hlo:hlo_ops_td_files", - ], -) - -cc_library( - name = "legalize_tf", - srcs = [ - "transforms/generated_legalize_tf.inc", - "transforms/legalize_tf.cc", - ], - hdrs = [ - "transforms/legalize_tf_passes.h", - ], - deps = [ - ":legalize_tf_patterns_inc_gen", - ":legalize_utils", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", - "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", - "//tensorflow/core:framework", - "//tensorflow/core/kernels:conv_grad_shape_utils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@local_tsl//tsl/platform:bfloat16", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:sharding_builder", - "@local_xla//xla/client/lib:conv_grad_size_util", - "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", - "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:convert_op_folder", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", - "@stablehlo//:chlo_ops", - ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), -) - cc_library( name = "tf_stablehlo", srcs = [ @@ -218,7 +131,6 @@ cc_library( "-Ithird_party", ], deps = [ - ":legalize_tf", ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", @@ -1041,7 +953,6 @@ tf_cc_binary( " [tf.lite.OpsSet.EXPERIMENTAL_STABLEHLO_OPS]", deps = [ ":check_accepted_ops_pass", - ":legalize_tf", ":op_stat_pass", ":stablehlo_util", ":transforms", @@ -1058,7 +969,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", - "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", "//tensorflow/core/ir/types:Dialect", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index bd18a351bd86e8..c2579fb3619911 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -56,7 +56,6 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir deleted file mode 100644 index bc2ce85d20f9f2..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir +++ /dev/null @@ -1,2532 +0,0 @@ -// RUN: odml-to-stablehlo-opt --tf-stablehlo \ -// RUN: %s | FILECHECK_OPTS="" FileCheck %s - -//===----------------------------------------------------------------------===// -// BatchNorm op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same -// code), so only do a couple of basic checks. - -// CHECK-LABEL: fusedBatchNormV2_noTraining -func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV2_training -func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_noTraining -func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision -// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) -func.func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { - // CHECK: [[DUMMY:%.*]] = stablehlo.constant dense<0.000000e+00> : tensor<0xf32> - // CHECK: [[CONVERT_X:%.*]] = stablehlo.convert [[X]] : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK: [[Y:%.*]] = "stablehlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) - // CHECK: [[Y_CONVERT:%.*]] = stablehlo.convert [[Y]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> - // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] - func.return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_training -func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance -func.func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK: return %[[VAR]] - func.return %0#4 : tensor<8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor -func.func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { - // CHECK-DAG: %[[ALPHA:.*]] = stablehlo.constant dense<0.199999988> - // CHECK-DAG: %[[BETA:.*]] = stablehlo.constant dense<8.000000e-01> - // CHECK-DAG: %[[FACTOR:.*]] = stablehlo.constant dense<1.00195694> - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK: %[[CORRECTED_VAR:.*]] = stablehlo.multiply %[[VAR]], %[[FACTOR]] - - // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = stablehlo.multiply %arg3, %[[ALPHA]] - // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = stablehlo.multiply %[[MEAN]], %[[BETA]] - // CHECK: %[[NEW_BATCH_MEAN:.*]] = stablehlo.add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] - - // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = stablehlo.multiply %arg4, %[[ALPHA]] - // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = stablehlo.multiply %[[CORRECTED_VAR]], %[[BETA]] - // CHECK: %[[NEW_BATCH_VAR:.*]] = stablehlo.add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] - - // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[MEAN]], %[[VAR]] - func.return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision -func.func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK: stablehlo.convert %arg0 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK: stablehlo.convert {{.*}} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - func.return %0#0 : tensor<8x8x8x8xbf16> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_NCHW -func.func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_NDHWC -func.func @fusedBatchNormV3_NDHWC(%arg0: tensor<8x8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>) { - // CHECK: feature_index = 4 : i64 - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NDHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported -func.func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { - // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor, tensor, tensor, tensor, tensor) -> tensor - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) - func.return %0#0 : tensor -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1 -func.func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { - // CHECK: tf.FusedBatchNormV3 - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) - func.return %0#0 : tensor -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2 -func.func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor) { - // CHECK: tf.FusedBatchNormV3 - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) - func.return %0#0 : tensor -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGrad_noTraining -func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> - - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> - // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> - - // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> - // CHECK-NEXT: %[[BCAST_MUL2:.+]] = stablehlo.broadcast_in_dim %[[MUL2]], {{.*}} : (tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[MUL3:.*]] = stablehlo.multiply %arg0, %[[BCAST_MUL2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[MUL3]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGrad_Training -func.func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) {{.*}} - // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV2_noTraining -func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> - // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> - // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[MUL2:.*]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV2_Training -func.func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) {{.*}} - // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision -func.func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEST: %[[CST:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> - // CHECK-NEST: %[[ADD:.*]] = stablehlo.add %arg4, %[[CST]] : tensor<8xf32> - // CHECK-NEST: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> - // CHECK-NEST: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> - // CHECK-NEST: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEST: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> - // CHECK-NEST: %[[CONVERT:.*]] = stablehlo.convert %[[MUL2]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK-NEST: return %[[CONVERT]] : tensor<8x8x8x8xbf16> - - %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xbf16> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision -func.func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%[[CONVERT]], %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %[[GRAD_OPERAND]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK-NEXT: return %[[CONVERT]] : tensor<8x8x8x8xbf16> - - %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xbf16> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV3_noTraining -func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { -// CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> -// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> -// CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> -// CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> -// CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> -// CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> -// CHECK-NEXT: return %[[MUL2]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV3_Training -func.func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { - // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<0xf32> - // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[EPS]] : tensor<0xf32> to tensor<*xf32> - // CHECK-NEXT: return %[[GRAD_OPERAND]], %[[EPS]], %[[CAST]] : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> - - %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) - func.return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision -func.func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> - // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> - // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %[[MUL2]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK-NEXT: return %[[CONVERT]] : tensor<8x8x8x8xbf16> - - %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xbf16> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision -func.func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%[[CONVERT]], %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-NEXT: %[[CONVERT2:.*]] = stablehlo.convert %[[GRAD_OPERAND]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK-NEXT: return %[[CONVERT2]] : tensor<8x8x8x8xbf16> - - %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xbf16> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW -func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> - // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> - // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [1] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[MUL2]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW -func.func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> - %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -//===----------------------------------------------------------------------===// -// Bias op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @biasAdd_default -func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> - // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> - %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> - func.return %0 : tensor<1x32x10x32xi32> -} - -// ----- - -// CHECK-LABEL: func @biasAdd_NHWC -func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> - // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> - %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> - func.return %0 : tensor<1x32x10x32xi32> -} - -// ----- - -// CHECK-LABEL: func @biasAdd_NCHW -func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> - // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> - %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> - func.return %0 : tensor<1x32x10x32xi32> -} - -// ----- - -// CHECK-LABEL: func @biasAdd_dynamic -func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor -> tensor<4xindex> - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[SHAPE]], dims = [1] : (tensor, tensor<4xindex>) -> tensor - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor - // CHECK-NEXT: return %[[ADD]] : tensor - %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @biasAdd_partial_dynamic -func.func @biasAdd_partial_dynamic(%arg0: tensor, %arg1: tensor<512xi32>) -> tensor { - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor -> tensor<4xindex> - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[SHAPE]], dims = [3] : (tensor<512xi32>, tensor<4xindex>) -> tensor - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor - // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ADD]] : tensor to tensor - // CHECK-NEXT: return %[[CAST]] : tensor - %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor, tensor<512xi32>) -> tensor - func.return %0 : tensor -} - - -//===----------------------------------------------------------------------===// -// ClipByValue -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @clip -func.func @clip(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - // CHECK: [[VAL:%.+]] = stablehlo.clamp %arg1, %arg0, %arg2 - - %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - // CHECK: return [[VAL]] - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @clip_dynamic -func.func @clip_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp %arg1, %arg0, %arg2 - %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - - // CHECK: return [[CLAMP]] - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @clip_static_broadcast -func.func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<5xf32> { - // CHECK-DAG: [[BROADCAST_MIN:%.+]] = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor) -> tensor<5xf32> - // CHECK-DAG: [[BROADCAST_MAX:%.+]] = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<5xf32> - // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] - %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor, tensor) -> tensor<5xf32> - - // CHECK: return [[CLAMP]] - func.return %0 : tensor<5xf32> -} - - -// CHECK-LABEL: @clip_dynamic_broadcast -func.func @clip_dynamic_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - // CHECK: [[SHP:%.+]] = shape.shape_of %arg0 - // CHECK: [[SHPIDX:%.+]] = arith.index_cast [[SHP]] : tensor<1xindex> to tensor<1xi32> - // CHECK-DAG: [[BROADCAST_MIN:%.+]] = stablehlo.dynamic_broadcast_in_dim %arg1, [[SHPIDX]], dims = [] : (tensor, tensor<1xi32>) -> tensor - // CHECK-DAG: [[BROADCAST_MAX:%.+]] = stablehlo.dynamic_broadcast_in_dim %arg2, [[SHPIDX]], dims = [] : (tensor, tensor<1xi32>) -> tensor - // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] - %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - - // CHECK: return [[CLAMP]] - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// DiagPart -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @diag_part -// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> -func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { - // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<12x12xf32> - // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> - // CHECK-NEXT: %[[IOTA:.*]] = stablehlo.iota dim = 0 : tensor<12xi32> - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[IOTA]], dims = [0] : (tensor<12xi32>) -> tensor<12x12xi32> - // CHECK-NEXT: %[[IOTA2:.*]] = stablehlo.iota dim = 0 : tensor<12xi32> - // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[IOTA2]], dims = [1] : (tensor<12xi32>) -> tensor<12x12xi32> - // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare EQ, %[[BCAST]], %[[BCAST2]], NOTYPE : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> - // CHECK-NEXT: %[[SEL:.*]] = stablehlo.select %[[CMP]], %[[RESHAPE]], %[[CST0]] : tensor<12x12xi1>, tensor<12x12xf32> - // CHECK-NEXT: %[[REDUCE:.*]] = stablehlo.reduce(%[[SEL]] init: %[[CST1]]) applies stablehlo.add across dimensions = [0] : (tensor<12x12xf32>, tensor) -> tensor<12xf32> - // CHECK-NEXT: %[[RESHAPE2:.*]] = stablehlo.reshape %[[REDUCE]] : (tensor<12xf32>) -> tensor<4x3xf32> - // CHECK-NEXT: return %[[RESHAPE2]] : tensor<4x3xf32> - - %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32> - func.return %0: tensor<4x3xf32> -} - -//===----------------------------------------------------------------------===// -// MatrixDiagPart -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @matrix_diag_part -// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32> -func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { - // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x22x128xi32> - // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x22x128xi32> - // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x22x128xi32> - // CHECK-NEXT: %[[CST3:.*]] = stablehlo.constant dense<11> : tensor<1x22x128xi32> - // CHECK-NEXT: %[[CST4:.*]] = stablehlo.constant dense<0> : tensor<1x22x128xi32> - // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<22xi32> - // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<22xi32>) -> tensor<1x22x128xi32> - // CHECK-NEXT: %[[IOTA1:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> - // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %[[IOTA1]], dims = [2] : (tensor<128xi32>) -> tensor<1x22x128xi32> - // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[CST3]], %[[BCAST0]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[NEG0:.*]] = stablehlo.negate %[[SUB0]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[MIN0:.*]] = stablehlo.minimum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %[[MIN0]], %[[CST2]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[MAX0:.*]] = stablehlo.maximum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[SUB1:.*]] = stablehlo.subtract %[[CST1]], %[[MAX0]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[MIN1:.*]] = stablehlo.minimum %[[ADD0]], %[[SUB1]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[SUB0]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[CST1]], %[[MIN1]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[CMP0]], %[[SUB2]], %[[CST4]] : tensor<1x22x128xi1>, tensor<1x22x128xi32> - // CHECK-NEXT: %[[MAX1:.*]] = stablehlo.maximum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[MAX1]], %[[SELECT0]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[MAX2:.*]] = stablehlo.maximum %[[NEG0]], %[[CST4]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[SUB3:.*]] = stablehlo.subtract %[[MAX2]], %[[SELECT0]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %[[BCAST1]], %[[SUB2]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[ADD2:.*]] = stablehlo.add %[[BCAST1]], %[[SUB3]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare GE, %[[ADD1]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %[[ADD1]], %[[CST1]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP1]], %[[CMP2]] : tensor<1x22x128xi1> - // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare GE, %[[ADD2]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[CMP4:.*]] = stablehlo.compare LT, %[[ADD2]], %[[CST2]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP3]], %[[CMP4]] : tensor<1x22x128xi1> - // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x22x128xi1> - // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> - // CHECK-NEXT: %[[CONCAT0:.*]] = stablehlo.concatenate %[[ADD2]], %[[ADD1]], dim = 0 : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> - // CHECK-NEXT: %[[GATHER0:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT0]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> - // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast %[[RESHAPE0]], sizes = [7] : (tensor<22x128xi1>) -> tensor<7x22x128xi1> - // CHECK-NEXT: %[[SELECT1:.*]] = stablehlo.select %[[BCAST1]], %[[GATHER0]], %[[CST0]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> - // CHECK-NEXT: return %[[SELECT1]] : tensor<7x22x128xi32> - - %0 = mhlo.constant dense<42> : tensor // padding value - %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k - %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { - T = i32, align = "RIGHT_LEFT" - } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> - func.return %2: tensor<7x22x128xi32> -} - -// ----- - -// CHECK-LABEL: func @matrix_diag_part_zero_dim_complex -func.func @matrix_diag_part_zero_dim_complex(%arg0: tensor<4x0xcomplex>) -> tensor<0xcomplex> { - %cst = "tf.Const"() {value = dense<-3> : tensor} : () -> tensor - %cst_0 = "tf.Const"() {value = dense<(0.000000e+00,0.000000e+00)> : tensor>} : () -> tensor> - %0 = "tf.MatrixDiagPartV3"(%arg0, %cst, %cst_0) {align = "RIGHT_LEFT", device = ""} : (tensor<4x0xcomplex>, tensor, tensor>) -> tensor<0xcomplex> - // CHECK: return %{{[0-9]*}} : tensor<0xcomplex> - return %0 : tensor<0xcomplex> -} - -// ----- - -// CHECK-LABEL: func @matrix_diag_part_single_diagonal -func.func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> { - // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x1x128xi32> - // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x1x128xi32> - // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x1x128xi32> - // CHECK-NEXT: %[[FALSE:.*]] = stablehlo.constant dense<0> : tensor<1x1x128xi32> - // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> - // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[IOTA0]] : (tensor<128xi32>) -> tensor<1x1x128xi32> - // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[RESHAPE0]], %[[FALSE]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> - // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %[[RESHAPE0]], %[[CST1]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> - // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP1]] : tensor<1x1x128xi1> - // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare GE, %[[RESHAPE0]], %[[FALSE]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> - // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare LT, %[[RESHAPE0]], %[[CST2]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> - // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP2]], %[[CMP3]] : tensor<1x1x128xi1> - // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x1x128xi1> - // CHECK-NEXT: %[[RESHAPE1:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x1x128xi1>) -> tensor<1x128xi1> - // CHECK-NEXT: %[[CONCAT:.*]] = stablehlo.concatenate %[[RESHAPE0]], %[[RESHAPE0]], dim = 0 : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<2x1x128xi32> - // CHECK-NEXT: %[[GATHER:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x1x128xi32>) -> tensor<7x1x128xi32> - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast %[[RESHAPE1]], sizes = [7] : (tensor<1x128xi1>) -> tensor<7x1x128xi1> - // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[BCAST]], %[[GATHER]], %[[CST0]] : tensor<7x1x128xi1>, tensor<7x1x128xi32> - // CHECK-NEXT: %[[RESHAPE2:.*]] = stablehlo.reshape %[[SELECT0]] : (tensor<7x1x128xi32>) -> tensor<7x128xi32> - // CHECK-NEXT: return %[[RESHAPE2]] : tensor<7x128xi32> - - %0 = mhlo.constant dense<42> : tensor // padding value - %1 = mhlo.constant dense<0> : tensor<2xi32> // k - %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { - T = i32, align = "RIGHT_LEFT" - } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x128xi32> - func.return %2: tensor<7x128xi32> -} - -// ----- - -// CHECK-LABEL: func @matrix_diag_part_align_ll -func.func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { - // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x22x128xi32> - // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x22x128xi32> - // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x22x128xi32> - // CHECK-NEXT: %[[CST3:.*]] = stablehlo.constant dense<11> : tensor<1x22x128xi32> - // CHECK-NEXT: %[[FALSE:.*]] = stablehlo.constant dense<0> : tensor<1x22x128xi32> - // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<22xi32> - // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<22xi32>) -> tensor<1x22x128xi32> - // CHECK-NEXT: %[[IOTA1:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> - // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %[[IOTA1]], dims = [2] : (tensor<128xi32>) -> tensor<1x22x128xi32> - // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[CST3]], %[[BCAST0]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[NEG0:.*]] = stablehlo.negate %[[SUB0]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[MAX0:.*]] = stablehlo.maximum %[[SUB0]], %[[FALSE]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[SUB1:.*]] = stablehlo.subtract %[[MAX0]], %[[FALSE]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[MAX1:.*]] = stablehlo.maximum %[[NEG0]], %[[FALSE]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[MAX1]], %[[FALSE]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %[[BCAST1]], %[[SUB1]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %[[BCAST1]], %[[SUB2]] : tensor<1x22x128xi32> - // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[ADD0]], %[[FALSE]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %[[ADD0]], %[[CST1]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP1]] : tensor<1x22x128xi1> - // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare GE, %[[ADD1]], %[[FALSE]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare LT, %[[ADD1]], %[[CST2]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP2]], %[[CMP3]] : tensor<1x22x128xi1> - // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x22x128xi1> - // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> - // CHECK-NEXT: %[[CONCAT0:.*]] = stablehlo.concatenate %[[ADD1]], %[[ADD0]], dim = 0 : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> - // CHECK-NEXT: %[[GATHER0:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT0]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> - // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast %[[RESHAPE0]], sizes = [7] : (tensor<22x128xi1>) -> tensor<7x22x128xi1> - // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[BCAST2]], %[[GATHER0]], %[[CST0]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> - // CHECK-NEXT: return %[[SELECT0]] : tensor<7x22x128xi32> - - %0 = mhlo.constant dense<42> : tensor // padding value - %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k - %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { - T = i32, align = "LEFT_LEFT" - } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> - func.return %2: tensor<7x22x128xi32> -} - -// ----- - -// CHECK-LABEL: func @matrix_diag_part_align_lr -func.func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { - %0 = mhlo.constant dense<42> : tensor // padding value - %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k - %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { - T = i32, align = "LEFT_RIGHT" - } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> - // CHECK: %[[LE:.*]] = stablehlo.compare LE, %{{.*}}, %{{.*}} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK: %{{.*}} = stablehlo.select %[[LE]], %{{.*}}, %{{.*}} : tensor<1x22x128xi1>, tensor<1x22x128xi32> - func.return %2: tensor<7x22x128xi32> -} - -// ----- - -// CHECK-LABEL: func @matrix_diag_part_align_rl -func.func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { - %0 = mhlo.constant dense<42> : tensor // padding value - %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k - %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { - T = i32, align = "RIGHT_LEFT" - } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> - // CHECK: %[[GE:.*]] = stablehlo.compare GE, %{{.*}}, %{{.*}} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> - // CHECK: %{{.*}} = stablehlo.select %[[GE]], %{{.*}}, %{{.*}} : tensor<1x22x128xi1>, tensor<1x22x128xi32> - - func.return %2: tensor<7x22x128xi32> -} - -// ----- - -// CHECK-LABEL: func @matrix_diag_part_align_rr -func.func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { - %0 = mhlo.constant dense<42> : tensor // padding value - %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k - %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { - T = i32, align = "RIGHT_RIGHT" - } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> - // CHECK-NOT: MatrixDiagPartV3 - func.return %2: tensor<7x22x128xi32> -} - -// ----- - -// CHECK-LABEL: func @matrix_diag_part_align_7d -// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> -func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> { - %0 = mhlo.constant dense<-1.> : tensor // padding value - %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32> // k - %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { - T = f32, align = "LEFT_RIGHT" - } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor) -> tensor<3x5x7x9x11x4x10xf32> - func.return %2: tensor<3x5x7x9x11x4x10xf32> -} - -//===----------------------------------------------------------------------===// -// Erf -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @erf -func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: mhlo.erf(%arg0) {{.*}} : (tensor<2x3xf32>) -> tensor<2x3xf32> - %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> - func.return %0 : tensor<2x3xf32> -} - -//===----------------------------------------------------------------------===// -// Erfc -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @erfc -func.func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK-NOT: tf.Erfc - %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> - func.return %0 : tensor<2x3xf32> -} - -//===----------------------------------------------------------------------===// -// Einsum. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @einsum -func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { - // CHECK: stablehlo.einsum - %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> - func.return %0: tensor<2x4xf32> -} - -// ----- - -// CHECK-LABEL: func @unary_einsum -func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { - // CHECK: stablehlo.constant{{.*}}1.000000e+00 - // CHECK: stablehlo.einsum{{.*}}",ab->aa" - %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> - func.return %0: tensor<2x2xf32> -} - -//===----------------------------------------------------------------------===// -// FloorDiv and FloorMod. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @floordiv_broadcast_i32 -func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2x3xi32> - // CHECK-NEXT: %[[ZEROS0:.*]] = stablehlo.constant dense<0> : tensor<3xi32> - // CHECK-NEXT: %[[ZEROS1:.*]] = stablehlo.constant dense<0> : tensor<2x3xi32> - // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> - // CHECK-NEXT: %[[DIV0:.*]] = stablehlo.divide %arg0, %[[BCAST0]] : tensor<2x3xi32> - // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> - // CHECK-NEXT: %[[MUL0:.*]] = stablehlo.multiply %[[DIV0]], %[[BCAST1]] : tensor<2x3xi32> - // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare NE, %[[MUL0]], %arg0 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %arg0, %[[ZEROS1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %arg1, %[[ZEROS0]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[CMP2]], dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1> - // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare NE, %[[CMP1]], %[[BCAST2]] : (tensor<2x3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP3]] : tensor<2x3xi1> - // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[DIV0]], %[[ONES]] : tensor<2x3xi32> - // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[AND0]], %[[SUB0]], %[[DIV0]] : tensor<2x3xi1>, tensor<2x3xi32> - // CHECK-NEXT: return %[[SELECT0]] : tensor<2x3xi32> - - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> - func.return %0: tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @floordiv_reverse_broadcast_i32 -func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2x3xi32> - // CHECK-NEXT: %[[ZEROS0:.*]] = stablehlo.constant dense<0> : tensor<2x3xi32> - // CHECK-NEXT: %[[ZEROS1:.*]] = stablehlo.constant dense<0> : tensor<3xi32> - // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> - // CHECK-NEXT: %[[DIV0:.*]] = stablehlo.divide %[[BCAST0]], %arg1 : tensor<2x3xi32> - // CHECK-NEXT: %[[MUL0:.*]] = stablehlo.multiply %[[DIV0]], %arg1 : tensor<2x3xi32> - // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> - // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare NE, %[[MUL0]], %[[BCAST1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %arg0, %[[ZEROS1]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %arg1, %[[ZEROS0]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[CMP1]], dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1> - // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare NE, %[[BCAST2]], %[[CMP2]] : (tensor<2x3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP3]] : tensor<2x3xi1> - // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[DIV0]], %[[ONES]] : tensor<2x3xi32> - // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[AND0]], %[[SUB0]], %[[DIV0]] : tensor<2x3xi1>, tensor<2x3xi32> - // CHECK-NEXT: return %[[SELECT0]] : tensor<2x3xi32> - - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - func.return %0: tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @floordiv_f32 -func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = stablehlo.divide %arg0, %arg0 - // CHECK-NEXT: %[[FLOOR:.*]] = stablehlo.floor %[[DIV]] - // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> - %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %0: tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @floordiv_bf16 -func.func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { - // CHECK-NEXT: stablehlo.convert - // CHECK-NEXT: stablehlo.convert - // CHECK-NEXT: stablehlo.divide - // CHECK-NEXT: stablehlo.floor - // CHECK-NEXT: stablehlo.convert - // CHECK-NEXT: return - %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> - func.return %0: tensor<2xbf16> -} - -// ----- - -// CHECK-LABEL: func @floordiv_f16_broadcast -func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: stablehlo.broadcast_in_dim - // CHECK-NEXT: stablehlo.divide - // CHECK-NEXT: stablehlo.floor - // CHECK-NEXT: return - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - func.return %0: tensor<2x3xf16> -} - -// ----- - -// CHECK-LABEL: func @floordiv_dynamic -func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.divide - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.multiply - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.compare - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.and - // - // CHECK: %[[SELECT:.*]] = stablehlo.select - // CHECK: return %[[SELECT]] - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @floordiv_unsigned -func.func @floordiv_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: %[[RESULT:.*]] = shape.assuming - // CHECK: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, - // CHECK: %[[BCAST1:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, - // CHECK: %[[DIV:.*]] = stablehlo.divide %[[BCAST0]], %[[BCAST1]] - // CHECK: shape.assuming_yield %[[DIV]] - // CHECK: return %[[RESULT]] - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @floordiv_int -func.func @floordiv_int(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.divide - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.multiply - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.compare - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.compare - // CHECK: shape.assuming - // CHECK: stablehlo.dynamic_broadcast_in_dim - // CHECK: stablehlo.and - // - // CHECK: %[[SELECT:.*]] = stablehlo.select - // CHECK: return %[[SELECT]] - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @floormod_broadcast_numerator -func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> - // CHECK: %[[REM:.*]] = stablehlo.remainder %[[BCAST0]], %arg1 : tensor<2x3xi32> - // CHECK: %[[AND:.*]] = stablehlo.and - // CHECK: %[[ADD:.*]] = stablehlo.add - // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] - // CHECK-NEXT: return %[[SELECT]] - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - func.return %0: tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @floormod_broadcast_denominator -func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> - // CHECK: %[[REM:.*]] = stablehlo.remainder %arg0, %[[BCAST0]] - // CHECK: %[[AND:.*]] = stablehlo.and - // CHECK: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> - // CHECK: %[[ADD:.*]] = stablehlo.add %[[BCAST1]], %[[REM]] - // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] - // CHECK-NEXT: return %[[SELECT]] - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> - func.return %0: tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @floormod_unsigned_broadcast_denominator -func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> { - // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xui32>) -> tensor<2x3xui32> - // CHECK-NEXT: %[[REM:.*]] = stablehlo.remainder %arg0, %[[BCAST0]] : tensor<2x3xui32> - // CHECK-NEXT: return %[[REM]] : tensor<2x3xui32> - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32> - func.return %0: tensor<2x3xui32> -} - -// ----- - -// CHECK-LABEL: func @floormod_dynamic_broadcast_numerator -func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: %[[REM:.*]] = shape.assuming {{.*}} { - // CHECK: stablehlo.remainder - // CHECK: shape.assuming {{.*}} { - // CHECK: stablehlo.compare - // CHECK: %[[AND:.*]] = shape.assuming {{.*}} { - // CHECK: stablehlo.and - // CHECK: %[[ADD:.*]] = shape.assuming {{.*}} { - // CHECK: stablehlo.add - // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] - // CHECK-NEXT: return %[[SELECT]] - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @floormod_dynamic_broadcast_denominator -func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NOT: tf.FloorMod - // CHECK: %[[REM:.*]] = shape.assuming {{.*}} { - // CHECK: stablehlo.remainder - // CHECK: shape.assuming {{.*}} { - // CHECK: stablehlo.compare - // CHECK: %[[AND:.*]] = shape.assuming {{.*}} { - // CHECK: stablehlo.and - // CHECK: %[[ADD:.*]] = shape.assuming {{.*}} { - // CHECK: stablehlo.add - // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] - // CHECK-NEXT: return %[[SELECT]] - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -//===----------------------------------------------------------------------===// -// OnesLike -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @ones_like -// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) -func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { - // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1.000000e+00> : tensor - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<2x?xf32> -> tensor<2xindex> - // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ONES]], %[[SHAPE]], dims = [] : (tensor, tensor<2xindex>) -> tensor<2x?xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x?xf32> - %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> - func.return %0 : tensor<2x?xf32> -} - -//===----------------------------------------------------------------------===// -// ZerosLike -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @zeros_like -// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) -func.func @zeros_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { - // CHECK-NEXT: %[[ZEROS:.*]] = stablehlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<2x?xf32> -> tensor<2xindex> - // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZEROS]], %[[SHAPE]], dims = [] : (tensor, tensor<2xindex>) -> tensor<2x?xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x?xf32> - %0 = "tf.ZerosLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> - func.return %0 : tensor<2x?xf32> -} - -//===----------------------------------------------------------------------===// -// BroadcastTo. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @broadcast_to -func.func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { - %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> - // CHECK: stablehlo.broadcast_in_dim %arg0 - %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> - func.return %0 : tensor<16x16x16x16xf32> -} - -//===----------------------------------------------------------------------===// -// Complex op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @complex -func.func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { - // CHECK: stablehlo.complex - %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> - func.return %1 : tensor<3xcomplex> -} - -// ----- - -// CHECK-LABEL: func @imag -func.func @imag(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { - // CHECK: stablehlo.imag - %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex>) -> tensor<3xf32> - func.return %1 : tensor<3xf32> -} - -// ----- - -// CHECK-LABEL: func @real -func.func @real(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { - // CHECK: stablehlo.real - %1 = "tf.Real"(%arg0) : (tensor<3xcomplex>) -> tensor<3xf32> - func.return %1 : tensor<3xf32> -} - -//===----------------------------------------------------------------------===// -// Concat op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @concat_v2 -func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 0 - %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> - func.return %1 : tensor<6x3xf32> -} - -// ----- - -// CHECK-LABEL: func @concat_v2_neg_axis -func.func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 0 - - %axis = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor - %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> - func.return %1 : tensor<6x3xf32> -} - -// ----- - -// CHECK-LABEL: func @concat_v2_1d_axis -func.func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { - // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 1 - - %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> - %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> - func.return %1 : tensor<3x6xf32> -} - -// ----- - -// CHECK-LABEL: func @concat_v2_non_const_axis -module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 12 : i32}} { -func.func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor) -> tensor<3x6xf32> { - // CHECK: "tf.ConcatV2" - %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> - func.return %1 : tensor<3x6xf32> -} -} - -//===----------------------------------------------------------------------===// -// Pad op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @padv2_1D -func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { - %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> - // CHECK: stablehlo.pad %arg0, %arg1, low = [1], high = [2], interior = [0] - %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor) -> tensor<6xf32> - func.return %1 : tensor<6xf32> -} - -// ----- - -// CHECK-LABEL: func @padv2_2D -func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { - %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> - // CHECK: stablehlo.pad %arg0, %arg1, low = [1, 3], high = [2, 4], interior = [0, 0] - %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor) -> tensor<6x9xf32> - func.return %1 : tensor<6x9xf32> -} - -// ----- - -// CHECK-LABEL: func @padv2_i32_paddings -func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { - %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> - // CHECK: stablehlo.pad %arg0, %arg1, low = [1, 3], high = [2, 4], interior = [0, 0] - %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor) -> tensor<6x9xf32> - func.return %1 : tensor<6x9xf32> -} - -// ----- - -// CHECK-LABEL: func @padv2_dynamic -func.func @padv2_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor<1x2xi64>) -> tensor { - // CHECK-NEXT: %[[ZEROS:.*]] = stablehlo.constant dense<0> : tensor<1xi64> - // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg2 : (tensor<1x2xi64>) -> tensor<2xi64> - // CHECK-NEXT: %[[SLICE0:.*]] = stablehlo.slice %[[RESHAPE]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> - // CHECK-NEXT: %[[SLICE1:.*]] = stablehlo.slice %[[RESHAPE]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> - // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_pad %arg0, %arg1, %[[SLICE0]], %[[SLICE1]], %[[ZEROS]] : (tensor, tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - // CHECK-NEXT: return %[[RESULT]] : tensor - - %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor, tensor<1x2xi64>, tensor) -> tensor - func.return %1 : tensor -} - -//===----------------------------------------------------------------------===// -// Identity op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @identity -func.func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { - // CHECK-NEXT: return %arg0 : tensor<1xi32> - %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> - func.return %0: tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: func @identityN -func.func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { - // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> - %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) - func.return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> -} - -// ----- - -// CHECK-LABEL: func @stopgradient -func.func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { - // CHECK-NEXT: return %arg0 : tensor<1xi32> - %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> - func.return %0: tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: func @preventgradient -func.func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { - // CHECK-NEXT: return %arg0 : tensor<1xi32> - %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32> - func.return %0: tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: func @checkNumerics -func.func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { - // CHECK-NEXT: return %arg0 : tensor<1xf32> - %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32> - func.return %0: tensor<1xf32> -} - -//===----------------------------------------------------------------------===// -// InfeedDequeueTuple legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @infeed_dequeue_tuple -func.func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) { - // CHECK: [[TOKEN:%.*]] = stablehlo.create_token : !stablehlo.token - // CHECK: [[INFEED:%.*]]:3 = "stablehlo.infeed"([[TOKEN]]) <{infeed_config = ""{{.*}}}> : (!stablehlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !stablehlo.token) - // CHECK: return [[INFEED]]#0, [[INFEED]]#1 - %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) - func.return %0#0, %0#1 : tensor<1x8x4x4xi32>, tensor<1x100x1xf32> -} - -// ----- - -// CHECK-LABEL: func @infeed_dequeue_tuple_dynamic_error -func.func @infeed_dequeue_tuple_dynamic_error() -> (tensor<3x3xf32>, tensor<4x?xf32>) { - // We expect legalization to fail for dynamic shapes: - // CHECK: [[INFEED:%.*]] = "tf.InfeedDequeueTuple"{{.*}} - %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3x3xf32>, tensor<4x?xf32>) - func.return %0#0, %0#1 : tensor<3x3xf32>, tensor<4x?xf32> -} - -// The following op sharding is used: -// Proto debug string: -// type: TUPLE -// tuple_shardings { -// type: MAXIMAL -// tile_assignment_dimensions: 1 -// tile_assignment_devices: 0 -// } -// Serialized string: -// "\08\02*\08\08\01\1A\01\01\22\01\00" - -// CHECK-LABEL: infeed_dequeue_tuple_sharding -func.func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { - // CHECK: "stablehlo.infeed" - // An additional sharding is added at the end to account for token result. - // Proto debug string: - // type: TUPLE - // tuple_shardings { - // type: MAXIMAL - // tile_assignment_dimensions: 1 - // tile_assignment_devices: 0 - // } - // tuple_shardings { - // type: MAXIMAL - // tile_assignment_dimensions: 1 - // tile_assignment_devices: 0 - // } - // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" - %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> - func.return %0 : tensor<8xi32> -} - -//===----------------------------------------------------------------------===// -// Nullary op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @const -func.func @const() -> tensor<2xi32> { - // CHECK: stablehlo.constant dense<0> : tensor<2xi32> - %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) - func.return %0: tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: @const_dynamic_output -func.func @const_dynamic_output() -> tensor<*xi32> { - // CHECK: [[CONST:%.*]] = stablehlo.constant dense<0> : tensor<2xi32> - // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32> - %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) - // CHECK: return [[CAST]] - func.return %0: tensor<*xi32> -} - -// ----- - -// CHECK-LABEL: @opaque_const -func.func @opaque_const() -> tensor>> { - // CHECK-NOT: stablehlo.constant - %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = #tf_type : tensor} : () -> tensor>> - func.return %0 : tensor>> -} - -//===----------------------------------------------------------------------===// -// Matmul op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: matmul_notranspose -// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>) -func.func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> { - // CHECK: stablehlo.dot %[[A]], %[[B]] - %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32> - - func.return %0 : tensor<5x11xf32> -} - -// ----- - -// CHECK-LABEL: matmul_transpose_b -// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>) -func.func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { - // CHECK: %[[UPDATED_B:.*]] = stablehlo.transpose %[[B]], dims = [1, 0] - // CHECK: stablehlo.dot %[[A]], %[[UPDATED_B]] - %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> - - func.return %0 : tensor<5x11xf32> -} - -// ----- - -// CHECK-LABEL: matmul_transpose_both -// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>) -func.func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { - // CHECK: %[[UPDATED_A:.*]] = stablehlo.transpose %[[A]] - // CHECK: %[[UPDATED_B:.*]] = stablehlo.transpose %[[B]] - // CHECK: stablehlo.dot %[[UPDATED_A]], %[[UPDATED_B]] - %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> - - func.return %0 : tensor<5x11xf32> -} - -// Verify that MatMul with ranked inputs are lowered to HLO. -// CHECK-LABEL: matmul_ranked -func.func @matmul_ranked(%a: tensor, %b: tensor<7x?xf32>) -> tensor { - // CHECK: stablehlo.dot - %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor, tensor<7x?xf32>) -> tensor - - func.return %0 : tensor -} - -// Verify SparseMatMul is legalized to dot. -// CHECK-LABEL: test_sparse_mat_mul -func.func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> { - // CHECK: stablehlo.dot - %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> - func.return %0: tensor<3x5xf32> -} - -// SparseMatMul where one operand needs to be transposed and the other one not. -// -// CHECK-LABEL: @test_sparse_mat_mul_with_transpose - // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> - // CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> - // CHECK-SAME: -> tensor<3x5xf32> - // CHECK: %[[TRANSPOSE:.*]] = stablehlo.transpose %[[ARG1]] - // CHECK-SAME: dims = [1, 0] - // CHECK-SAME: -> tensor<4x5xf32> - // CHECK: %[[RESULT:.*]] = stablehlo.dot %[[ARG0]], %[[TRANSPOSE]] - // CHECK-SAME: -> tensor<3x5xf32> - // CHECK: return %[[RESULT]] -func.func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { - %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> - func.return %0: tensor<3x5xf32> -} - -// SparseMatMul where one operand needs to be casted and the other one not. -// -// CHECK-LABEL: @test_sparse_mat_mul_with_cast - // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> - // CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> - // CHECK-SAME: -> tensor<3x5xf32> - // CHECK: %[[CAST:.*]] = stablehlo.convert %[[ARG1]] - // CHECK-SAME: -> tensor<4x5xf32> - // CHECK: %[[RESULT:.*]] = stablehlo.dot %[[ARG0]], %[[CAST]] - // CHECK-SAME: -> tensor<3x5xf32> - // CHECK: return %[[RESULT]] -func.func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { - %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> - func.return %0: tensor<3x5xf32> -} - -//===----------------------------------------------------------------------===// -// MaxPool op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: maxpool_valid_padding -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { - // CHECK: %[[INIT:.*]] = stablehlo.constant dense<-2147483648> : tensor - // CHECK: "stablehlo.reduce_window"(%[[ARG]], %[[INIT]]) - // CHECK-SAME: <{window_dimensions = array, window_strides = array}> - // CHECK: stablehlo.maximum - // CHECK: stablehlo.return - - %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> - func.return %0 : tensor<2x3x5x7xi32> -} - -// ----- - -// CHECK-LABEL: maxpool_same_padding -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { - // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> - - %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> - func.return %0 : tensor<2x4x7x7xi32> -} - -// ----- - -// CHECK-LABEL: maxpool_3d_valid_padding -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { - // CHECK: %[[INIT:.*]] = stablehlo.constant dense<0xFF800000> : tensor - // CHECK: "stablehlo.reduce_window"(%[[ARG]], %[[INIT]]) - // CHECK-SAME: <{window_dimensions = array, window_strides = array}> - // CHECK: stablehlo.maximum - // CHECK: stablehlo.return - - %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> - func.return %0 : tensor<2x8x3x5x7xf32> -} - -// ----- - -// CHECK-LABEL: maxpool_3d_same_padding -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { - // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> - - %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> - func.return %0 : tensor<2x8x4x7x7xf32> -} - -// ----- - -// CHECK-LABEL: maxpool_explicit_padding -func.func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { - // CHECK: tf.MaxPool - // TODO(b/165938852): need to support explicit padding in max_pool. - - %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> - func.return %0 : tensor<2x3x5x7xi32> -} - -//===----------------------------------------------------------------------===// -// MaxPoolGrad op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @max_pool_grad_valid -// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> -func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { - // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[RESULT:.*]] = "stablehlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) - // CHECK-SAME: <{window_dimensions = array, window_strides = array}> ({ - // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }, { - // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor - // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> - // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> - %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { - data_format = "NHWC", - ksize = [1, 2, 2, 1], - padding = "VALID", - strides = [1, 2, 2, 1] - } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> - func.return %result : tensor<10x24x24x64xf32> -} - -// ----- - -// CHECK-LABEL: @max_pool_3d_grad_valid -// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> -func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { - // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[RESULT:.*]] = "stablehlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) - // CHECK-SAME: <{window_dimensions = array, window_strides = array}> ({ - // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }, { - // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor - // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }) : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> - // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> - %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> - func.return %result : tensor<10x8x24x24x64xf32> -} - -// ----- - -// CHECK-LABEL: @max_pool_grad_same -func.func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { - // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> - %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { - data_format = "NHWC", - ksize = [1, 2, 3, 1], - padding = "SAME", - strides = [1, 4, 4, 1] - } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> - func.return %result : tensor<2x13x25x7xf32> -} - -// ----- - -// CHECK-LABEL: @max_pool_3d_grad_same -func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { - // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> - %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> - func.return %result : tensor<2x8x13x25x7xf32> -} - -//===----------------------------------------------------------------------===// -// OneHot op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL:one_hot -func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { - // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<5xi32> - // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<5xi32>) -> tensor<3x5xi32> - // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<3xi32>) -> tensor<3x5xi32> - // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare EQ, %[[BCAST1]], %[[BCAST0]], NOTYPE : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> - // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast %arg1, sizes = [3, 5] : (tensor) -> tensor<3x5xf32> - // CHECK-NEXT: %[[BCAST3:.*]] = stablehlo.broadcast %arg2, sizes = [3, 5] : (tensor) -> tensor<3x5xf32> - // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.select %[[CMP0]], %[[BCAST2]], %[[BCAST3]] : tensor<3x5xi1>, tensor<3x5xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<3x5xf32> - %depth = "tf.Const"() { value = dense<5> : tensor } : () -> tensor - %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<3x5xf32> - func.return %result : tensor<3x5xf32> -} - -//===----------------------------------------------------------------------===// -// tf.OutfeedEnqueueTuple legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @outfeed_enqueue_tuple -// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) -func.func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { - // CHECK: [[TOKEN:%.*]] = stablehlo.create_token : !stablehlo.token - // CHECK: "stablehlo.outfeed"([[VAL_0]], [[VAL_1]], [[TOKEN]]) <{outfeed_config = ""}> : (tensor<3xi32>, tensor<4xf32>, !stablehlo.token) -> !stablehlo.token - "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () - func.return -} - -//===----------------------------------------------------------------------===// -// Pack op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @pack -func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { - // CHECK: stablehlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> - // CHECK: stablehlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> - // CHECK: stablehlo.concatenate {{.*}}, {{.*}}, dim = 0 - - %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -//===----------------------------------------------------------------------===// -// PartitionedCall op legalization. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @partitioned_call -func.func @partitioned_call(%arg0: tensor) -> tensor { - // CHECK: call @pcall_func(%arg0) : (tensor) -> tensor - %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor) -> (tensor) - func.return %0 : tensor -} - - -func.func @pcall_func(%arg0: tensor) -> tensor { - func.return %arg0 : tensor -} - -// ----- - -// CHECK-LABEL: func @partitioned_call_multi_input -func.func @partitioned_call_multi_input(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor, tensor) -> tensor - %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor, tensor) -> (tensor) - func.return %0 : tensor -} - - -func.func @pcall_multi_input(%arg0: tensor, %arg1: tensor) -> tensor { - func.return %arg0 : tensor -} - -// ----- - -// CHECK-LABEL: func @partitioned_call_multi_in_out -func.func @partitioned_call_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) - %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor, tensor) -> (tensor, tensor) - func.return %0, %1 : tensor, tensor -} - - -func.func @pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - func.return %arg1, %arg0 : tensor, tensor -} - -// CHECK-LABEL: func @unhandled_partitioned_call -func.func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor, tensor) { - // The argument types don't match the parameter types for the - // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not - // for a standard CallOp, so this op can't be lowered. - // CHECK: "tf.PartitionedCall" - %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor, tensor) - func.return %0, %1 : tensor, tensor -} - - -// CHECK-LABEL: func @unhandled_partitioned_call_2 -func.func @unhandled_partitioned_call_2(%arg0: tensor, %arg1: tensor<*xi32>) -> (tensor, tensor) { - // CHECK: "tf.PartitionedCall" - %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor, tensor<*xi32>) -> (tensor, tensor) - func.return %0, %1 : tensor, tensor -} - -// ----- - -// CHECK-LABEL: func @no_args_and_results -func.func @no_args_and_results() { - // CHECK: call @callee() : () -> () - // CHECK: call @callee() : () -> () - // CHECK: call @callee() : () -> () - "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () - "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () - "tf.LegacyCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () - func.return -} - -func.func @callee() { - func.return -} - -//===----------------------------------------------------------------------===// -// ReverseV2 op legalization. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @reverse_func_32 -func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { - %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) - - // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [0] : tensor<5xi32> - %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> - - // CHECK: return [[VAL]] : tensor<5xi32> - func.return %reversed : tensor<5xi32> -} - -// ----- - -// CHECK-LABEL: @reverse_func_64 -func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { - %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) - - // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [0] : tensor<5xi32> - %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> - - // CHECK: return [[VAL]] : tensor<5xi32> - func.return %reversed : tensor<5xi32> -} - -// ----- - -// CHECK-LABEL: @reverse_func_neg -func.func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { - %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - - // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [1] : tensor<5x5xi32> - %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> - - // CHECK: return [[VAL]] : tensor<5x5xi32> - func.return %reversed : tensor<5x5xi32> -} - -//===----------------------------------------------------------------------===// -// StatefulPartitionedCall op legalization. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @stateful_partitioned_call -// CHECK-SAME: [[ARG:%.+]]: tensor -func.func @stateful_partitioned_call(%arg0: tensor) -> tensor { - // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor) -> tensor - %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor) -> (tensor) - func.return %0 : tensor -} - -func.func @stateful_pcall_func(%arg0: tensor) -> tensor { - func.return %arg0 : tensor -} - -// ----- - -// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out -// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) -func.func @stateful_partitioned_call_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor, tensor) -> (tensor, tensor) - %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor, tensor) -> (tensor, tensor) - func.return %0, %1 : tensor, tensor -} - -func.func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - func.return %arg1, %arg0 : tensor, tensor -} - -//===----------------------------------------------------------------------===// -// Elu op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @elu -func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { - // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<1xf32> - // CHECK-DAG: %[[PRED:.*]] = stablehlo.compare GT, %arg0, %[[ZERO]] - // CHECK-DAG: %[[EXP:.*]] = stablehlo.exponential_minus_one %arg0 - // CHECK: %[[RESULT:.*]] = stablehlo.select %[[PRED]], %arg0, %[[EXP]] - // CHECK: return %[[RESULT]] - %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> - func.return %0: tensor<1xf32> -} - -// ----- - -// CHECK-LABEL: func @elu_grad -// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) -func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { - // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> : tensor - // CHECK-DAG: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZERO]], {{.*}}, dims = [] : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[PRED:.*]] = stablehlo.compare GT, %[[FEATURES]], %[[BCAST0]] - // CHECK-DAG: %[[BCAST1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ONE]], {{.*}}, dims = [] : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ADD1:.*]] = stablehlo.add %[[FEATURES]], %[[BCAST1]] - // CHECK-DAG: %[[MULGRAD:.*]] = stablehlo.multiply %[[GRADIENTS]], %[[ADD1]] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> - // CHECK: %[[RESULT:.*]] = stablehlo.select %[[PRED]], %[[GRADIENTS]], %[[MULGRAD]] - // CHECK: return %[[RESULT]] - %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> - func.return %2 : tensor<4x8xf32> -} - -//===----------------------------------------------------------------------===// -// Relu op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @relu -func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { - // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor<1xi32> - // CHECK: stablehlo.maximum %arg0, %[[ZERO]] - %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> - func.return %0: tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: func @relu_unsigned -func.func @relu_unsigned(%arg0: tensor) -> tensor { - // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor - // CHECK: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZERO]], {{.*}}, dims = [] - // CHECK: stablehlo.maximum %arg0, %[[BCAST0]] - %0 = "tf.Relu"(%arg0) : (tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @relu6 -func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { - // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor - // CHECK-DAG: %[[SIX:.*]] = stablehlo.constant dense<6> : tensor - // CHECK: stablehlo.clamp %[[ZERO]], %arg0, %[[SIX]] - %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> - func.return %0: tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: func @relu6_unsigned -func.func @relu6_unsigned(%arg0: tensor) -> tensor { - // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor - // CHECK-DAG: %[[SIX:.*]] = stablehlo.constant dense<6> : tensor - // CHECK: stablehlo.clamp %[[ZERO]], %arg0, %[[SIX]] - %0 = "tf.Relu6"(%arg0) : (tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @leaky_relu -func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} { - // CHECK-NEXT: %[[ALPHA:.*]] = stablehlo.constant dense<2.000000e-01> - // CHECK-NEXT: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> - // CHECK-NEXT: %[[LEAKY:.*]] = stablehlo.multiply %arg0, %[[ALPHA]] - // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare GT, %arg0, %[[ZERO]] - // CHECK-NEXT: %[[RES:.*]] = stablehlo.select %[[CMP]], %arg0, %[[LEAKY]] - // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> - %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> - func.return %0 : tensor<1x4x4x3xf32> -} - -// ----- - -// CHECK-LABEL: func @leaky_relu_grad -func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} { - // CHECK-NEXT: %[[ALPHA:.*]] = stablehlo.constant dense<2.000000e-01> - // CHECK-NEXT: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> - // CHECK-NEXT: %[[LEAKYGRAD:.*]] = stablehlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] - // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE - // CHECK-NEXT: %[[RES:.*]] = stablehlo.select %[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]] - // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> - %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> - func.return %0 : tensor<1x4x4xf32> -} - -// ----- - -// CHECK-LABEL: func @softsign -func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { - // CHECK-NEXT: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %{{.*}} - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[ABS]], %[[ONE]] - // CHECK-NEXT: %[[DIV:.*]] = stablehlo.divide %{{.*}}, %[[ADD]] - // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32> - %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32> - func.return %0 : tensor<4x10xf32> -} - -// ----- - -// CHECK-LABEL: func @softsign_grad -func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> { - - // CHECK-NEXT: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %{{.*}} : tensor<4x10xf32> - // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = stablehlo.add %[[ABS]], %[[ONE]] - // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] - // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = stablehlo.divide %{{.*}}, %[[MUL]] - // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> - %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> - func.return %0 : tensor<4x10xf32> -} - -//===----------------------------------------------------------------------===// -// Roll op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @Roll_0D -func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor) -> tensor<512xi32> { - %axis = "tf.Const"() {value = dense<0> : tensor} : () -> (tensor) - // CHECK: %[[AXIS_SIZE:.*]] = stablehlo.constant dense<512> : tensor - // CHECK: %[[T1:.+]] = stablehlo.remainder %arg1, %[[AXIS_SIZE]] : tensor - // CHECK: %[[T2:.+]] = stablehlo.add %[[T1]], %[[AXIS_SIZE]] : tensor - // CHECK: %[[T3:.+]] = stablehlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor - // CHECK: %[[CONCAT:.+]] = stablehlo.concatenate %arg0, %arg0, dim = 0 - // CHECK: %[[OFFSET:.+]] = stablehlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor - // CHECK: stablehlo.dynamic_slice %[[CONCAT]], %[[OFFSET]], sizes = [512] - %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor, tensor) -> tensor<512xi32> - func.return %0 : tensor<512xi32> -} - -//===----------------------------------------------------------------------===// -// Select op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @select_batch_static -func.func @select_batch_static(%arg0: tensor<2xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { - // CHECK: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0] - // CHECK: stablehlo.select %[[BCAST]], %arg1, %arg2 - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> - func.return %0: tensor<2x6x8xi32> -} - -// ----- - -// CHECK-LABEL: func @select_batch_static_r1 -func.func @select_batch_static_r1(%arg0: tensor, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { - // CHECK: stablehlo.select %arg0, %arg1, %arg2 - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> - func.return %0: tensor<2x6x8xi32> -} - -// ----- - -// CHECK-LABEL: func @select_batch_static_all_same -func.func @select_batch_static_all_same(%arg0: tensor<2x6x8xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { - // CHECK: stablehlo.select %arg0, %arg1, %arg2 - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x6x8xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> - func.return %0: tensor<2x6x8xi32> -} - -// ----- - -// CHECK-LABEL: func @select_batch_dynamic_r1 -func.func @select_batch_dynamic_r1(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index - // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<1xindex> - // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor -> tensor<3xindex> - // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor -> tensor<3xindex> - // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> - // CHECK-NEXT: %[[HEAD:.*]], %[[TAIL:.*]] = "shape.split_at"(%[[SHAPE1]], %[[C1]]) : (tensor<3xindex>, index) -> (tensor<1xindex>, tensor<2xindex>) - // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<1xindex> - // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] - // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor) { - // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, %[[SHAPE1]], dims = [0] - // CHECK-NEXT: %[[SELECT:.*]] = stablehlo.select %[[BCAST]], %arg1, %arg2 : tensor, tensor - // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @select_batch_dynamic -func.func @select_batch_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<3xindex> - // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor -> tensor<3xindex> - // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor -> tensor<3xindex> - // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> - // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex> - // CHECK-NEXT: %[[SHAPEEQ3:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex> - // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming %[[SHAPEEQ3]] - // CHECK-NEXT: %[[SELECT:.*]] = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor - // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: testSelectInvalidUnranked -func.func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { - // CHECK-NEXT: tf.Select - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> - func.return %0: tensor<*xf16> -} - -// ----- - -// CHECK-LABEL: testSelectThenUnranked -func.func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { - // CHECK-NEXT: tf.Select - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> - func.return %0: tensor<*xf16> -} - -// ----- - -// CHECK-LABEL: testSelectElseUnranked -func.func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { - // CHECK-NEXT: tf.Select - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> - func.return %0: tensor<*xf16> -} - -// ----- - -// CHECK-LABEL: func @selectv2_dynamic_ranked -func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { - // CHECK: stablehlo.select - %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> - func.return %0: tensor<2x?x8xi32> -} - -//===----------------------------------------------------------------------===// -// Fast Fourier Transform op legalization. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @fft_1D -func.func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { - // CHECK: stablehlo.fft %arg0, type = FFT, length = [8] - %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> - func.return %0 : tensor<8xcomplex> -} - -// ----- - -// CHECK-LABEL: func @ifft_1D -func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { - // CHECK: stablehlo.fft %arg0, type = IFFT, length = [8] - %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> - func.return %0 : tensor<8xcomplex> -} - -// ----- - -// CHECK-LABEL: func @rfft_1D -func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: stablehlo.fft %arg0, type = RFFT, length = [8] - %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<5xcomplex> - func.return %0 : tensor<5xcomplex> -} - -// ----- - -// CHECK-LABEL: func @rfft_1D_padded -func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: %[[PADDED:.*]] = stablehlo.pad %arg0, %{{.*}}, low = [0], high = [1], interior = [0] - // CHECK: stablehlo.fft %[[PADDED]], type = RFFT, length = [8] - %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<5xcomplex> - func.return %0 : tensor<5xcomplex> -} - -// ----- - -// CHECK-LABEL: func @rfft_1D_sliced -func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: %[[SLICED:.*]] = stablehlo.slice %arg0 [0:2, 0:8] - // CHECK: stablehlo.fft %[[SLICED]], type = RFFT, length = [8] - %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x5xcomplex> - func.return %0 : tensor<2x5xcomplex> -} - -// ----- - -// CHECK-LABEL: func @irfft_1D -func.func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xf32> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: %[[SLICED:.*]] = stablehlo.slice %arg0 [0:5] - // CHECK: stablehlo.fft %[[SLICED]], type = IRFFT, length = [8] - %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex>, tensor<1xi32>) -> tensor<8xf32> - func.return %0 : tensor<8xf32> -} - -// ----- - -// CHECK-LABEL: fft_1D_dynamic -func.func @fft_1D_dynamic(%arg0: tensor>) -> tensor<8xcomplex> { - // CHECK: "tf.FFT" - %0 = "tf.FFT"(%arg0) : (tensor>) -> tensor<8xcomplex> - func.return %0 : tensor<8xcomplex> -} - -// ----- - -// CHECK-LABEL: rfft_1D_dynamic -func.func @rfft_1D_dynamic(%arg0: tensor) -> tensor<8xcomplex> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: "tf.RFFT" - %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor, tensor<1xi32>) -> tensor<8xcomplex> - func.return %0 : tensor<8xcomplex> -} - -//===----------------------------------------------------------------------===// -// Shape op legalization. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @shape_1D -func.func @shape_1D(%arg0: tensor) -> tensor<1xi32> { - // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 - // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<1xindex> to tensor<1xi32> - %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<1xi32> - - // CHECK: return [[TENSOR]] - func.return %0 : tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: func @shape_2D -func.func @shape_2D(%arg0: tensor) -> tensor<2xi32> { - // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 - // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<2xindex> to tensor<2xi32> - %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> - - // CHECK: return [[TENSOR]] - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: func @shape_rankless -func.func @shape_rankless(%arg0: tensor<*xf32>) -> tensor { - // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 - // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor to tensor - %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor - - // CHECK: return [[TENSOR]] - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// Transpose op legalization. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @transpose_noop -func.func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>) - // CHECK: return %arg0 - %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32> - func.return %0 : tensor<2x3xf32> -} - -// ----- - -// CHECK-LABEL: @transpose_2d -func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { - %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - // CHECK: stablehlo.transpose - %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> - func.return %0 : tensor<3x2xf32> -} - -// ----- - -// CHECK-LABEL: @transpose_3d_int32 -func.func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { - %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>) - // CHECK: stablehlo.transpose - %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> - func.return %0 : tensor<3x2x1xf32> -} - -// ----- - -// CHECK-LABEL: @transpose_3d -func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { - %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>) - // CHECK: stablehlo.transpose - %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> - func.return %0 : tensor<3x2x1xf32> -} - -// ----- - -// CHECK-LABEL: @transpose_dynamic_2d -func.func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { - %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - // CHECK: stablehlo.transpose - %0 = "tf.Transpose"(%arg0, %permutation) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> - func.return %0 : tensor<4x?xf32> -} - -//===----------------------------------------------------------------------===// -// Unary op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @abs -func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.abs %arg0 : tensor<2xf32> - %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @abs_dynamic -func.func @abs_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.abs %arg0 : tensor - %0 = "tf.Abs"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @acos -func.func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: %[[TEMP_0:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> - // CHECK: %[[TEMP_1:.*]] = stablehlo.subtract %[[TEMP_0]], %arg0 : tensor<2xf32> - // CHECK: %[[TEMP_2:.*]] = stablehlo.add %arg0, %[[TEMP_0]] : tensor<2xf32> - // CHECK: %[[TEMP_3:.*]] = stablehlo.multiply %[[TEMP_1]], %[[TEMP_2]] : tensor<2xf32> - // CHECK: %[[TEMP_4:.*]] = stablehlo.sqrt %[[TEMP_3]] : tensor<2xf32> - // CHECK: %[[TEMP_5:.*]] = stablehlo.atan2 %[[TEMP_4]], %arg0 : tensor<2xf32> - // CHECK: return %[[TEMP_5]] : tensor<2xf32> - %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @acos_complex -func.func @acos_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { -// CHECK-NEXT: %[[TEMP_1:.*]] = stablehlo.constant dense<4.33680869E-19> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_2:.*]] = stablehlo.constant dense<0.693147182> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_3:.*]] = stablehlo.constant dense<2.30584283E+20> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_4:.*]] = stablehlo.constant dense<2.30584274E+12> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_5:.*]] = stablehlo.constant dense<2.30584285E+30> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_6:.*]] = stablehlo.constant dense<1.41421354> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_7:.*]] = stablehlo.constant dense<2.30584287E+18> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_8:.*]] = stablehlo.constant dense<1.500000e+00> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_9:.*]] = stablehlo.constant dense<0x7F800000> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_10:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_11:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_12:.*]] = stablehlo.constant dense<5.000000e-01> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_13:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_14:.*]] = stablehlo.real %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> -// CHECK-NEXT: %[[TEMP_15:.*]] = stablehlo.abs %[[TEMP_14]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_16:.*]] = stablehlo.imag %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> -// CHECK-NEXT: %[[TEMP_17:.*]] = stablehlo.abs %[[TEMP_16]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_18:.*]] = stablehlo.maximum %[[TEMP_15]], %[[TEMP_17]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_19:.*]] = stablehlo.compare GE, %[[TEMP_18]], %[[TEMP_7]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_20:.*]] = stablehlo.compare LE, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_21:.*]] = stablehlo.add %[[TEMP_15]], %[[TEMP_13]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_22:.*]] = stablehlo.abs %[[TEMP_21]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_23:.*]] = stablehlo.maximum %[[TEMP_22]], %[[TEMP_17]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_24:.*]] = stablehlo.minimum %[[TEMP_22]], %[[TEMP_17]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_25:.*]] = stablehlo.compare EQ, %[[TEMP_23]], %[[TEMP_24]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_26:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_6]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_27:.*]] = stablehlo.divide %[[TEMP_24]], %[[TEMP_23]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_28:.*]] = stablehlo.multiply %[[TEMP_27]], %[[TEMP_27]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_29:.*]] = stablehlo.add %[[TEMP_28]], %[[TEMP_13]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_30:.*]] = stablehlo.sqrt %[[TEMP_29]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_31:.*]] = stablehlo.compare EQ, %[[TEMP_30]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_32:.*]] = stablehlo.compare GT, %[[TEMP_28]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_33:.*]] = stablehlo.and %[[TEMP_31]], %[[TEMP_32]] : tensor<2xi1> -// CHECK-NEXT: %[[TEMP_34:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_28]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_35:.*]] = stablehlo.divide %[[TEMP_34]], %[[TEMP_11]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_36:.*]] = stablehlo.add %[[TEMP_23]], %[[TEMP_35]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_37:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_30]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_38:.*]] = stablehlo.select %[[TEMP_33]], %[[TEMP_36]], %[[TEMP_37]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_39:.*]] = stablehlo.select %[[TEMP_25]], %[[TEMP_26]], %[[TEMP_38]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_40:.*]] = stablehlo.subtract %[[TEMP_15]], %[[TEMP_13]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_41:.*]] = stablehlo.abs %[[TEMP_40]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_42:.*]] = stablehlo.maximum %[[TEMP_41]], %[[TEMP_17]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_43:.*]] = stablehlo.minimum %[[TEMP_41]], %[[TEMP_17]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_44:.*]] = stablehlo.compare EQ, %[[TEMP_42]], %[[TEMP_43]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_45:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_6]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_46:.*]] = stablehlo.divide %[[TEMP_43]], %[[TEMP_42]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_47:.*]] = stablehlo.multiply %[[TEMP_46]], %[[TEMP_46]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_48:.*]] = stablehlo.add %[[TEMP_47]], %[[TEMP_13]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_49:.*]] = stablehlo.sqrt %[[TEMP_48]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_50:.*]] = stablehlo.compare EQ, %[[TEMP_49]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_51:.*]] = stablehlo.compare GT, %[[TEMP_47]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_52:.*]] = stablehlo.and %[[TEMP_50]], %[[TEMP_51]] : tensor<2xi1> -// CHECK-NEXT: %[[TEMP_53:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_47]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_54:.*]] = stablehlo.divide %[[TEMP_53]], %[[TEMP_11]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_55:.*]] = stablehlo.add %[[TEMP_42]], %[[TEMP_54]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_56:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_49]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_57:.*]] = stablehlo.select %[[TEMP_52]], %[[TEMP_55]], %[[TEMP_56]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_58:.*]] = stablehlo.select %[[TEMP_44]], %[[TEMP_45]], %[[TEMP_57]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_59:.*]] = stablehlo.add %[[TEMP_39]], %[[TEMP_58]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_60:.*]] = stablehlo.multiply %[[TEMP_59]], %[[TEMP_12]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_61:.*]] = stablehlo.add %[[TEMP_60]], %[[TEMP_15]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_62:.*]] = stablehlo.multiply %[[TEMP_61]], %[[TEMP_12]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_63:.*]] = stablehlo.multiply %[[TEMP_17]], %[[TEMP_17]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_64:.*]] = stablehlo.add %[[TEMP_39]], %[[TEMP_21]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_65:.*]] = stablehlo.divide %[[TEMP_63]], %[[TEMP_64]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_66:.*]] = stablehlo.subtract %[[TEMP_58]], %[[TEMP_40]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_67:.*]] = stablehlo.add %[[TEMP_65]], %[[TEMP_66]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_68:.*]] = stablehlo.multiply %[[TEMP_62]], %[[TEMP_67]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_69:.*]] = stablehlo.sqrt %[[TEMP_68]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_70:.*]] = stablehlo.divide %[[TEMP_62]], %[[TEMP_64]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_71:.*]] = stablehlo.add %[[TEMP_58]], %[[TEMP_40]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_72:.*]] = stablehlo.divide %[[TEMP_62]], %[[TEMP_71]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_73:.*]] = stablehlo.add %[[TEMP_70]], %[[TEMP_72]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_74:.*]] = stablehlo.sqrt %[[TEMP_73]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_75:.*]] = stablehlo.multiply %[[TEMP_17]], %[[TEMP_74]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_76:.*]] = stablehlo.select %[[TEMP_20]], %[[TEMP_69]], %[[TEMP_75]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_77:.*]] = stablehlo.select %[[TEMP_19]], %[[TEMP_17]], %[[TEMP_76]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_78:.*]] = stablehlo.compare LT, %[[TEMP_15]], %[[TEMP_5]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_79:.*]] = stablehlo.select %[[TEMP_78]], %[[TEMP_4]], %[[TEMP_3]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_80:.*]] = stablehlo.compare GE, %[[TEMP_17]], %[[TEMP_79]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_81:.*]] = stablehlo.select %[[TEMP_80]], %[[TEMP_17]], %[[TEMP_15]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_82:.*]] = stablehlo.select %[[TEMP_80]], %[[TEMP_79]], %[[TEMP_7]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_83:.*]] = stablehlo.compare GE, %[[TEMP_81]], %[[TEMP_82]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_84:.*]] = stablehlo.log %[[TEMP_81]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_85:.*]] = stablehlo.add %[[TEMP_84]], %[[TEMP_2]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_86:.*]] = stablehlo.compare EQ, %[[TEMP_17]], %[[TEMP_9]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_87:.*]] = stablehlo.not %[[TEMP_86]] : tensor<2xi1> -// CHECK-NEXT: %[[TEMP_88:.*]] = stablehlo.and %[[TEMP_80]], %[[TEMP_87]] : tensor<2xi1> -// CHECK-NEXT: %[[TEMP_89:.*]] = stablehlo.divide %[[TEMP_15]], %[[TEMP_17]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_90:.*]] = stablehlo.select %[[TEMP_88]], %[[TEMP_89]], %[[TEMP_10]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_91:.*]] = stablehlo.multiply %[[TEMP_90]], %[[TEMP_90]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_92:.*]] = stablehlo.log_plus_one %[[TEMP_91]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_93:.*]] = stablehlo.multiply %[[TEMP_92]], %[[TEMP_12]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_94:.*]] = stablehlo.add %[[TEMP_85]], %[[TEMP_93]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_95:.*]] = stablehlo.compare LT, %[[TEMP_17]], %[[TEMP_1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_96:.*]] = stablehlo.compare LT, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_97:.*]] = stablehlo.and %[[TEMP_95]], %[[TEMP_96]] : tensor<2xi1> -// CHECK-NEXT: %[[TEMP_98:.*]] = stablehlo.multiply %[[TEMP_21]], %[[TEMP_40]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_99:.*]] = stablehlo.add %[[TEMP_60]], %[[TEMP_13]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_100:.*]] = stablehlo.divide %[[TEMP_98]], %[[TEMP_99]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_101:.*]] = stablehlo.negate %[[TEMP_100]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_102:.*]] = stablehlo.compare GE, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_103:.*]] = stablehlo.multiply %[[TEMP_63]], %[[TEMP_12]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_104:.*]] = stablehlo.divide %[[TEMP_103]], %[[TEMP_64]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_105:.*]] = stablehlo.multiply %[[TEMP_71]], %[[TEMP_12]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_106:.*]] = stablehlo.add %[[TEMP_104]], %[[TEMP_105]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_107:.*]] = stablehlo.compare LE, %[[TEMP_60]], %[[TEMP_8]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_108:.*]] = stablehlo.divide %[[TEMP_103]], %[[TEMP_66]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_109:.*]] = stablehlo.add %[[TEMP_104]], %[[TEMP_108]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_110:.*]] = stablehlo.subtract %[[TEMP_60]], %[[TEMP_13]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_111:.*]] = stablehlo.select %[[TEMP_107]], %[[TEMP_109]], %[[TEMP_110]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_112:.*]] = stablehlo.select %[[TEMP_102]], %[[TEMP_106]], %[[TEMP_111]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_113:.*]] = stablehlo.select %[[TEMP_97]], %[[TEMP_101]], %[[TEMP_112]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_114:.*]] = stablehlo.multiply %[[TEMP_113]], %[[TEMP_99]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_115:.*]] = stablehlo.sqrt %[[TEMP_114]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_116:.*]] = stablehlo.divide %[[TEMP_17]], %[[TEMP_115]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_117:.*]] = stablehlo.add %[[TEMP_113]], %[[TEMP_115]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_118:.*]] = stablehlo.log_plus_one %[[TEMP_117]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_119:.*]] = stablehlo.select %[[TEMP_97]], %[[TEMP_116]], %[[TEMP_118]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_120:.*]] = stablehlo.select %[[TEMP_83]], %[[TEMP_94]], %[[TEMP_119]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_121:.*]] = stablehlo.real %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> -// CHECK-NEXT: %[[TEMP_122:.*]] = stablehlo.atan2 %[[TEMP_77]], %[[TEMP_121]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_123:.*]] = stablehlo.imag %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> -// CHECK-NEXT: %[[TEMP_124:.*]] = stablehlo.compare LT, %[[TEMP_123]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: %[[TEMP_125:.*]] = stablehlo.negate %[[TEMP_120]] : tensor<2xf32> -// CHECK-NEXT: %[[TEMP_126:.*]] = stablehlo.select %[[TEMP_124]], %[[TEMP_120]], %[[TEMP_125]] : tensor<2xi1>, tensor<2xf32> -// CHECK-NEXT: %[[TEMP_127:.*]] = stablehlo.complex %[[TEMP_122]], %[[TEMP_126]] : tensor<2xcomplex> -// CHECK-NEXT: return %[[TEMP_127]] : tensor<2xcomplex> - - %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> - func.return %0 : tensor<2xcomplex> -} - -// ----- - -// CHECK-LABEL: @acos_dynamic -func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Acos" - %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: func @cast_dynamic_i2f -func.func @cast_dynamic_i2f(%arg0: tensor) -> tensor { - // CHECK: stablehlo.convert %arg0 : (tensor) -> tensor - %0 = "tf.Cast"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @cast_i2f -func.func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> { - // CHECK: stablehlo.convert %arg0 : (tensor<2xi32>) -> tensor<2xf32> - %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @cast_c2f -func.func @cast_c2f(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { - // CHECK: stablehlo.convert %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> - %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @ceil -func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.ceil %arg0 : tensor<2xf32> - %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @ceil_dynamic -func.func @ceil_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.ceil %arg0 : tensor - %0 = "tf.Ceil"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @complex_abs -func.func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { - // CHECK: stablehlo.abs %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> - %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @cos -func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.cosine %arg0 : tensor<2xf32> - %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @tan -func.func @tan(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.tan %arg0 : tensor<2xf32> - %0 = "tf.Tan"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @cos_dynamic -func.func @cos_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.cosine %arg0 : tensor - %0 = "tf.Cos"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @exp -func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.exponential %arg0 : tensor<2xf32> - %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @expm1 -func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.exponential_minus_one %arg0 : tensor<2xf32> - %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @exp_dynamic -func.func @exp_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.exponential %arg0 : tensor - %0 = "tf.Exp"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @floor -func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.floor %arg0 : tensor<2xf32> - %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @floor_dynamic -func.func @floor_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.floor %arg0 : tensor - %0 = "tf.Floor"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @is_finite -func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { - // CHECK: stablehlo.is_finite %arg0 : (tensor<2xf32>) -> tensor<2xi1> - %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> - func.return %0 : tensor<2xi1> -} - -// ----- - -// CHECK-LABEL: func @is_finite_dynamic -func.func @is_finite_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.is_finite %arg0 : (tensor) -> tensor - %0 = "tf.IsFinite"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @log -func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.log %arg0 : tensor<2xf32> - %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @log_dynamic -func.func @log_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.log %arg0 : tensor - %0 = "tf.Log"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @log1p -func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.log_plus_one %arg0 : tensor<2xf32> - %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @log1p_dynamic -func.func @log1p_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.log_plus_one %arg0 : tensor - %0 = "tf.Log1p"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @neg -func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.negate %arg0 : tensor<2xf32> - %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @neg_dynamic -func.func @neg_dynamic(%arg0: tensor) -> tensor { - // CHECK: stablehlo.negate %arg0 : tensor - %0 = "tf.Neg"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @sigmoid -func.func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: stablehlo.logistic - %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @sigmoid_complex -func.func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { - // CHECK: stablehlo.logistic - %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> - func.return %0 : tensor<2xcomplex> -} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc deleted file mode 100644 index 0e7f1744d5fb63..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc +++ /dev/null @@ -1,6911 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements logic for lowering TensorFlow dialect to XLA dialect. -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/Dialect/Traits.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" -#include "xla/client/lib/conv_grad_size_util.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" -#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/utils/convert_op_folder.h" -#include "xla/mlir_hlo/utils/hlo_utils.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/core/framework/kernel_shape_util.h" -#include "tensorflow/core/framework/rng_alg.h" -#include "tensorflow/core/kernels/conv_grad_shape_utils.h" -#include "tensorflow/core/util/padding.h" -#include "tensorflow/core/util/tensor_format.h" -#include "tsl/platform/bfloat16.h" -#include "tsl/platform/status.h" -#include "tsl/platform/tensor_float_32_utils.h" - -namespace mlir { -// Keep this in the mlir namespace to allow the use of the mhlo ops. -namespace mhlo { -namespace { - -// The utils are copied into the odml namespace to avoid duplicate names and -// they are imported here to avoid having to change the code below. -using ::mlir::odml::BuildReduceBody; -using ::mlir::odml::GetI64ElementsAttr; -using ::mlir::odml::GetScalarConstOfType; -using ::mlir::odml::GetScalarNegZeroOfType; - -constexpr char kShardingAttr[] = "mhlo.sharding"; - -/// Returns the feature dimension for the given format and input type. -static size_t GetFeatureDimension(tensorflow::TensorFormat format, - RankedTensorType input_ty) { - return GetTensorFeatureDimIndex(input_ty.getRank(), format); -} - -// Gets all integer values from the given attribute and push them to `values`. -void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl *values) { - auto array_attr = mlir::cast(attr); - values->reserve(array_attr.getValue().size()); - for (Attribute val : array_attr.getValue()) - values->push_back(mlir::cast(val).getValue().getSExtValue()); -} - -// Returns 1D 32-bit dense elements attribute with the given values. -static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, - Builder *builder) { - RankedTensorType ty = tensorflow::GetTypeFromTFTensorShape( - {static_cast(values.size())}, builder->getIntegerType(32)); - return DenseIntElementsAttr::get(ty, values); -} - -// Returns a 1-d i64 elements attribute populated with numbers from start to -// end, excluding. -static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, - Builder *builder) { - int size = end - start; - - SmallVector vals; - vals.resize(size); - std::iota(vals.begin(), vals.end(), start); - - TensorType ty = - tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, vals); -} - -// Returns a 1-d i64 elements attribute populated with `val` repeated `size` -// times. -static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val, - Builder *builder) { - TensorType ty = - tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, val); -} - -// Returns the corresponding type that should be used for performing sum -// accumulation over the given input type. -Type GetSumAccumulationType(Type input_type) { - MLIRContext *ctx = input_type.getContext(); - if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx); - if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16)) - return IntegerType::get(ctx, 32); - return input_type; -} - -// Returns axis in HLO format from TF elements attr with exactly one element or -// is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow -// format supports negative indexing unlike HLO. -static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, - Builder *b) { - IntegerAttr intAttr = mlir::dyn_cast_or_null(attr); - if (auto elementAttr = mlir::dyn_cast_or_null(attr)) { - SmallVector index(elementAttr.getShapedType().getRank(), 0); - intAttr = elementAttr.getValues()[index]; - } - - assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis"); - - int64_t axis = intAttr.getInt(); - if (axis < 0) { - axis += rank; - } - return b->getI64IntegerAttr(axis); -} - -// Returns a PrecisionConfig as an array attribute based on whether TF32 -// execution is enabled -static ArrayAttr GetPrecisionConfig(Builder *builder) { - mlir::mhlo::Precision precision = tsl::tensor_float_32_execution_enabled() - ? mhlo::Precision::DEFAULT - : mlir::mhlo::Precision::HIGHEST; - llvm::SmallVector attr_vec; - const int num_inputs = 2; - for (int i = 0; i < num_inputs; i++) { - attr_vec.push_back( - mlir::mhlo::PrecisionAttr::get(builder->getContext(), precision)); - } - return builder->getArrayAttr(attr_vec); -} - -// If `value` is an IntegerAttr, returns the integer value for the HLO axis -// corresponding to the tensorflow axis. In particular, the tensorflow axis can -// be negative, in which case, the corresponding HLO axis is -// (axis + rank-of-the-tensor). -static std::optional GetIntegerHLOAxisFromTFAxis(Value value, - int64_t rank) { - DenseIntElementsAttr attrs; - if (!matchPattern(value, m_Constant(&attrs)) || - attrs.getType().getRank() != 0) { - return std::nullopt; - } - int64_t axis = attrs.getValues()[0].getInt(); - return axis < 0 ? axis + rank : axis; -} - -/// Returns a `ConvertOp` that casts the elements to a i64 type while retaining -/// the shape of the input value. -static ConvertOp CastValueToI64(Location loc, Value value, - PatternRewriter *rewriter) { - return rewriter->create(loc, value, rewriter->getIntegerType(64)); -} - -// Creates an unpack op along the 0th dimension of the tensor. The `value` input -// must be a ranked tensor. -static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, - PatternRewriter *rewriter) { - auto indices_type = mlir::cast(value.getType()); - int num_outputs = indices_type.getShape().front(); - SmallVector unpacked_indices_type( - num_outputs, - tensorflow::GetTypeFromTFTensorShape({}, indices_type.getElementType())); - auto unpacked_indices = rewriter->create( - loc, unpacked_indices_type, value, - IntegerAttr::get(rewriter->getIntegerType(64), 0)); - return unpacked_indices; -} - -// Returns size of dimension at the specified index, if ranked tensor. -// Otherwise, returns -1. -// -// Aborts if the type is ranked but doesn't have the dimension. -int64_t GetDimSize(Type ty, int64_t index) { - RankedTensorType ranked_ty = mlir::dyn_cast(ty); - if (!ranked_ty) return -1; - - return ranked_ty.getDimSize(index); -} - -template -tensorflow::TensorShape ToTensorShape(llvm::ArrayRef sizes) { - return tensorflow::TensorShape( - llvm::SmallVector(sizes.begin(), sizes.end())); -} - -template -tensorflow::TensorShape ToTensorShape( - llvm::iterator_range> sizes) { - return tensorflow::TensorShape( - llvm::SmallVector(sizes.begin(), sizes.end())); -} - -// Returns a limit scalar const op for the given type. -// Requires FloatType or IntegerType -static ConstantOp GetScalarLimitConstOfType(Type ty, Location loc, - hlo::ScalarLimit limit, - OpBuilder *builder) { - return builder->create(loc, hlo::getScalarLimitOfType(ty, limit)); -} - -// Deprecated: This is maintained to aid in porting old code that is not yet -// dynamic shape aware and uses broadcasting modes that CHLO does not support. -// Gets the resulting type from a broadcast between two types for statically -// shaped types. This is to be used for legacy lowerings that both use non -// left-padded broadcasting and static shapes. Its use should not be permitted -// in new code. -// May return nullptr on invalid static broadcast dimensions. -// ABSL_DEPRECATED() -static RankedTensorType GetStaticBroadcastType( - RankedTensorType x, RankedTensorType y, - DenseIntElementsAttr broadcast_dimensions_attr) { - auto element_type = x.getElementType(); - auto shape_x = x.getShape(); - auto shape_y = y.getShape(); - - if (shape_x.size() == shape_y.size()) { - llvm::SmallVector out_shape(shape_x.size()); - for (int i = 0; i < shape_x.size(); i++) { - auto x_val = shape_x[i]; - auto y_val = shape_y[i]; - out_shape[i] = std::max(x_val, y_val); - } - return tensorflow::GetTypeFromTFTensorShape(out_shape, element_type); - } - - auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; - auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; - - llvm::SmallVector broadcast_dimensions; - // Explicit broadcast dimensions. - for (const APInt &int_value : broadcast_dimensions_attr) { - broadcast_dimensions.push_back(int_value.getSExtValue()); - } - if (broadcast_dimensions.size() != shape_small.size()) { - return nullptr; - } - llvm::SmallVector out_shape(shape_large.begin(), - shape_large.end()); - - // Update according to the broadcast dimensions. - for (const auto &index_pair : llvm::enumerate(broadcast_dimensions)) { - auto old_value = out_shape[index_pair.value()]; - auto new_value = shape_small[index_pair.index()]; - out_shape[index_pair.value()] = std::max(old_value, new_value); - } - return tensorflow::GetTypeFromTFTensorShape(out_shape, element_type); -} - -// Deprecated: This is maintained to aid in porting old code that is not yet -// dynamic shape aware and uses broadcasting modes that CHLO does not support. -// Applies static binary broadcasting to a binary elementwise op. -// This is a legacy helper to provide general broadcasting support in legacy, -// static shaped code that relies on non-left-padded broadcasting semantics. -template -static Value StaticBinaryBroadcast(Location loc, Value x, Value y, - DenseIntElementsAttr broadcast_dims, - OpBuilder &builder) { - auto x_type = mlir::cast(x.getType()); - auto y_type = mlir::cast(y.getType()); - auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); - if (!result_type) { - emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type - << " with broadcast_dims = " << broadcast_dims; - return nullptr; - } - auto larger_broadcast_dims = - GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); - if (x_type.getRank() < y_type.getRank()) { - if (x_type != result_type) { - x = builder.create(loc, result_type, x, broadcast_dims); - } - if (y_type != result_type) { - y = builder.create(loc, result_type, y, - larger_broadcast_dims); - } - } else { - if (x_type != result_type) { - x = builder.create(loc, result_type, x, - larger_broadcast_dims); - } - if (y_type != result_type) { - y = builder.create(loc, result_type, y, broadcast_dims); - } - } - return builder.create(loc, x, y); -} - -// Gets a 1D tensor type suitable for expressing extents of the given tensor -// value type. If the value type is ranked, the result will be statically -// shaped. Otherwise, it will have a dynamic dimension. -static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { - Builder b(value_type.getContext()); - int64_t dim = value_type.hasRank() ? value_type.getRank() : -1; - return tensorflow::GetTypeFromTFTensorShape({dim}, b.getIndexType()); -} - -// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D -// value (broadcast_from) along that feature dimension. This is a shortcut -// for the cases where a 1D tensor must be broadcast along a specific feature -// dimension, which can vary based on data layout, etc. -// -// The extent of `broadcast_from` dim0 must be equal to the extent of the -// feature_dim of `broadcast_to`. -// -// Example: -// [1x2x3x4], [2], 1 -> [1x2x3x4] -// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for -// consistency. Possibly also rename for clarity. -static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, - Value broadcast_from, int64_t feature_dim, - OpBuilder &builder) { - auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); - auto to_type = mlir::cast(broadcast_to.getType()); - auto result_shape = builder.create(loc, broadcast_to); - auto result_extents_type = GetExtentsTensorTypeFor(to_type); - auto result_extents = builder.create( - loc, result_extents_type, result_shape); - return builder.create( - loc, to_type, broadcast_from, result_extents, broadcast_dims); -} - -// Broadcasts `input` to the shape of `broadcast_to` value following -// TF::BroadcastTo semantics. -// -// Requires that input is a ranked tensor. -// -// TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp -// supports unranked inputs in the lowering. -static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, - OpBuilder &builder) { - auto result_shape = builder.create(loc, broadcast_to); - auto to_type = mlir::cast(broadcast_to.getType()); - auto result_extents_type = GetExtentsTensorTypeFor(to_type); - auto result_extents = builder.create( - loc, result_extents_type, result_shape); - int64_t rank = mlir::cast(input.getType()).getRank(); - auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); - return builder.create( - loc, to_type, input, result_extents, broadcast_dims); -} - -// Builds a set of operations for applying reduction on the input value. A -// tf.sum op is created and will be legalized to tfl ops automatically. -static Value ApplyReduction(Location loc, Value input, - DenseIntElementsAttr reduce_dims, - OpBuilder *builder) { - auto reduce_dims_op = builder->create(loc, reduce_dims); - return builder->create(loc, input, reduce_dims_op, - builder->getBoolAttr(false)); -} - -// Creates a mhlo.rng_uniform op with `builder` to generate `num_elements` -// 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`). -static mhlo::RngOp CreateRngUniform32(Location loc, int num_elements, - int lower_limit, int upper_limit, - OpBuilder *builder) { - auto shape_tensor = builder->create( - loc, GetI64ElementsAttr({num_elements}, builder)); - - auto lower = builder->create( - loc, builder->getI32IntegerAttr(lower_limit)); - auto upper = builder->create( - loc, builder->getI32IntegerAttr(upper_limit)); - - return builder->create(loc, lower, upper, shape_tensor, - ::mlir::mhlo::RngDistribution::UNIFORM); -} - -using WhileBodyFnType = llvm::function_ref old_values, - SmallVectorImpl *new_values, OpBuilder *builder)>; - -// Creates a mhlo.while op with `builder` to loop `num_interations` times, -// each time calling the given `body_fn` on a set of values to generate a new -// set of values. Returns the final set of values via `final_values`. The -// initial set of values is passed in via `init_values`. -// -// This effectively does: -// -// ```c++ -// SmallVector old_values = init_values; -// SmallVector new_values; -// for (int i = 0; i < num_iterations; ++i) { -// body_fn(old_values, &new_values, ...); -// old_values = new_values; -// } -// ``` -// -// Under the hood an induction variable is prepended to values to control the -// number of iterations, but that is transparent to `body_fn`, which does not -// need to care about that. -static void CreateWhile32(Location loc, int num_iterations, - WhileBodyFnType body_fn, ArrayRef init_values, - SmallVectorImpl *final_values, - OpBuilder *builder) { - int value_count = init_values.size() + 1; - - // Prepend a loop induction variable to the initial values. - SmallVector init_values_with_loop_iv; - SmallVector init_types_with_loop_iv; - init_values_with_loop_iv.reserve(value_count); - init_types_with_loop_iv.reserve(value_count); - - // The initial value for the loop induction variable is 0. - init_values_with_loop_iv.push_back( - builder->create(loc, builder->getI32IntegerAttr(0))); - init_values_with_loop_iv.append(init_values.begin(), init_values.end()); - - // Accumulate types of all the init values. - for (const auto &init_value_with_loop_iv : init_values_with_loop_iv) - init_types_with_loop_iv.push_back(init_value_with_loop_iv.getType()); - - // Create the while op. - auto while_op = builder->create(loc, init_types_with_loop_iv, - init_values_with_loop_iv); - auto ivs_count = init_types_with_loop_iv.size(); - - { - OpBuilder::InsertionGuard guard(*builder); - - // Build up the only block in the condition region. - Region &condition = while_op.getCond(); - Block *block = builder->createBlock(&condition); - block->addArguments(init_types_with_loop_iv, - SmallVector(ivs_count, loc)); - - // Get the loop induction variable and compare it against the upper limit. - auto loop_iv = block->getArgument(0); - auto upper_limit = builder->create( - loc, builder->getI32IntegerAttr(num_iterations)); - Value compare = builder->create(loc, loop_iv, upper_limit, - ComparisonDirection::LT); - - builder->create(loc, compare); - } - - { - OpBuilder::InsertionGuard guard(*builder); - - // Build up the only block in the body region. - Region &body = while_op.getBody(); - Block *block = builder->createBlock(&body); - block->addArguments(init_types_with_loop_iv, - SmallVector(ivs_count, loc)); - - SmallVector new_values; // Generated by this iteration - new_values.reserve(value_count); - - // Feed all values excluding the loop induction variable to body_fn. - body_fn(loc, block->getArgument(0), - ArrayRef(block->getArguments().begin() + 1, - block->getArguments().end()), - &new_values, builder); - - // Increment the loop induction variable by one. - auto one = - builder->create(loc, builder->getI32IntegerAttr(1)); - auto scalar_broadcast_dims = builder->getDenseI64ArrayAttr({}); - auto plus_one = builder->create( - loc, block->getArgument(0), one, scalar_broadcast_dims); - // Prepend with the updated loop induction variable. - new_values.insert(new_values.begin(), plus_one); - - builder->create(loc, new_values); - } - - // TODO(jpienaar): Support multi-operand while op. - final_values->reserve(init_values.size()); - for (int i = 0, e = init_values.size(); i < e; ++i) - final_values->push_back(while_op.getResult(i + 1)); -} - -//===----------------------------------------------------------------------===// -// BatchNorm op utilities. -//===----------------------------------------------------------------------===// - -static IntegerAttr getFeatureDimensionAttr(Builder &b, - tensorflow::TensorFormat format, - Value input) { - return b.getI64IntegerAttr(GetFeatureDimension( - format, mlir::cast(input.getType()))); -} - -//===----------------------------------------------------------------------===// -// FFT op utilities. -//===----------------------------------------------------------------------===// - -// Returns the 1D i64 elements attribute populated with the inner-most dim of -// the value. -static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type, - Builder *builder) { - if (type.getRank() == 0) { - return builder->getI64TensorAttr({}); - } - return builder->getI64TensorAttr(type.getShape().back()); -} - -// Returns True if the inner-most dim is static. -bool CheckInnerDimStatic(ShapedType type, Builder *builder) { - if (!type.hasRank()) { - return false; - } - return !type.isDynamicDim(type.getShape().size() - 1); -} - -//===----------------------------------------------------------------------===// -// MatMul op utilities. -//===----------------------------------------------------------------------===// - -// If the 'transpose' attribute is true returns ElementsAttr to transpose 2D -// matrix. Otherwise, returns ElementsAttr for identity transpose. -static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { - if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b); - return GetI64ElementsAttr({0, 1}, b); -} - -//===----------------------------------------------------------------------===// -// Pad op utilities. -//===----------------------------------------------------------------------===// - -// Slices input attribute of rank two and returns the specified column. -// -// Always returns 64 bit integer attribute regardless of bitwidth of the input -// attribute. -static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( - ElementsAttr input, int column) { - auto int_attr = mlir::cast(input); - auto shaped_type = int_attr.getType(); - auto shape = shaped_type.getShape(); - - if (shape.size() != 2) return DenseIntElementsAttr(); - - llvm::SmallVector values; - values.reserve(shaped_type.getNumElements() / shape[1]); - - for (const auto &it : llvm::enumerate(int_attr.getValues())) { - if (static_cast(it.index() % shape[1]) == column) { - values.push_back(it.value().getSExtValue()); - } - } - - auto element_type = IntegerType::get(input.getContext(), 64); - return DenseIntElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape({shape[0]}, element_type), values); -} - -// Returns interior padding to use in HLO Pad op based on the TensorFlow padding -// in TensorFlow PadV2 op. -static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { - auto length = tf_padding.getShapedType().getShape()[0]; - auto element_type = IntegerType::get(tf_padding.getContext(), 64); - return DenseIntElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape({length}, element_type), 0); -} - -//===----------------------------------------------------------------------===// -// Binary op utilities. -//===----------------------------------------------------------------------===// - -// Returns whether the two values are guaranteed to be broadcastable to the -// same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions -// must be broadcasted with a size 1 tensor or another dynamic dimension. -// Returns false on rankless. -static bool AreBroadcastCompatible(Value x, Value y) { - auto x_rankless = mlir::dyn_cast(x.getType()); - auto y_rankless = mlir::dyn_cast(y.getType()); - if (!x_rankless || !y_rankless) { - return false; - } - - // Check that the shapes can be broadcasted. - auto shape_x = x_rankless.getShape(); - auto shape_y = y_rankless.getShape(); - - int rank_diff = shape_x.size() - shape_y.size(); - int offset_x = rank_diff > 0 ? rank_diff : 0; - int offset_y = rank_diff < 0 ? -rank_diff : 0; - for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) { - int index_x = i + offset_x; - int index_y = i + offset_y; - if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) || - (shape_y[index_y] == -1 && shape_x[index_x] != 1)) { - return false; - } - } - - return true; -} - -// Return a new TensorType the same rank and dimensions as the input with an -// updated element type. -static Type ChangeTensorElementType(Builder *b, Type tensor_type, - Type element_type) { - RankedTensorType ranked_type = mlir::dyn_cast(tensor_type); - if (ranked_type) { - return tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), - element_type); - } - - return UnrankedTensorType::get(element_type); -} - -//===----------------------------------------------------------------------===// -// Softmax op utilities. -//===----------------------------------------------------------------------===// - -// Returns the type to use for accumulating the given type. -static Type GetAccumulationType(Type ty) { - // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from - // repeated floating point additions. - return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty; -} - -//===----------------------------------------------------------------------===// -// Softplus op utilities. -//===----------------------------------------------------------------------===// - -static DenseElementsAttr GetEpsilonValue(Type ty) { - auto element_ty = mlir::cast(ty).getElementType(); - auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, element_ty); - if (element_ty.isF16()) { - uint16_t raw_epsilon = Eigen::numext::bit_cast( - Eigen::NumTraits::epsilon()); - auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon)); - return DenseElementsAttr::get(scalar_ty, value); - } else if (element_ty.isBF16()) { - uint16_t raw_epsilon = Eigen::numext::bit_cast( - Eigen::NumTraits::epsilon()); - auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon)); - return DenseElementsAttr::get(scalar_ty, value); - } else if (element_ty.isF32()) { - auto value = APFloat(std::numeric_limits::epsilon()); - return DenseElementsAttr::get(scalar_ty, value); - } else if (element_ty.isF64()) { - auto value = APFloat(std::numeric_limits::epsilon()); - return DenseElementsAttr::get(scalar_ty, value); - } - llvm_unreachable("unsupported element type for tf.SoftPlus"); -} - -//===----------------------------------------------------------------------===// -// ArgMax/ArgMin op utilities. -//===----------------------------------------------------------------------===// - -static void BuildArgMinMaxReductionBody(Type input_element_type, - Type index_element_type, - ComparisonDirection direction, - Region *body, OpBuilder *builder) { - OpBuilder::InsertionGuard insertion_point_gurad(*builder); - - Type input_type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, input_element_type); - Type index_type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, index_element_type); - Block *block = builder->createBlock(body); - Location loc = body->getLoc(); - block->addArguments({input_type, index_type, input_type, index_type}, - SmallVector(4, loc)); - - Value lhs_val = block->getArgument(0); - Value lhs_index = block->getArgument(1); - Value rhs_val = block->getArgument(2); - Value rhs_index = block->getArgument(3); - - ImplicitLocOpBuilder b(loc, *builder); - Value compare_dt = b.create(lhs_val, rhs_val, direction); - Value selected_input = - b.create(input_type, compare_dt, lhs_val, rhs_val); - - Value compare_eq = - b.create(lhs_val, rhs_val, ComparisonDirection::EQ); - Value min_index = b.create(lhs_index, rhs_index); - Value min_val_index = - b.create(index_type, compare_dt, lhs_index, rhs_index); - Value selected_index = - b.create(index_type, compare_eq, min_index, min_val_index); - - Value return_values[] = {selected_input, selected_index}; - b.create(return_values); -} - -//===----------------------------------------------------------------------===// -// PartitionedCall op utilities. -//===----------------------------------------------------------------------===// - -// Verify that the arguments to be passed into the function are the same types -// as the function paramter types. -static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args, - SymbolRefAttr func) { - auto module = op->getParentOfType(); - auto function = - dyn_cast_or_null(SymbolTable::lookupSymbolIn(module, func)); - FunctionType function_ty = function.getFunctionType(); - - for (auto arg_in : llvm::zip(args, function_ty.getInputs())) { - if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) { - // Argument type and input type mismatch. - return false; - } - } - return true; -} - -//===----------------------------------------------------------------------===// -// Slice op utilities. -//===----------------------------------------------------------------------===// - -static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, - DenseIntElementsAttr slice_sizes) { - auto input_ty = mlir::dyn_cast(input.getType()); - if (!input_ty) return false; - auto start_indices_ty = - mlir::dyn_cast(start_indices.getType()); - if (!start_indices_ty) return false; - - int64_t input_rank = input_ty.getRank(); - ArrayRef input_shape = input_ty.getShape(); - DenseIntElementsAttr constant_start_indices; - bool is_constant_start = - matchPattern(start_indices, m_Constant(&constant_start_indices)); - - for (int64_t i = 0; i < input_rank; ++i) { - int64_t input_size = input_shape[i]; - int64_t slice_size = slice_sizes.getValues()[i].getInt(); - // A slice_size of -1 means "all elements from start_index to the end". - // In order to support these semantics, we need to know both the start index - // and the shape of the input dimension. - if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false; - } - return true; -} - -// TF slice size can be -1, which represents all elements from start_index to -// the end. HLO slice size can't be -1. As such, we need to translate TF slice -// size -1 to HLO slice size. -static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( - Value input, Value start_indices, DenseIntElementsAttr slice_sizes, - Builder *builder) { - DenseIntElementsAttr constant_start_indices; - if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { - return mlir::cast( - hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))); - } - - auto input_ty = mlir::dyn_cast(input.getType()); - int64_t input_rank = input_ty.getRank(); - ArrayRef input_shape = input_ty.getShape(); - SmallVector normalized_sizes; - - for (int64_t i = 0; i < input_rank; ++i) { - int64_t input_size = input_shape[i]; - int64_t start_index = - constant_start_indices.getValues()[i].getInt(); - int64_t slice_size = slice_sizes.getValues()[i].getInt(); - normalized_sizes.push_back(slice_size == -1 ? input_size - start_index - : slice_size); - } - - return GetI64ElementsAttr(normalized_sizes, builder); -} - -//===----------------------------------------------------------------------===// -// XlaGather op utilities. -//===----------------------------------------------------------------------===// - -bool HasValidGatherDims(StringAttr attr) { - ::xla::GatherDimensionNumbers dims; - return dims.ParseFromString(attr.getValue().str()); -} - -GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr, - Builder *builder) { - ::xla::GatherDimensionNumbers dims; - if (!dims.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertGatherDimensionNumbers(dims, builder); -} - -//===----------------------------------------------------------------------===// -// XlaDot op utilities. -//===----------------------------------------------------------------------===// - -bool HasValidDotDims(StringAttr attr) { - ::xla::DotDimensionNumbers dims; - return dims.ParseFromString(attr.getValue().str()); -} - -DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, Builder *builder) { - ::xla::DotDimensionNumbers dims; - if (!dims.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertDotDimensionNumbers(dims, builder); -} - -bool HasValidPrecisionConfig(StringAttr attr) { - ::xla::PrecisionConfig precision; - return precision.ParseFromString(attr.getValue().str()); -} - -mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) { - ::xla::PrecisionConfig precision; - if (!precision.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertPrecisionConfig(&precision, builder); -} - -//===----------------------------------------------------------------------===// -// XlaVariadicReduceV2 op utilities. -//===----------------------------------------------------------------------===// - -static void BuildBodyWithCall(PatternRewriter &rewriter, const Location &loc, - mlir::SymbolRefAttr func, - mlir::FunctionType func_ty, Region *body) { - OpBuilder::InsertionGuard guard(rewriter); - - Block *block = rewriter.createBlock(body); - auto inputs = func_ty.getInputs(); - block->addArguments(inputs, SmallVector(inputs.size(), loc)); - mlir::func::CallOp call_op = rewriter.create( - loc, func, func_ty.getResults(), block->getArguments()); - rewriter.create(loc, call_op.getResults()); -} - -//===----------------------------------------------------------------------===// -// Op converters. -//===----------------------------------------------------------------------===// - -NamedAttribute GetConvDimensionNumbersAttr(ArrayRef spatial_dims, - tensorflow::TensorFormat format, - Builder *builder) { - int64_t num_spatial_dims = spatial_dims.size(); - int64_t num_dims = num_spatial_dims + 2; - - int64_t batch_dim = GetTensorBatchDimIndex(num_dims, format); - int64_t feature_dim = GetTensorFeatureDimIndex(num_dims, format); - - // Filters data_format is always HWIO so input channels dimension is after - // all spatial dimensions. - int64_t kernel_input_feature_dim = num_spatial_dims; - int64_t kernel_output_feature_dim = num_spatial_dims + 1; - SmallVector kernel_spatial_dimensions; - kernel_spatial_dimensions.resize(num_spatial_dims); - std::iota(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(), - 0); - - return builder->getNamedAttr( - "dimension_numbers", - ConvDimensionNumbersAttr::get( - builder->getContext(), batch_dim, feature_dim, spatial_dims, - kernel_input_feature_dim, kernel_output_feature_dim, - kernel_spatial_dimensions, batch_dim, feature_dim, spatial_dims)); -} - -// Converts a TF::BiasAddOp to HLO. -// This differs from a normal TF::AddOp with respect to how the data_format -// is handled, which can optionally require a general broadcast of the -// 'bias' term in a way that is not compatible with the standard left-padded -// broadcast semantics (i.e. NCHW will broadcast into dimension 1). -// The correct 'bias' broadcast will be synthesized manually. -class ConvertBiasAddOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::BiasAddOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - tensorflow::TensorFormat data_format; - if (!FormatFromString(op.getDataFormat().str(), &data_format)) - return op.emitOpError("invalid data format"); - - auto value_type = mlir::dyn_cast(op.getValue().getType()); - if (!value_type) return failure(); - auto feature_dim = GetFeatureDimension(data_format, value_type); - auto bias_broadcast = Broadcast1DToFeatureDim( - loc, op.getValue(), op.getBias(), feature_dim, rewriter); - Value add = rewriter.create(loc, op.getValue(), bias_broadcast); - if (add.getType() != op.getType()) { - add = rewriter.create(loc, op.getType(), add); - } - rewriter.replaceOp(op, {add}); - return success(); - } -}; - -// Conterts tf.Conv2D to mhlo.dynamic_conv. -// TODO(disc): To recover static special case's performance with adding folding, -// canonicalization func and removing ConvertConvOp. -template -class ConvertConvDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - bool GetPaddingValues(OpT &op, PatternRewriter &rewriter, Value input_size, - Value filter_size, int64_t dilation_rate, - int64_t stride, tensorflow::Padding padding_type, - Type shape_scalar_type, Value *padding_low, - Value *padding_high) const { - // Stride must be > 0 - if (stride <= 0) return false; - // Dilation rate must be >= 1 - if (dilation_rate < 1) return false; - - Location loc = op.getLoc(); - switch (padding_type) { - case tensorflow::Padding::VALID: { - auto zero = - rewriter.create(loc, 0, shape_scalar_type); - *padding_low = *padding_high = zero; - break; - } - case tensorflow::Padding::EXPLICIT: - break; - case tensorflow::Padding::SAME: { - auto zero = - rewriter.create(loc, 0, shape_scalar_type); - auto one = - rewriter.create(loc, 1, shape_scalar_type); - auto two = - rewriter.create(loc, 2, shape_scalar_type); - // See also the parallel implementation in - // GetWindowedOutputSizeFromDimsV2. effective_filter_size = (filter_size - // - 1) * dilation_rate + 1 - Value stride_value = rewriter.create( - loc, stride, shape_scalar_type); - Value dilation_rate_value = rewriter.create( - loc, dilation_rate, shape_scalar_type); - Value effective_filter_size_op = rewriter.create( - loc, one, - rewriter.create( - loc, dilation_rate_value, - rewriter.create(loc, filter_size, one))); - // output_size = (input_size + stride - 1) / stride; - Value output_size = rewriter.create( - loc, - rewriter.create( - loc, input_size, - rewriter.create(loc, stride_value, one)), - stride_value); - // std::max(int64{0}, (output_size - 1) * stride + - // effective_filter_size - input_size); - Value padding_needed = rewriter.create( - loc, - rewriter.create( - loc, effective_filter_size_op, - rewriter.create( - loc, stride_value, - rewriter.create(loc, output_size, one))), - input_size); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::sge, padding_needed, zero); - padding_needed = rewriter.create( - loc, padding_needed.getType(), cond, padding_needed, zero); - *padding_low = - rewriter.create(loc, padding_needed, two); - *padding_high = - rewriter.create(loc, padding_needed, *padding_low); - break; - } - } - return true; - } - - LogicalResult matchAndRewriteDynamicConv(OpT op, - PatternRewriter &rewriter) const { - tensorflow::TensorFormat data_format; - if (!FormatFromString(op.getDataFormat().str(), &data_format)) - return op.emitOpError("invalid data format"); - - tensorflow::Padding padding; - if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) - return failure(); - - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); - auto result_ty = mlir::dyn_cast(op.getType()); - if (!input_ty || !filter_ty || !result_ty) return failure(); - // TODO(disc): Remove this constraint once fold and canonicalization - // implemented. - if (input_ty.hasStaticShape() && filter_ty.hasStaticShape()) - return failure(); - - ArrayRef dilations = op.getDilations().getValue(); - ArrayRef strides = op.getStrides().getValue(); - ArrayRef explicit_paddings; - if (padding == tensorflow::Padding::EXPLICIT) { - // EXPLICIT padding mode and the associated attribute is attached to - // Conv2D. - explicit_paddings = - op->template getAttrOfType("explicit_paddings").getValue(); - } - - SmallVector spatial_dim_indices; - SmallVector rhs_dilations; - SmallVector window_strides; - SmallVector paddings; - - auto get_int = [](Attribute attr) { - return mlir::cast(attr).getInt(); - }; - - constexpr int num_dims = num_spatial_dims + 2; - - Location loc = op.getLoc(); - auto shape_scalar_type = rewriter.getIntegerType(32); - - auto get_const = [&](int64_t val) { - return rewriter.create(loc, val, - shape_scalar_type); - }; - auto get_dim_value = [&](Value val, int64_t dim) { - Value dim_value = rewriter.create(loc, val, dim); - return rewriter.create(loc, shape_scalar_type, - dim_value); - }; - - for (auto i : llvm::seq(0, num_spatial_dims)) { - const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); - spatial_dim_indices.push_back(dim); - - const int64_t dilation = get_int(dilations[dim]); - rhs_dilations.push_back(dilation); - const int64_t stride = get_int(strides[dim]); - window_strides.push_back(stride); - - Value pad_low, pad_high; - if (padding == tensorflow::Padding::EXPLICIT) { - pad_low = get_const(get_int(explicit_paddings[2 * dim])); - pad_high = get_const(get_int(explicit_paddings[2 * dim + 1])); - } else { - auto input_size = get_dim_value(op.getInput(), dim); - auto filter_size = get_dim_value(op.getFilter(), i); - if (!GetPaddingValues(op, rewriter, input_size, filter_size, dilation, - stride, padding, shape_scalar_type, &pad_low, - &pad_high)) { - return failure(); - } - } - paddings.push_back(pad_low); - paddings.push_back(pad_high); - } - auto rhs_dilations_attr = rewriter.getNamedAttr( - "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); - - auto window_strides_attr = rewriter.getNamedAttr( - "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); - - auto dimension_numbers_attr = GetConvDimensionNumbersAttr( - spatial_dim_indices, data_format, &rewriter); - - const int64_t input_channels = - GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format)); - // Filters data_format is always HWIO so input channels dimension is after - // all spatial dimensions. - const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); - // TensorFlow convolution op verifies that the number of input channels is - // divisible by the number of filter channels. - // For depthwise convolution the feature_group_count argument would be set - // to the input feature dimension. - const int64_t feature_group_count = - depthwise_conv ? input_channels : input_channels / filter_channels; - auto feature_group_count_attr = rewriter.getNamedAttr( - "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count)); - - auto batch_group_count_attr = rewriter.getNamedAttr( - "batch_group_count", rewriter.getI64IntegerAttr(1)); - - auto precision_config_attr = rewriter.getNamedAttr( - "precision_config", GetPrecisionConfig(&rewriter)); - - Value paddings_op = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape(2 * num_spatial_dims, - rewriter.getI32Type()), - paddings); - - SmallVector operands(op.getOperands()); - operands.push_back(paddings_op); - // Reshape the filter to {spatial_dims...., 1,in_channels * - // channel_multiplier} - if (depthwise_conv) { - ArrayRef filter_shape = filter_ty.getShape(); - llvm::SmallVector new_shape( - filter_shape.begin(), filter_shape.begin() + num_spatial_dims); - new_shape.push_back(1); - new_shape.push_back(filter_shape[num_spatial_dims] * - filter_shape[num_spatial_dims + 1]); - operands[1] = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape(new_shape, - filter_ty.getElementType()), - operands[1]); - } - NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, - dimension_numbers_attr, feature_group_count_attr, - batch_group_count_attr, precision_config_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); - return success(); - } - - LogicalResult matchAndRewrite(OpT op, - PatternRewriter &rewriter) const override { - return matchAndRewriteDynamicConv(op, rewriter); - } -}; - -using ConvertConv2DDynamic = - ConvertConvDynamic; - -// Converts the TensorFlow conv op in template to the generic HLO conv op by -// converting TensorFlow op attributes to HLO op attributes. -// -// Sample result for Conv2D: -// -// %conv = "mhlo.convolution"(%input, %filter) { -// strides = [1, 2], -// paddings = [[1, 0], [1, 1]], -// ... -// } -// -// This pattern is not defined using declarative rewrite rules as computation of -// the paddings attribute anyway requires multiple source op attributes and -// result op attributes. Defining it as declarative rewrite rule will introduce -// some duplication in the C++ helper methods. -template -class ConvertConvOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - tensorflow::TensorFormat data_format; - if (!FormatFromString(op.getDataFormat().str(), &data_format)) - return op.emitOpError("invalid data format"); - - tensorflow::Padding padding; - if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) - return failure(); - - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); - - // With the exception of input's batch dimension, input and filter need to - // have static shape for calculation of HLO paddings and feature group count - // attributes. Filter is validated here, input is mostly validated at use. - if (!input_ty || !filter_ty || !filter_ty.hasStaticShape()) - return failure(); - - ArrayRef dilations = op.getDilations().getValue(); - ArrayRef strides = op.getStrides().getValue(); - ArrayRef explicit_paddings; - if (padding == tensorflow::Padding::EXPLICIT) { - // EXPLICIT padding mode and the associated attribute is limited to - // Conv2D. So, fetch attribute by identifier instead of the - // op.explicit_paddings() attribute getter. - explicit_paddings = - op->template getAttrOfType("explicit_paddings").getValue(); - } - - SmallVector spatial_dim_indices; - SmallVector rhs_dilations; - SmallVector window_strides; - SmallVector paddings; - - auto get_int = [](Attribute attr) { - return mlir::cast(attr).getInt(); - }; - - constexpr int num_dims = num_spatial_dims + 2; - for (auto i : llvm::seq(0, num_spatial_dims)) { - const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); - spatial_dim_indices.push_back(dim); - - const int64_t dilation = get_int(dilations[dim]); - rhs_dilations.push_back(dilation); - const int64_t stride = get_int(strides[dim]); - window_strides.push_back(stride); - - int64_t pad_low, pad_high; - if (padding == tensorflow::Padding::EXPLICIT) { - pad_low = get_int(explicit_paddings[2 * dim]); - pad_high = get_int(explicit_paddings[2 * dim + 1]); - } else { - int64_t output_size; - int64_t pad_low_int64; - int64_t pad_high_int64; - int64_t input_size = input_ty.getDimSize(dim); - if (input_size == ShapedType::kDynamic) return failure(); - absl::Status status = tensorflow::GetWindowedOutputSizeVerbose( - input_size, filter_ty.getDimSize(i), dilation, stride, padding, - &output_size, &pad_low_int64, &pad_high_int64); - if (!status.ok()) return failure(); - pad_low = pad_low_int64; - pad_high = pad_high_int64; - } - paddings.push_back(pad_low); - paddings.push_back(pad_high); - } - - auto rhs_dilations_attr = rewriter.getNamedAttr( - "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); - - auto window_strides_attr = rewriter.getNamedAttr( - "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); - - auto dimension_numbers_attr = GetConvDimensionNumbersAttr( - spatial_dim_indices, data_format, &rewriter); - - const int64_t input_channels = - GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format)); - if (input_channels == ShapedType::kDynamic) return failure(); - // Filters data_format is always HWIO so input channels dimension is after - // all spatial dimensions. - const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); - // TensorFlow convolution op verifies that the number of input channels is - // divisible by the number of filter channels. - // For depthwise convolution the feature_group_count argument would be set - // to the input feature dimension. - const int64_t feature_group_count = - depthwise_conv ? input_channels : input_channels / filter_channels; - auto feature_group_count_attr = rewriter.getNamedAttr( - "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count)); - - auto batch_group_count_attr = rewriter.getNamedAttr( - "batch_group_count", rewriter.getI64IntegerAttr(1)); - - RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( - {num_spatial_dims, 2}, rewriter.getIntegerType(64)); - auto paddings_attr = rewriter.getNamedAttr( - "padding", DenseElementsAttr::get(paddings_ty, paddings)); - - auto precision_config_attr = rewriter.getNamedAttr( - "precision_config", GetPrecisionConfig(&rewriter)); - - SmallVector operands(op.getOperands()); - // Reshape the filter to {spatial_dims...., 1,in_channels * - // channel_multiplier} - if (depthwise_conv) { - ArrayRef filter_shape = filter_ty.getShape(); - llvm::SmallVector new_shape( - filter_shape.begin(), filter_shape.begin() + num_spatial_dims); - new_shape.push_back(1); - new_shape.push_back(filter_shape[num_spatial_dims] * - filter_shape[num_spatial_dims + 1]); - operands[1] = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape(new_shape, - filter_ty.getElementType()), - operands[1]); - } - NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, - dimension_numbers_attr, feature_group_count_attr, - batch_group_count_attr, paddings_attr, - precision_config_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); - return success(); - } -}; - -using ConvertConv2DOp = ConvertConvOp; -using ConvertConv3DOp = ConvertConvOp; -using ConvertDepthConv2DOp = - ConvertConvOp; - -// Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const. -class ConvertPadOpDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - // TODO(disc): To recover static special case's performance with folding and - // canonicalization. - LogicalResult matchAndRewrite(TF::PadV2Op op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto input = op.getInput(); - auto paddings = op.getPaddings(); - auto constant_values = op.getConstantValues(); - auto input_type = mlir::dyn_cast(input.getType()); - auto paddings_type = mlir::dyn_cast(paddings.getType()); - if (!input_type || !paddings_type || !paddings_type.hasStaticShape()) - return failure(); - - // TODO(disc): Remove this constraint once fold and canonicalization is - // implemented. - if (input_type.hasStaticShape()) return failure(); - - int input_rank = input_type.getRank(); - // interior padding - std::vector interior_values(input_rank, 0); - auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter); - - Value interior_padding_tensor = - rewriter.create(loc, interior_attr); - Type paddings_elem_ty = paddings_type.getElementType(); - if (!paddings_elem_ty.isInteger(64)) { - interior_padding_tensor = rewriter.create( - loc, interior_padding_tensor, paddings_elem_ty); - } - llvm::SmallVector transposed_shape = {2, input_rank}; - auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter); - Value transposed_paddings = - rewriter.create(loc, paddings, transpose_attr); - Value reshaped_paddings = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape({input_rank * 2}, - paddings_elem_ty), - transposed_paddings); - - auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter); - auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter); - auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); - Value left_padding_tensor = rewriter.create( - loc, reshaped_paddings, left_padding_start_attr, - left_padding_limit_attr, left_padding_stride_attr); - - auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter); - auto right_padding_limit_attr = - GetI64ElementsAttr({2 * input_rank}, &rewriter); - auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); - Value right_padding_tensor = rewriter.create( - loc, reshaped_paddings, right_padding_start_attr, - right_padding_limit_attr, right_padding_stride_attr); - - rewriter.replaceOpWithNewOp( - op, op.getType(), input, constant_values, left_padding_tensor, - right_padding_tensor, interior_padding_tensor); - - return success(); - } -}; - -class ConvertGatherNdOpDynamic : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - // Converts tf.GatherNdOp to mhlo.DynamicGatherOp. - // Here we leave 'slice_sizes' as an Attr, without defining a new - // DynamicGatherOp, since GatherDimensionNumbers has already provide enough - // information for shape inference and code generation of mhlo::GatherOp. '?' - // will be filled into slice_sizes for dimensions that are dynamic sized. - // TODO(disc): To recover static special case's performance with folding and - // canonicalization. - LogicalResult matchAndRewrite(TF::GatherNdOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto params = op.getParams(); - auto params_ty = mlir::dyn_cast(params.getType()); - auto indices = op.getIndices(); - auto indices_ty = mlir::dyn_cast(indices.getType()); - auto params_rank = params_ty.getRank(); - auto indices_rank = indices_ty.getRank(); - int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); - if (!params_ty || !indices_ty) return failure(); - // the last dim of indices of GatherNdOp must be fixed shaped - if (num_index_dims == ShapedType::kDynamic) return failure(); - - SmallVector slice_sizes; - slice_sizes.reserve(params_rank); - for (int64_t i = 0; i < params_rank; ++i) { - if (i < num_index_dims) { - slice_sizes.push_back(1); - } else { - // potentially dynamic - int64_t dim_size = params_ty.getDimSize(i); - slice_sizes.push_back(dim_size); - } - } - SmallVector slice_sizes_vals; - Value slice_sizes_value = nullptr; - for (int64_t i = 0; i < params_rank; ++i) { - if (i < num_index_dims) { - slice_sizes_vals.push_back(rewriter.create( - loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1))); - } else { - int64_t dim_size = params_ty.getDimSize(i); - if (dim_size != ShapedType::kDynamic) { - slice_sizes_vals.push_back(rewriter.create( - loc, - rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size))); - } else { - slice_sizes_vals.push_back(rewriter.create( - loc, indices_ty.getElementType(), - rewriter.create(loc, params, i))); - } - } - } - slice_sizes_value = - rewriter.create(loc, slice_sizes_vals); - - // collapsed_slice_dims - SmallVector collapsed_slice_dims; - collapsed_slice_dims.reserve(num_index_dims); - for (int64_t i = 0; i < num_index_dims; ++i) { - collapsed_slice_dims.push_back(i); - } - // offset_dims - SmallVector offset_dims; - offset_dims.reserve(params_rank - num_index_dims); - for (int64_t i = num_index_dims; i < params_rank; i++) { - offset_dims.push_back(i + indices_rank - 1 - num_index_dims); - } - // start_index_map - SmallVector start_index_map; - offset_dims.reserve(num_index_dims); - for (int64_t i = 0; i < num_index_dims; i++) { - start_index_map.push_back(i); - } - // index_vector_dim - int64_t index_vector_dim = indices_rank - 1; - - auto dims_attr = GatherDimensionNumbersAttr::get( - rewriter.getContext(), offset_dims, collapsed_slice_dims, - /*operandBatchingDims=*/{}, - /*startIndicesBatchingDims=*/{}, start_index_map, index_vector_dim); - // TODO(disc): Remove this if-statement once fold and canonicalization is - // implemented. - if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) { - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getParams(), op.getIndices(), dims_attr, - GetI64ElementsAttr(slice_sizes, &rewriter)); - } else { - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getParams(), op.getIndices(), slice_sizes_value, - dims_attr); - } - return success(); - } -}; - -// Converts BF16 FloorDiv op to have casting operators on either end as BF16 -// division can result in strange behavior. -// -// floordiv = cast(floordiv(cast(left), cast(right)))) -// -// %left_cast = cast(%left) -// %right_cast = cast(%right) -// %div = div(%left, %left) -// %floored = floor(%div) -// %floored_cast = cast(%floored) -// -// Required to manually specify the intermediate types. -class ConvertBF16FloorDivOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::FloorDivOp op, - PatternRewriter &rewriter) const override { - auto l = mlir::dyn_cast>(op.getX()); - auto r = mlir::dyn_cast>(op.getY()); - if (!l || !r) return failure(); - - auto element_type = getElementTypeOrSelf(l.getType()); - if (!element_type.isBF16()) return failure(); - - auto out_type = op.getZ().getType(); - - l = rewriter.create(op.getLoc(), l, rewriter.getF32Type()); - r = rewriter.create(op.getLoc(), r, rewriter.getF32Type()); - - auto intermediate = rewriter.create( - op.getLoc(), - ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l, - r); - - auto floor_op = - rewriter.create(op.getLoc(), out_type, intermediate); - rewriter.replaceOp(op, floor_op.getResult()); - return success(); - } -}; - -class ConvertBroadcastToOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::BroadcastToOp op, - PatternRewriter &rewriter) const override { - auto input_type = mlir::dyn_cast(op.getInput().getType()); - auto output_type = op.getOutput().getType(); - if (!input_type) { - return rewriter.notifyMatchFailure(op, "requires ranked input shape"); - } - llvm::SmallVector broadcast_dimensions; - if (input_type.getRank() > 0) { - auto ranked_output_type = mlir::dyn_cast(output_type); - if (!ranked_output_type) { - return rewriter.notifyMatchFailure(op, "requires ranked output shape"); - } - auto rank_diff = ranked_output_type.getRank() - input_type.getRank(); - // The tf.BroadcastTo op performs "right-aligned" numpy-style - // broadcasting. - broadcast_dimensions = llvm::to_vector<4>( - llvm::seq(rank_diff, ranked_output_type.getRank())); - } - rewriter.replaceOpWithNewOp( - op, output_type, op.getInput(), op.getShape(), - rewriter.getI64TensorAttr(broadcast_dimensions)); - return success(); - } -}; - -/// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis -/// have to be a constant. -class ConvertRollOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::RollOp op, - PatternRewriter &rewriter) const override { - auto shift_ty = mlir::dyn_cast(op.getShift().getType()); - if (!shift_ty || shift_ty.getRank() != 0) { - return rewriter.notifyMatchFailure( - op, "require the type of shift to be 0D tensor"); - } - - APInt val; - if (!matchPattern(op.getAxis(), m_ConstantInt(&val))) { - return rewriter.notifyMatchFailure(op, "require axis to be constant"); - } - int axis = val.getSExtValue(); - - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - if (!input_ty || !input_ty.hasStaticShape()) { - return rewriter.notifyMatchFailure( - op, "require the type of input to have static shapes"); - } - ArrayRef input_shape = input_ty.getShape(); - int input_rank = input_ty.getRank(); - if (axis < 0) axis += input_rank; - - // Adjust large offsets into [0, axis_size). This also makes negative - // offsets positive. - // offset = ((offset % axis_size) + axis_size) % axis_size - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value offset = op.getShift(); - auto axis_size = b.create(b.getIntegerAttr( - getElementTypeOrSelf(offset.getType()), input_shape[axis])); - offset = b.create( - b.create(b.create(offset, axis_size), axis_size), - axis_size); - - // Stack two copies of the dimension, then slice from the calculated - // offset. This also works if shift is not constant. - // DynamicSliceOp requires the sizes being integer, and we can get the - // information from input shape. - auto concat = b.create( - ValueRange{op.getInput(), op.getInput()}, b.getI64IntegerAttr(axis)); - Value zero = b.create( - b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0)); - SmallVector slice_begin_indices(input_rank, zero); - slice_begin_indices[axis] = b.create(axis_size, offset); - rewriter.replaceOpWithNewOp( - op, input_ty, concat, slice_begin_indices, - rewriter.getI64TensorAttr(input_shape)); - return success(); - } -}; - -/// Converts a TF::LeakyReluOp to HLO. -/// LeakyRelu(x) = alpha * x if x < 0 else x. -class ConvertLeakyReluOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::LeakyReluOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value features = op.getFeatures(); - - // Use ConstantLike for `alpha` to match the shape of feature. - auto alphaVal = chlo::getConstantLike( - rewriter, loc, op.getAlpha().convertToFloat(), features); - Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); - - Value leakyActivationVal = - rewriter.create(loc, features, alphaVal); - - Value compareGtZero = rewriter.create( - loc, features, zeroVal, ComparisonDirection::GT); - - rewriter.replaceOpWithNewOp(op, compareGtZero, features, - leakyActivationVal); - return success(); - } -}; - -/// Converts a TF::LeakyReluGradOp to HLO. -/// LeakyReluGrad(gradient, inputs) = gradient if input > 0 -/// else alpha * gradient. -class ConvertLeakyReluGradOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::LeakyReluGradOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value gradients = op.getGradients(); - Value features = op.getFeatures(); - auto featureType = features.getType(); - - // Use ConstantLike for `alpha` to match the shape of feature. - auto alphaVal = chlo::getConstantLike( - rewriter, loc, op.getAlpha().convertToFloat(), features); - Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); - - Value leakyGradientVal = - rewriter.create(loc, gradients, alphaVal); - - Value compareGtZero = rewriter.create( - loc, features, zeroVal, ComparisonDirection::GT); - - rewriter.replaceOpWithNewOp(op, featureType, compareGtZero, - gradients, leakyGradientVal); - return success(); - } -}; - -// Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix. -// For a Rank-2 input, it creates the following ops: -// %1 = "mhlo.iota"() {iota_dimension = 0 : i64} -// %2 = "mhlo.iota"() {iota_dimension = 1 : i64} -// %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"} -// %4 = mhlo.constant dense<0.000000e+00> : tensor -// %5 = "mhlo.broadcast"(%4) -// %6 = "mhlo.select"(%3, %input, %5) -// %7 = "mhlo.reduce"(%6, %4) ({ -// ^bb0(%arg1: tensor, %arg2: tensor): -// %9 = mhlo.add %arg1, %arg2 : tensor -// "mhlo.return"(%9) : (tensor) -> () -// }) {dimensions = dense<0> : tensor<1xi64>} -// -// If the input's rank N is greater than 2, we will reshape it to R2 first and -// create the above ops, then reshape it back to rank N/2. -class ConvertDiagPartOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::DiagPartOp op, - PatternRewriter &rewriter) const override { - auto input_type = mlir::dyn_cast(op.getInput().getType()); - if (!input_type || !input_type.hasStaticShape()) return failure(); - int64_t num_dims = input_type.getRank(); - if (num_dims < 2 || num_dims % 2 != 0) return failure(); - const int64_t out_dims = num_dims / 2; - - int64_t new_size = 1; - llvm::SmallVector new_dims; - for (int i = 0; i < out_dims; i++) { - if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims)) - return op.emitOpError("invalid dimensions size"); - new_size *= input_type.getDimSize(i); - new_dims.push_back(input_type.getDimSize(i)); - } - Value reshaped_input = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape({new_size, new_size}, - input_type.getElementType()), - op.getInput()); - auto iota_type = tensorflow::GetTypeFromTFTensorShape( - {new_size, new_size}, rewriter.getIntegerType(32)); - auto iota0 = rewriter.create(op.getLoc(), iota_type, - rewriter.getI64IntegerAttr(0)); - auto iota1 = rewriter.create(op.getLoc(), iota_type, - rewriter.getI64IntegerAttr(1)); - Value compare = rewriter.create(op.getLoc(), iota0, iota1, - ComparisonDirection::EQ); - Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), - 0, &rewriter); - Value zero_matrix = rewriter.create( - op.getLoc(), reshaped_input.getType(), zero, - GetI64ElementsAttr({new_size, new_size}, &rewriter)); - Value masked = - rewriter.create(op.getLoc(), reshaped_input.getType(), - compare, reshaped_input, zero_matrix); - auto reduce = rewriter.create(op.getLoc(), masked, zero, - GetI64ElementsAttr({0}, &rewriter), - input_type.getElementType()); - assert(!input_type.getElementType().isInteger(1) && - "data type should not be i1"); - BuildReduceBody(input_type.getElementType(), &reduce.getBody(), - &rewriter); - rewriter.replaceOpWithNewOp( - op, - tensorflow::GetTypeFromTFTensorShape(new_dims, - input_type.getElementType()), - reduce.getResult(0)); - return success(); - } -}; - -// Converts TensorFlow MatrixDiagPartOp to HLO ops. -class ConvertMatrixDiagPartV3Op - : public OpRewritePattern { - using Shape = llvm::SmallVector; - - // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s) - // with k. This can be either a single value (for a single diagonal) or a - // tuple of two values (starting and ending diagonal, for a band). - LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const { - DenseIntElementsAttr kattr; - if (!matchPattern(op.getK(), m_Constant(&kattr))) { - return failure(); - } - DenseIntElementsAttr::iterator it = kattr.begin(); - (*k)[0] = (*it).getSExtValue(); - it++; - if (it == kattr.end()) { - // Handle input like e.g. "k = 5", in which case we extract a single - // diagonal. - (*k)[1] = (*k)[0]; - } else { - // Handle input like e.g. "k = [-1, 1]", in which case we extract a - // band (multiple diagonals). - (*k)[1] = (*it).getSExtValue(); - } - return success(); - } - - // Utility method for broadcasting integer constants to a given shape. - BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant, - int int_size, PatternRewriter &rewriter) const { - return rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape(shape, - rewriter.getIntegerType(int_size)), - GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant, - &rewriter), - GetI64ElementsAttr(shape, &rewriter)); - } - - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - ShapedType input_type = mlir::dyn_cast(op.getInput().getType()); - - // Align is a string specifying how superdiagonals and subdiagonals should - // be aligned/padded for diagonals that are shorter than max_diag_len. The - // format is "{super}_{sub}", with {super} the superdiagonal alignment and - // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the - // left, "RIGHT" means rows will be padded ot the right. The default is - // "RIGHT_LEFT". - StringRef align = op->getAttrOfType("align").getValue(); - enum Alignment { kLeft, kRight }; - - // default is RIGHT_LEFT - Alignment superdiagonal_align = kRight; - Alignment subdiagonal_align = kLeft; - - if (align == "RIGHT_LEFT") { - superdiagonal_align = kRight; - subdiagonal_align = kLeft; - } else if (align == "RIGHT_RIGHT") { - superdiagonal_align = kRight; - subdiagonal_align = kRight; - } else if (align == "LEFT_RIGHT") { - superdiagonal_align = kLeft; - subdiagonal_align = kRight; - } else if (align == "LEFT_LEFT") { - superdiagonal_align = kLeft; - subdiagonal_align = kLeft; - } else { - return failure(); // unsupported alignment - } - - // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and - // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L]. - if (!input_type || !input_type.hasStaticShape()) return failure(); - int64_t num_dims = input_type.getRank(); - if (num_dims < 2) return failure(); - int64_t rows = input_type.getDimSize(num_dims - 2); // rows - int64_t cols = input_type.getDimSize(num_dims - 1); // cols - - // We extract the diagonals from k[0] up to and including k[1]. - // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract - // the main diagonal). It's negative for subdiagonals (under and to the left - // of the main diagonal) and positive for superdiagonals (above and to the - // right of the main diagonal). - int64_t k[2]; - if (failed(ExtractK(op, &k))) return failure(); - int num_diags = k[1] - k[0] + 1; - - // Shifting diagonals away from the main diagonal might shorten them. This - // is the longest diagonal we will see. We make this the last dimension of - // the output shape. - int64_t max_diag_len = - std::min(rows + std::min(k[1], static_cast(0)), - cols + std::min(-k[0], static_cast(0))); - - // The first dimension is the index vector dimension we'll use for gather. - // It's 1 here, but will be 2 once we glue x and y together. - Shape indices_shape({1, num_diags, max_diag_len}); - - RankedTensorType iota_type = tensorflow::GetTypeFromTFTensorShape( - indices_shape, rewriter.getIntegerType(32)); - Value iotaM = - rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(1)); - Value iotaN = - rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(2)); - - // Boradcasted constants, of the same shape as iotaM and iotaN. - Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter); - Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter); - Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter); - Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter); - Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter); - Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter); - Value b_max_diag_len = - BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter); - - // d = k[1] - m - // (A.k.a. the number of the diagonal, depending on m. Note that we - // subtract m here. This means we start with the superdiagonals and - // move downwards towards the subdiagonals. So the start indices will - // be decreasing.) - Value d = rewriter.create(loc, b_k1, iotaM); - Value neg_d = rewriter.create(loc, d); - - // diag_len_d = min(rows + min(d, 0), cols - max(d, 0)) - // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.) - Value diag_len_d = rewriter.create( - loc, - rewriter.create(loc, b_rows, - rewriter.create(loc, d, b_zero)), - rewriter.create(loc, b_cols, - rewriter.create(loc, d, b_zero))); - - // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise. - Value cmp; - if (subdiagonal_align == kRight && superdiagonal_align == kRight) { - cmp = b_true; - } else if (superdiagonal_align == kRight) { - // offset = d>=0 ? max_diag_len - diag_len_d : 0 - cmp = rewriter.create(loc, d, b_zero); - } else if (subdiagonal_align == kRight) { - // offset = d<=0 ? max_diag_len - diag_len_d : 0 - cmp = rewriter.create(loc, d, b_zero); - } else { - // offset = 0 - cmp = b_false; - } - - // This offset shifts the diagonals to the "left" or "right", depending - // on alignment. - Value offset = rewriter.create( - loc, b_zero.getType(), cmp, - rewriter.create(loc, b_max_diag_len, diag_len_d), b_zero); - - // x = max(d, 0) - offset - // y = max(-d, 0) - offset - Value x = rewriter.create( - loc, rewriter.create(loc, d, b_zero), offset); - Value y = rewriter.create( - loc, rewriter.create(loc, neg_d, b_zero), offset); - - Value n_plus_x = rewriter.create(loc, iotaN, x); - Value n_plus_y = rewriter.create(loc, iotaN, y); - - // GatherOp is happy about letting us index out of bounds values, but those - // values will be undefined. So we mask them later. Set up the boolean - // expression that tells us which entries, in the output shape, are out of - // bounds and thus become the padding_value. - Value x_in_bounds = rewriter.create( - loc, - rewriter.create(loc, b_false.getType(), n_plus_x, - b_zero), - rewriter.create(loc, b_false.getType(), n_plus_x, b_cols)); - Value y_in_bounds = rewriter.create( - loc, - rewriter.create(loc, b_false.getType(), n_plus_y, - b_zero), - rewriter.create(loc, b_false.getType(), n_plus_y, b_rows)); - Value in_bounds = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape(Shape({num_diags, max_diag_len}), - rewriter.getIntegerType(1)), - rewriter.create(loc, x_in_bounds, y_in_bounds)); - - // Now combine x and y into the index data structure needed for gather. - Shape concat_shape({2, num_diags, max_diag_len}); - Value start_indices = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape(concat_shape, - rewriter.getIntegerType(32)), - mlir::ValueRange({n_plus_y, n_plus_x}), - mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0)); - - // Shape of the final output. (Except for dimension folding in the - // single diagonal case.) - Shape output_shape; - for (int i = 0; i < num_dims - 2; i++) { - output_shape.push_back(input_type.getDimSize(i)); - } - output_shape.push_back(num_diags); - output_shape.push_back(max_diag_len); - - // A slice is the shape of what GatherOp copies per lookup. So the last - // two dimensions (M, N in the matrix-diag-part docs) are where we go - // through entry by entry. - ArrayRef input_shape = input_type.getShape(); - int input_shape_size = input_shape.size(); - Shape slice_sizes(input_shape.begin(), input_shape.end()); - int slice_dimensions = slice_sizes.size(); - slice_sizes[slice_dimensions - 2] = - std::min((int64_t)1, input_shape[input_shape_size - 2]); - slice_sizes[slice_dimensions - 1] = - std::min((int64_t)1, input_shape[input_shape_size - 1]); - - // Dimensions of the input we won't see in the output (M and N). - SmallVector collapsed_dims( - {slice_dimensions - 2, slice_dimensions - 1}); - - // Which dimensions (in the input) the two offset "columns" map to. - SmallVector start_index_map({num_dims - 2, num_dims - 1}); - - // Gather the diagonal entries. - // TODO(kramm): For a single diagonal, this might be slower than the - // mask + sum approach. Special-case num_diags==1? - auto dims_attr = GatherDimensionNumbersAttr::get( - rewriter.getContext(), - /*offsetDims=*/llvm::to_vector<4>(llvm::seq(0, num_dims - 2)), - /*collapsedSliceDims=*/collapsed_dims, - /*operandBatchingDims=*/{}, - /*startIndicesBatchingDims=*/{}, start_index_map, - /*indexVectorDim=*/0); - Value gather = rewriter.create( - loc, op.getInput(), start_indices, dims_attr, - GetI64ElementsAttr(slice_sizes, &rewriter)); - - // We now need to broadcast the "in_bounds" boolean expression, as well as - // the padding value, to do the final select. - Shape broadcast_bounds; - for (int i = 0; i < output_shape.size() - 2; i++) { - broadcast_bounds.push_back(output_shape[i]); - } - Value b_in_bounds = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape(output_shape, - rewriter.getIntegerType(1)), - in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter)); - Value b_padding = rewriter.create( - loc, op.getPaddingValue(), GetI64ElementsAttr(output_shape, &rewriter)); - - // Replace all out-of-bounds values in the result with padding_value. - Value result = - rewriter.create(loc, b_in_bounds, gather, b_padding); - - if (num_diags == 1) { - // matrix_diag_part folds away the 1-sized band dimension if we only - // extract a single diagonal. - result = rewriter.create(loc, op.getType(), result); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; - -// Converts TensorFlow EinsumOp to HLO EinsumOp -class ConvertEinsumOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::EinsumOp op, - PatternRewriter &rewriter) const override { - // Prepend `,` to equation if unary einsum. - std::string equation_str = op.getEquation().str(); - llvm::SmallVector inputs; - - // Unary einsum prepends `,` to equation and - // creates a scalar constant 1.0 for first operand. - if (op.getN() == 1) { - equation_str = "," + equation_str; - inputs.push_back(rewriter.create( - op.getLoc(), hlo::getScalarOfType( - mlir::getElementTypeOrSelf(op.getOperand(0)), 1))); - } - // Insert remaining operands into inputs, TF op verifier requires there be - // 0 or 1 operands. - auto operands = op.getInputs(); - inputs.insert(inputs.end(), operands.begin(), operands.end()); - assert(inputs.size() == 2); - - rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], - inputs[1], equation_str); - return success(); - } -}; - -// Bypasses IdentityN op. -class ConvertIdentityNOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::IdentityNOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOp(op, op.getOperands()); - return success(); - } -}; - -template -class ConvertFFTOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - auto input_ty = mlir::cast(op.getInput().getType()); - if (!input_ty.hasRank()) { - return failure(); - } - auto input_shape = input_ty.getShape(); - DenseIntElementsAttr fft_length_attr; - if (!matchPattern(op.getFftLength(), m_Constant(&fft_length_attr))) { - return failure(); - } - int64_t fft_length; - if (fft_length_attr.getNumElements() != 0) { - fft_length = fft_length_attr.getValues()[0].getInt(); - } else { - return failure(); - } - - int64_t expected_dim = fft_length; - std::string fft_string = "RFFT"; - if (typeid(OpTy) == typeid(TF::IRFFTOp)) { - expected_dim = fft_length / 2 + 1; - fft_string = "IRFFT"; - } - Location loc = op.getLoc(); - - // The inner-most dim cannot be dynamic. - if (input_ty.isDynamicDim(input_shape.size() - 1)) { - return failure(); - } - - auto expected_shape = llvm::to_vector<4>(input_shape.drop_back()); - expected_shape.push_back(expected_dim); - - // Zero pad or truncate the last axis - Value reshaped = op.getInput(); - SmallVector begin_indices(input_shape.size(), 0); - SmallVector strides(input_shape.size(), 1); - - // Last dim larger than expected_dim, slice the input - if (input_shape.back() > expected_dim) { - reshaped = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape(expected_shape, - input_ty.getElementType()), - op.getInput(), GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(expected_shape, &rewriter), - GetI64ElementsAttr(strides, &rewriter)); - - // Last dim smaller than expected_dim, zero-pad the input - } else if (input_ty.getShape().back() < expected_dim) { - SmallVector no_padding(input_shape.size(), 0); - SmallVector padding(input_shape.size() - 1, 0); - padding.push_back(expected_dim - input_shape.back()); - Value zero = - GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter); - reshaped = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape(expected_shape, - input_ty.getElementType()), - op.getInput(), zero, GetI64ElementsAttr(no_padding, &rewriter), - GetI64ElementsAttr(padding, &rewriter), - GetI64ElementsAttr(no_padding, &rewriter)); - } - - rewriter.replaceOpWithNewOp( - op, op.getType(), reshaped, - FftTypeAttr::get(rewriter.getContext(), - symbolizeFftType(fft_string).value()), - rewriter.getI64TensorAttr(fft_length)); - return success(); - } -}; - -using ConvertRFFTOp = ConvertFFTOp; -using ConvertIRFFTOp = ConvertFFTOp; - -// The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO -// BatchNormGradOp for training and a sequence of binary ops for inference. -// TODO(b/145536565): move to legalize_tf_patterns.td if it applies. -template -class ConvertFusedBatchNormGradBase - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FusedBatchNormGradOpT op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value grad = op.getYBackprop(); - Value act = op.getX(); - Value scale = op.getScale(); - Value mean = op.getReserveSpace_1(); - Value var = op.getReserveSpace_2(); - - // TODO(b/141785544): Update this to not require static shapes. - // activation shape needs to be static to convert negative indices in - // TensorFlow to absolute indices required by HLO. - RankedTensorType act_type = mlir::dyn_cast(act.getType()); - if (!act_type) return failure(); - Type act_ele_type = act_type.getElementType(); - // To support mixed precision, the statistics type, which maybe more - // precise than the input types, are used for this op. - Type kernel_type = mlir::cast(scale.getType()).getElementType(); - grad = rewriter.create(loc, grad, kernel_type); - act = rewriter.create(loc, act, kernel_type); - - tensorflow::TensorFormat data_format; - if (!FormatFromString(op.getDataFormat().str(), &data_format)) - return op.emitOpError("invalid data format"); - - auto feature_dim_attr = getFeatureDimensionAttr(rewriter, data_format, act); - auto feature_dim = feature_dim_attr.getValue().getSExtValue(); - - // Gets the result values. - Value x_backprop, scale_backprop, offset_backprop; - if (op.getIsTraining()) { // training - // TODO(b/145536565): handle GPU logic separately. - // Infers the output type with the converted `act`. - Type feature_type = tensorflow::GetTypeFromTFTensorShape( - {GetDimSize(act_type, feature_dim)}, kernel_type); - - SmallVector operand_types = {act.getType(), feature_type, - feature_type}; - auto training_op = rewriter.create( - loc, operand_types, act, scale, mean, var, grad, op.getEpsilon(), - feature_dim); - - x_backprop = training_op.getResult(0); - - scale_backprop = training_op.getResult(1); - - offset_backprop = training_op.getResult(2); - } else { // inference - SmallVector non_feature_dims; - for (int64_t i = 0; i < act_type.getRank(); ++i) { - if (i == feature_dim) continue; - non_feature_dims.push_back(i); - } - auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); - auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); - - // scratch1 = rsqrt(var + epsilon) - RankedTensorType scalar_float = - tensorflow::GetTypeFromTFTensorShape({}, kernel_type); - auto epsilon = rewriter.create( - loc, DenseFPElementsAttr::get(scalar_float, {op.getEpsilon()})); - auto add_op = rewriter.create( - loc, var, epsilon.getResult(), scalar_broadcast_dims); - - Value scratch1 = rewriter.create(loc, add_op); - - // scratch2 = sum(y_backprop * (x - mean)) - auto sub_op = rewriter.create( - loc, act, - Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); - auto weighted_grad = rewriter.create(loc, grad, sub_op); - Value scratch2 = - ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); - - // x_backprop = y_backprop * (scale * scratch1) - auto scaled_grad = - rewriter.create(loc, op.getScale(), scratch1); - x_backprop = rewriter.create( - loc, grad, - Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, - rewriter)); - - // scale_backprop = scratch2 * scratch1 - scale_backprop = rewriter.create(loc, scratch1, scratch2); - - // offset_backprop = sum(y_backprop) - offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); - } - - x_backprop = rewriter.create(loc, x_backprop, act_ele_type); - Value last_val[2]; - if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) { - // It doesn't matter what values we provide for the last 2 results. - last_val[0] = last_val[1] = op.getX(); - } else { - auto const_val = rewriter.create( - op.getLoc(), DenseElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape( - {0}, getElementTypeOrSelf(op.getResult(3))), - 0.0)); - auto maybe_cast = [&](Value val, Type t) -> Value { - if (val.getType() == t) return val; - return rewriter.create(op.getLoc(), t, val); - }; - last_val[0] = maybe_cast(const_val, op.getResult(3).getType()); - last_val[1] = maybe_cast(const_val, op.getResult(4).getType()); - } - rewriter.replaceOp( - op, {/*x_backprop=*/x_backprop, - /*scale_backprop=*/scale_backprop, - /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]}); - return success(); - } -}; - -using ConvertFusedBatchNormGradOp = - ConvertFusedBatchNormGradBase; -using ConvertFusedBatchNormGradV2Op = - ConvertFusedBatchNormGradBase; -using ConvertFusedBatchNormGradV3Op = - ConvertFusedBatchNormGradBase; - -// Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or -// HLO BatchNormInferenceOp, depending on the value of the 'is_training' -// parameter. -template -class ConvertFusedBatchNormBase : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FusedBatchNormOpT op, - PatternRewriter &rewriter) const override { - tensorflow::TensorFormat data_format; - if (!FormatFromString(op.getDataFormat().str(), &data_format)) - return op.emitOpError("invalid data format"); - - auto feature_dim = - getFeatureDimensionAttr(rewriter, data_format, op.getX()); - - auto input_type_tensor = mlir::cast(op.getX().getType()); - auto input_element_type = input_type_tensor.getElementType(); - - auto scale_type_tensor = mlir::cast(op.getScale().getType()); - auto scale_element_type = scale_type_tensor.getElementType(); - - auto mean_type_tensor = mlir::cast(op.getMean().getType()); - auto mean_element_type = mean_type_tensor.getElementType(); - // In the training case, dimensions of input tensors must be static. - if (op.getIsTraining() && (!input_type_tensor.hasStaticShape() || - !scale_type_tensor.hasStaticShape() || - !mean_type_tensor.hasStaticShape())) - return failure(); - - // TODO(b/69928690): Support mixed precision in the XLA batch - // normalization operators. As a workaround, create a new x with the same - // element type as scale (which may be more precise than the input type). - Value bn_train_input = rewriter.create( - op.getLoc(), op.getX(), scale_element_type); - TensorType bn_train_input_type_tensor = - mlir::cast(bn_train_input.getType()); - - if (op.getIsTraining()) { - // Training case. - auto operand_shape = bn_train_input_type_tensor.getShape(); - // The mean and variance are each 1 dimensional arrays the size of the - // feature dimension, with the same element type as the operand (x). - // This shape must be constructed manually because the mean and variance - // inputs are empty in the training case. - Type mean_var_type = tensorflow::GetTypeFromTFTensorShape( - {operand_shape[feature_dim.getInt()]}, scale_element_type); - // Op result type is a tuple of 3 values: output with same shape as input; - // batch_mean, and batch_var. - SmallVector operand_types = {bn_train_input_type_tensor, - mean_var_type, mean_var_type}; - auto bn_train_op = rewriter.create( - op.getLoc(), operand_types, bn_train_input, op.getScale(), - op.getOffset(), op.getEpsilon(), feature_dim.getInt()); - // HLO op outputs a tuple of tensors. Extract those results. - Value y_out = bn_train_op.getResult(0); - Value batch_mean = bn_train_op.getResult(1); - Value reserve_space_1 = batch_mean; - Value batch_variance = bn_train_op.getResult(2); - - // Apply Bessel's correction on the variance. - int total_input_size = bn_train_input_type_tensor.getNumElements(); - int total_scale_size = scale_type_tensor.getNumElements(); - int sample_size = - total_scale_size > 0 ? total_input_size / total_scale_size : 0; - int sample_size_minus_one = std::max(1, sample_size - 1); - double factor = static_cast(sample_size) / - static_cast(sample_size_minus_one); - auto factor_const_op = rewriter.create( - op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); - - Value corrected_variance = rewriter.create( - op.getLoc(), batch_variance.getType(), batch_variance, - factor_const_op, /*broadcast_dimensions=*/DenseI64ArrayAttr()); - - // Convert back to input type to stay aligned with expected output type - // for TF op. - y_out = rewriter.create(op.getLoc(), y_out, - input_element_type); - - float exponential_avg_factor = - op.getExponentialAvgFactor().convertToFloat(); - if (exponential_avg_factor != 1.0f) { - auto alpha = rewriter.create( - op.getLoc(), rewriter.getFloatAttr(mean_element_type, - 1.0f - exponential_avg_factor)); - auto beta = rewriter.create( - op.getLoc(), - rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); - - // new_running_mean = alpha * old_mean + beta * batch_mean. - auto alpha_mul_old_mean = rewriter.create( - op.getLoc(), op.getMean().getType(), alpha, op.getMean(), - /*broadcast_dimensions=*/DenseI64ArrayAttr()); - auto beta_mul_batch_mean = rewriter.create( - op.getLoc(), batch_mean.getType(), beta, batch_mean, - /*broadcast_dimensions=*/DenseI64ArrayAttr()); - batch_mean = rewriter.create( - op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, - /*broadcast_dimensions=*/DenseI64ArrayAttr()); - - // new_running_variance = alpha * old_variance + beta * batch_variance. - auto alpha_mul_old_variance = rewriter.create( - op.getLoc(), op.getVariance().getType(), alpha, op.getVariance(), - /*broadcast_dimensions=*/DenseI64ArrayAttr()); - auto beta_mul_batch_variance = rewriter.create( - op.getLoc(), corrected_variance.getType(), beta, corrected_variance, - /*broadcast_dimensions=*/DenseI64ArrayAttr()); - corrected_variance = rewriter.create( - op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, - /*broadcast_dimensions=*/DenseI64ArrayAttr()); - } - - if (std::is_same::value) { - // FusedBatchNormV2 expects 4 outputs. - // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2". - // They are used to pass the per-batch mean and variance to the - // gradiant. Here we maintain the same behavior by setting them to the - // mean and variance calculated by BatchNormTraining. - rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, - /*batch_variance=*/corrected_variance, - /*reserve_space_1=*/reserve_space_1, - /*reserve_space_2=*/batch_variance}); - } else { // TF::FusedBatchNormV3Op - // For FusedBatchNormV3Op, also create a constant tensor to forward to - // last reserve_space_3 output. - auto reserve_space_3_type = - mlir::cast(op.getResult(5).getType()); - int num_elements = reserve_space_3_type.hasStaticShape() - ? reserve_space_3_type.getNumElements() - : 0; - auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( - {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); - Value dummy_const = rewriter.create( - op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); - if (const_attr_type != reserve_space_3_type) - dummy_const = rewriter.create( - op.getLoc(), reserve_space_3_type, dummy_const); - rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, - /*batch_variance=*/corrected_variance, - /*reserve_space_1=*/reserve_space_1, - /*reserve_space_2=*/batch_variance, - /*reserve_space_3=*/dummy_const}); - } - } else { // Inference case. - auto bn_train_op = rewriter.create( - op.getLoc(), - /*result_type=*/bn_train_input_type_tensor, bn_train_input, - op.getScale(), op.getOffset(), op.getMean(), op.getVariance(), - op.getEpsilon(), feature_dim.getInt()); - - // Convert back to input type to stay aligned with expected output type - // for TF op. - auto y_out = rewriter.create(op.getLoc(), bn_train_op, - input_element_type); - - // The mean, variance, and reserved space outputs of the batch norm op are - // not used for inference. It doesn't matter what values we provide for - // the last 5 results as long as they are of the same type. Forward - // input mean and variance to output mean, variance, reserved_space_1 and - // reserved_space_2. - if (std::is_same::value) { - rewriter.replaceOp(op, {/*y=*/y_out, - /*batch_mean=*/op.getMean(), - /*batch_variance=*/op.getVariance(), - /*reserve_space_1=*/op.getMean(), - /*reserve_space_2=*/op.getVariance()}); - } else { - // For FusedBatchNormV3Op, also create a constant tensor to forward to - // last reserve_space_3 output. - auto reserve_space_3_type = - mlir::cast(op.getResult(5).getType()); - int num_elements = reserve_space_3_type.hasStaticShape() - ? reserve_space_3_type.getNumElements() - : 0; - auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( - {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); - Value dummy_const = rewriter.create( - op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); - if (const_attr_type != reserve_space_3_type) - dummy_const = rewriter.create( - op.getLoc(), reserve_space_3_type, dummy_const); - rewriter.replaceOp(op, {/*y=*/y_out, - /*batch_mean=*/op.getMean(), - /*batch_variance=*/op.getVariance(), - /*reserve_space_1=*/op.getMean(), - /*reserve_space_2=*/op.getVariance(), - /*reserve_space_3=*/dummy_const}); - } - } - return success(); - } -}; - -using ConvertFusedBatchNormV2Op = - ConvertFusedBatchNormBase; -using ConvertFusedBatchNormV3Op = - ConvertFusedBatchNormBase; - -using PaddingArray = std::vector>; - -// Returns padding values for ReduceWindow op as a vector of pairs. -// -// Requires padding to be either 'SAME' or 'VALID' and the number of input -// dimensions to be equal to the size of window dimensions and window strides. -template -static PaddingArray GetReduceWindowPaddingAsArray( - llvm::ArrayRef input_dims, ArrayAttr window_dims, - ArrayAttr window_strides, StringRef padding, Builder *builder) { - if (padding == "VALID") { - return PaddingArray(num_dims, std::make_pair(0, 0)); - } - assert(padding == "SAME"); - llvm::SmallVector input_shape, window_shape, strides; - input_shape.reserve(input_dims.size()); - window_shape.reserve(window_shape.size()); - strides.reserve(window_strides.size()); - - for (const auto &dim : input_dims) input_shape.push_back(dim); - for (Attribute attr : window_dims) - window_shape.push_back(mlir::cast(attr).getInt()); - for (Attribute attr : window_strides) - strides.push_back(mlir::cast(attr).getInt()); - - PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides, - ::xla::Padding::kSame); - return paddings; -} - -// Same as GetReduceWindowPaddingAsArray but returns padding as -// DenseIntElementsAttr. Returns empty attribute for `VALID` padding. -template -static DenseIntElementsAttr GetReduceWindowPaddingAsAttr( - llvm::ArrayRef input_dims, ArrayAttr window_dims, - ArrayAttr window_strides, StringRef padding, Builder *builder) { - if (padding == "VALID") return {}; - assert(padding == "SAME"); - PaddingArray paddings = GetReduceWindowPaddingAsArray( - input_dims, window_dims, window_strides, padding, builder); - int64_t rank = paddings.size(); - llvm::SmallVector flatten_paddings(rank * 2); - for (int i = 0; i < rank; i++) { - flatten_paddings[2 * i] = paddings[i].first; - flatten_paddings[2 * i + 1] = paddings[i].second; - } - return DenseIntElementsAttr::get(tensorflow::GetTypeFromTFTensorShape( - {rank, 2}, builder->getIntegerType(64)), - flatten_paddings); -} - -// Helper function for dividing each entry of `pooled` by the count of its -// corresponding window, i.e., the number of non-padding entries of the window -// which an `AvgPool` operation performed on an `input_shape`-tensor would map -// to this entry, depending on `ksize` and `strides`. This function is used for -// `AvgPool` and `AvgPoolGrad` legalizations. -// `zero` is passed as a parameter because it can be reused from caller level. -// `pooled` must have `RankedTensorType`. -template -Operation *AvgPoolDivideByCount( - Value pooled, const SmallVector &input_shape, - const SmallVector &ksize, - const SmallVector &strides, OpTy op, Value zero, - PatternRewriter &rewriter) { - Location loc = op.getLoc(); - RankedTensorType pooled_type = mlir::cast(pooled.getType()); - Type element_type = pooled_type.getElementType(); - Operation *result = nullptr; - RankedTensorType orig_input_type = - tensorflow::GetTypeFromTFTensorShape(input_shape, element_type); - - if (op.getPadding() == "VALID") { - // All window counts are equal here because we don't have padding - // (each entry of `pooled` corresponds to a window that consists of - // original input entries only). - int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1, - std::multiplies()); - // Divide `pooled` by window counts. - Value divisor = - GetScalarConstOfType(element_type, loc, window_count, &rewriter); - auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); - result = rewriter.create( - loc, pooled_type, pooled, divisor, scalar_broadcast_dims); - } else { - assert(op.getPadding() == "SAME"); - // For SAME padding, only original entries that contributed to a window - // are counted for the average of this window, not padded entries. - - // Build all-ones tensor of same shape as the original input. - ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); - auto all_ones_tensor = rewriter.create(loc, splat); - - // Get padding for the input. - DenseIntElementsAttr input_padding_attr = - GetReduceWindowPaddingAsAttr(input_shape, op.getKsize(), - op.getStrides(), op.getPadding(), - &rewriter); - - // Count the 1's in each window, using the same padding as for the input, - // which gives us the window counts by which `pooled` needs to be divided. - auto divisor = rewriter.create( - loc, pooled_type, - /*operand=*/all_ones_tensor, - /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), - /*window_strides=*/GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), - /*padding=*/input_padding_attr); - BuildReduceBody(element_type, &divisor.getBody(), &rewriter); - - // Divide `pooled` by window counts. - result = rewriter.create(loc, pooled_type, pooled, - divisor.getResult(0)); - } - return result; -} - -Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.getValue(); } -Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.getInput(); } - -// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window -// dimensions with add as the reduction function. The reduction result is -// then divided by the number of elements in the window. -template -class ConvertAvgPoolOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - Value input_value = GetAvgPoolInput(op); - auto input_type = mlir::dyn_cast(input_value.getType()); - if (!input_type) return failure(); - - // We will do accumulation first; use a larger bitwidth if suitable. - Type input_element_type = input_type.getElementType(); - Type sum_element_type = GetSumAccumulationType(input_element_type); - Type result_type; - - // The result type for reduction and division with the proper element type. - if (auto ranked_type = mlir::dyn_cast(op.getType())) - result_type = tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), - sum_element_type); - else - result_type = UnrankedTensorType::get(sum_element_type); - - // Convert if we need enlarge the element type's bitwidth. - if (input_element_type != sum_element_type) - input_value = rewriter.create(op.getLoc(), input_value, - sum_element_type); - - // Create the ReduceWindow op. - Value init = - GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); - DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( - input_type.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), - &rewriter); - auto reduce = rewriter.create( - op.getLoc(), result_type, input_value, init, - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(sum_element_type, &reduce.getBody(), &rewriter); - - // Count the number of elements in the window. The following calculation - // is only valid for no paddings. - SmallVector input_shape( - llvm::to_vector(input_type.getShape())); - SmallVector ksize, strides; - GetI64ArrayAttrValues(op.getKsize(), &ksize); - GetI64ArrayAttrValues(op.getStrides(), &strides); - - Operation *result_op = AvgPoolDivideByCount( - reduce.getResult(0), input_shape, ksize, strides, op, init, rewriter); - - // Convert back if we enlarged the element type's bitwidth. - Value result = result_op->getOpResult(0); - if (input_element_type != sum_element_type) - result = - rewriter.create(op.getLoc(), result, input_element_type); - - rewriter.replaceOp(op, result); - return success(); - } -}; - -using ConvertAvgPool2DOp = ConvertAvgPoolOp; -using ConvertAvgPool3DOp = ConvertAvgPoolOp; - -// `AvgPoolGradOp` is converted to the following operations: -// 1. Divide each entry of the output gradient (the gradient for the previous -// layer in backpropagation order) by the count of the corresponding window -// (i.e., the number of non-padding entries of the window which `AvgPool` -// has mapped to this entry in forward propagation). -// 2. Add appropriate interior and exterior padding for step 3 (see example -// below). -// 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape -// as windows) and stride 1 in each dimension. This is implemented as a -// `ReduceWindowOp` with `AddOp` as body. -// -// Example: -// Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2, -// and SAME padding with 0's. It is defined by -// f(x) = [ (x_1 + x_2 + x_3) / 3 ] ( x = (x_1, x_2, x_3, x_4) ) -// [ (x_3 + x_4 + 0) / 2 ] (the 0 results from right padding) -// Note that for SAME padding in `AvgPool` the padded entries are not counted -// for the average, this is why the second denominator is 2 and not 3. -// The Jacobian Df is -// [ 1/3 1/3 1/3 0 ] -// [ 0 0 1/2 1/2 ] -// -// Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only -// needs the original input shape and not the tensor as argument). -// Let v = [ 4 6 ]^T be the output gradient (^T = transposed). Then the -// average pool gradient is given by -// Df^T * v = [ 4/3 4/3 13/3 3 ]^T -// Instead of a matrix-vector-multiplication we can utilize the sparsity and -// structure of Df by using the 3-step approach from above: -// 1. Divide output gradient v by window counts: [ 4/3 6/2 ]^T -// 2. Add appropriate padding: [ 0 0 4/3 0 3 0 ]^T -// 3. Convolve with kernel [ 1 1 1 ]: [ 4/3 4/3 11/3 3 ]^T -// -// Note that the padding in step 2. is chosen in such a way that the subsequent -// convolution produces the gradient. Higher dimensions, different padding, and -// different windows/strides work in a similar way, the main difference is in -// the computation of the paddings in step 2. -// -// For more details on backpropagation for convolution of which `AvgPoolGrad` -// is a special case see `tensorflow/core/kernels/conv_grad_ops.h`. -// `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more -// examples for different cases. -template -class ConvertAvgPoolGradOp : public OpRewritePattern { - using DimVector = SmallVector; - - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - tensorflow::TensorFormat data_format; - if (!FormatFromString(op.getDataFormat().str(), &data_format)) { - return op.emitOpError("invalid data format"); - } - // `out_grad` is the gradient that was propagated via backpropagation from - // the output layer. - Value out_grad = op.getGrad(); - auto out_grad_type = mlir::dyn_cast(out_grad.getType()); - if (!out_grad_type) { - return failure(); - } - Type element_type = out_grad_type.getElementType(); - DenseIntElementsAttr orig_input_shape_attr; - if (!matchPattern(op.getOrigInputShape(), - m_Constant(&orig_input_shape_attr))) { - return failure(); - } - auto orig_input_shape_values = orig_input_shape_attr.getValues(); - DimVector orig_input_shape(orig_input_shape_values.begin(), - orig_input_shape_values.end()); - DimVector ksize, strides; - GetI64ArrayAttrValues(op.getKsize(), &ksize); - GetI64ArrayAttrValues(op.getStrides(), &strides); - Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter); - - auto out_grad_divided = AvgPoolDivideByCount( - out_grad, orig_input_shape, ksize, strides, op, zero, rewriter); - - // Get same padding as for original input. - PaddingArray orig_padding = GetReduceWindowPaddingAsArray( - orig_input_shape, op.getKsize(), op.getStrides(), op.getPadding(), - &rewriter); - - // Add padding around `out_grad_divided` values in such a way that the - // subsequent `ReduceWindowOp` produces the gradient. - DimVector out_grad_shape( - llvm::to_vector(out_grad_type.getShape())); - DimVector low_padding(num_dims, 0); - DimVector high_padding(num_dims, 0); - DimVector interior_padding(num_dims, 0); - constexpr int num_spatial_dims = num_dims - 2; - for (int i = 0; i < num_spatial_dims; ++i) { - int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i); - int orig_input_shape_padded_in_dim = orig_input_shape[dim] + - orig_padding[dim].first + - orig_padding[dim].second; - // Set interior padding such that neighboring entries from - // `out_grad_divided` have distance `strides[dim]` from each other in - // every dimension. - interior_padding[dim] = strides[dim] - 1; - // Set exterior padding in the same way as for convolution gradient - // computation. - auto status = ::xla::ConvGradExtractAndVerifyDimension( - /*input_size=*/orig_input_shape_padded_in_dim, - /*filter_size=*/ksize[dim], - /*output_size=*/out_grad_shape[dim], - /*dilation=*/1, - /*stride=*/strides[dim], - /*padding=*/::xla::Padding::kValid); - if (!status.ok()) { - return failure(); - } - ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim = - status.value(); - // Subtract the original exterior padding since it doesn't contribute to - // the gradient. Note that we save one `PadOp` and some unnecessary kernel - // computations, compared to the `xla::AvgPoolGrad` implementation, by - // subtracting the original exterior padding before `ReduceWindowOp` - // instead of trimming the result of `ReduceWindowOp` (the final result is - // the same because all strides are 1). - low_padding[dim] = - conv_grad_spatial_dim.pad_before - orig_padding[dim].first; - high_padding[dim] = - conv_grad_spatial_dim.pad_after - orig_padding[dim].second; - - // Update `out_grad_shape` to result shape of following `PadOp`. - out_grad_shape[dim] = low_padding[dim] + high_padding[dim] + - (out_grad_shape[dim] - 1) * strides[dim] + 1; - } - Value reduce_window_input = rewriter.create( - loc, tensorflow::GetTypeFromTFTensorShape(out_grad_shape, element_type), - /*operand=*/out_grad_divided->getOpResult(0), - /*padding_value=*/zero, - /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter), - /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter), - /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter)); - - // Compute result by convolving `reduce_window_input` with an all-ones - // kernel, using `ReduceWindowOp` with `AddOp` body. - - Type sum_element_type = GetSumAccumulationType(element_type); - if (element_type != sum_element_type) { - // Convert to appropriate sum accumulation type to avoid precision loss. - reduce_window_input = rewriter.create(loc, reduce_window_input, - sum_element_type); - zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter); - } - auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter); - auto reduce_window_op = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape(orig_input_shape, - sum_element_type), - /*operand=*/reduce_window_input, - /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), - /*window_strides=*/ones, - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), - /*padding=*/DenseIntElementsAttr()); - BuildReduceBody(sum_element_type, &reduce_window_op.getBody(), - &rewriter); - Value result = reduce_window_op.getResult(0); - - if (element_type != sum_element_type) { - // Convert back to original element type. - result = rewriter.create(op.getLoc(), result, element_type); - } - rewriter.replaceOp(op, {result}); - return success(); - } -}; - -using ConvertAvgPool2DGradOp = - ConvertAvgPoolGradOp; -using ConvertAvgPool3DGradOp = - ConvertAvgPoolGradOp; - -// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window -// dimensions with max as the reduction function. -// -// Sample result for VALID padding mode: -// -// %init = arith.constant dense<...> : tensor -// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] -// {window_dimensions = ..., window_strides = ... } -// -template -class ConvertMaxPoolOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - Type element_type = - mlir::cast(op.getInput().getType()).getElementType(); - if (!element_type.isSignlessIntOrFloat()) return failure(); - tensorflow::Padding padding; - if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) - return failure(); - if (padding == tensorflow::Padding::EXPLICIT) { - return failure(); - } - Location loc = op.getLoc(); - ConstantOp init = GetScalarLimitConstOfType( - element_type, loc, hlo::kInfinityLowest, &rewriter); - - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - if (!input_ty) return failure(); - DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( - input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), - &rewriter); - auto reduce = rewriter.create( - loc, op.getType(), op.getInput(), init, - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(element_type, &reduce.getBody(), &rewriter); - - rewriter.replaceOp(op, reduce.getResult(0)); - return success(); - } -}; - -using ConvertMaxPool2DOp = ConvertMaxPoolOp; -using ConvertMaxPool3DOp = ConvertMaxPoolOp; - -// Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on -// the condition only. -class ConvertSelectOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SelectOp op, - PatternRewriter &rewriter) const override { - // This lowering only works on ranked types. - auto cond_type = - mlir::dyn_cast(op.getCondition().getType()); - auto then_type = - mlir::dyn_cast(op.getThenValue().getType()); - auto else_type = - mlir::dyn_cast(op.getElseValue().getType()); - if (!cond_type || !then_type || !else_type) { - return failure(); - } - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value cond_shape = b.createOrFold(op.getCondition()); - Value then_shape = b.createOrFold(op.getThenValue()); - Value else_shape = b.createOrFold(op.getElseValue()); - - // First check that the `then` and `else` shapes are the equal. - Value assumption = - b.createOrFold(ValueRange{then_shape, else_shape}); - // For a vector cond we also verify that the majormost dim of `then` matches - // the vector size. To do that split off the first dim of `then`. - bool needs_broadcast = cond_type.getRank() == 1 && then_type.getRank() != 1; - Value then_shape_split = then_shape; - if (needs_broadcast) { - Value const_one = b.create(1); - Type extent_first = shape::getExtentTensorType(b.getContext(), 1); - Type extent_second = - shape::getExtentTensorType(b.getContext(), then_type.getRank() - 1); - SmallVector then_split; - b.createOrFold(then_split, - TypeRange{extent_first, extent_second}, - then_shape, const_one); - then_shape_split = then_split[0]; - } - // If the condition is not a scalar, check that it matches the other shapes. - if (cond_type.getRank() > 0) { - Value eq_cstr = b.createOrFold( - ValueRange{cond_shape, then_shape_split}); - auto witness = shape::WitnessType::get(b.getContext()); - assumption = b.createOrFold( - witness, ValueRange{assumption, eq_cstr}); - } - auto result_type = mlir::cast(op.getResult().getType()); - auto assuming_op = - b.create(ArrayRef{result_type}, assumption); - - OpBuilder::InsertionGuard guard(b); - b.createBlock(&assuming_op.getDoRegion()); - - // Broadcast the cond if necessary. - Value cond = op.getCondition(); - if (needs_broadcast) { - Value result_extents = b.create( - GetExtentsTensorTypeFor(result_type), then_shape); - cond = b.create( - tensorflow::GetTypeFromTFTensorShape(result_type.getShape(), - b.getI1Type()), - cond, result_extents, - GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b)); - } - Value select = b.create( - result_type, cond, op.getThenValue(), op.getElseValue()); - b.create(select); - rewriter.replaceOp(op, {assuming_op.getResult(0)}); - return success(); - } -}; - -// Converts the tf.Slice op into mhlo.real_dynamic_slice -// TODO(disc): To recover static special case's performance with folding and -// canonicalization. -class ConvertSliceOpDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SliceOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value input = op.getInput(); - Value begin_indices = op.getBegin(); - Value sizes = op.getSize(); - - auto input_ty = mlir::dyn_cast(input.getType()); - auto begin_type = mlir::dyn_cast(begin_indices.getType()); - auto size_type = mlir::dyn_cast(sizes.getType()); - - if (!input_ty || !begin_type || !size_type || - !begin_type.hasStaticShape() || !size_type.hasStaticShape() || - begin_type.getRank() != 1 || size_type.getRank() != 1) { - return failure(); - } - // TODO(disc): remove static shape check once folding/canonicalization func - // added - DenseIntElementsAttr size_attr; - if (matchPattern(op.getSize(), m_Constant(&size_attr))) { - return failure(); - } - - int rank = begin_type.getDimSize(0); - auto shape_scalar_type = begin_type.getElementType(); - Value one = rewriter.create(loc, 1); - SmallVector stride_values(rank, one); - SmallVector end_values; - SmallVector begin_values; - end_values.reserve(rank); - for (int i = 0; i < rank; ++i) { - SmallVector indices; - indices.push_back(rewriter.create(loc, i)); - auto begin_value = - rewriter.create(loc, begin_indices, indices); - auto size_value = rewriter.create(loc, sizes, indices); - Value minus_one = rewriter.create( - loc, shape_scalar_type, - rewriter.create(loc, -1)); - auto is_minus_one = rewriter.create( - loc, arith::CmpIPredicate::eq, size_value, minus_one); - Value end_value = - rewriter.create(loc, begin_value, size_value); - auto dim_value = rewriter.create( - loc, shape_scalar_type, - rewriter.create(loc, input, i)); - end_value = rewriter.create(loc, is_minus_one, - dim_value, end_value); - auto end_value_casted = rewriter.create( - loc, rewriter.getIndexType(), end_value); - end_values.push_back(end_value_casted); - - auto begin_value_casted = rewriter.create( - loc, rewriter.getIndexType(), begin_value); - begin_values.push_back(begin_value_casted); - } - auto index_ty = rewriter.getIndexType(); - auto start_indices = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(begin_values.size())}, index_ty), - begin_values); - auto end_indices = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(end_values.size())}, index_ty), - end_values); - auto stride_indices = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(stride_values.size())}, index_ty), - stride_values); - - auto d_slice = rewriter.create( - loc, op.getOperation()->getResult(0).getType(), input, start_indices, - end_indices, stride_indices); - rewriter.replaceOp(op, d_slice.getOperation()->getResults()); - return success(); - } -}; - -static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, - Value *out_lhs, Value *out_rhs, - PatternRewriter *rewriter) { - // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is: - // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS] - // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS] - // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS] - // To perform the matmul, we need to first broadcast lhs and rhs to a common - // set of leading dimensions before doing the actual matmul. - // That's what the code below does. - // In particular, we populate out_lhs and out_rhs to have dimension structure: - // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS] - // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS] - // To do this, we need to calculate those output shapes, which involves - // slicing off the leading batch dims of each operand, broadcasting them, - // then concatenating the broadcasted leading dims back to the row/col dims. - // Finally, we create a TF::BroadcastTo op that does the actual broadcast. - - // TODO(silvasean): Reduce duplication across reified shape calculations and - // the static computation of output types needed to create ops. - Value lhs_shape = rewriter->create(loc, lhs); - Value rhs_shape = rewriter->create(loc, rhs); - Value const_neg2 = - rewriter->create(loc, rewriter->getIndexAttr(-2)); - auto shape_type = shape::ShapeType::get(rewriter->getContext()); - auto lhs_splitted = rewriter->create( - loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2); - auto rhs_splitted = rewriter->create( - loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2); - auto lhs_type = mlir::cast(lhs.getType()); - auto rhs_type = mlir::cast(rhs.getType()); - // The last two dimensions are the matrix row/col dimensions. Don't broadcast - // them. - SmallVector result_batch_shape_compile_time_extents; - mlir::OpTrait::util::getBroadcastedShape( - lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2), - result_batch_shape_compile_time_extents); - auto result_batch_shape = rewriter->create( - loc, shape_type, lhs_splitted.getHead(), rhs_splitted.getHead(), - /*error=*/nullptr); - // Lambda which handles the broadcasting of one side to the common - // leading-batch dimensions. - auto broadcast_one_side = [&](Value side, RankedTensorType type, - Value tail_shape, Value *out_side) { - ArrayRef matrix_dims = type.getShape().take_back(2); - auto result_shape = result_batch_shape_compile_time_extents; - result_shape.append(matrix_dims.begin(), matrix_dims.end()); - auto result_type = tensorflow::GetTypeFromTFTensorShape( - result_shape, type.getElementType()); - auto shape = rewriter->create( - loc, shape_type, result_batch_shape, tail_shape); - auto shape_tensor = rewriter->create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(result_shape.size())}, - rewriter->getIndexType()), - shape); - *out_side = rewriter->create(loc, result_type, side, - shape_tensor); - }; - broadcast_one_side(lhs, lhs_type, lhs_splitted.getTail(), out_lhs); - broadcast_one_side(rhs, rhs_type, rhs_splitted.getTail(), out_rhs); -} - -class ConvertBatchMatMulV2Op : public OpRewritePattern { - public: - // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved - // to CHLO and it is missing legalization to MHLO. Once that is done, this - // pattern's benefit can be changed back to one as well as the fallback - // lowering pattern for the op can be removed. - // - // Set benefit of this pattern to zero to prefer the fallback pattern when - // available and applicable. That pattern avoids broadcast on operands and is - // therefore faster. - // - // Native legalization for BatchMatMulV3 needs to be added as well. - explicit ConvertBatchMatMulV2Op(MLIRContext *context) - : OpRewritePattern(context, /*benefit=*/0) {} - - LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, - PatternRewriter &rewriter) const override { - Value lhs = op.getX(); - Value rhs = op.getY(); - auto lhs_type = mlir::dyn_cast(lhs.getType()); - auto rhs_type = mlir::dyn_cast(rhs.getType()); - if (!lhs_type || !rhs_type) return failure(); - if (mlir::isa(lhs_type.getElementType()) && op.getAdjX()) { - lhs = rewriter.create(op.getLoc(), lhs_type, lhs); - } - if (mlir::isa(rhs_type.getElementType()) && op.getAdjY()) { - rhs = rewriter.create(op.getLoc(), rhs_type, rhs); - } - - // Broadcast both operands. - BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, - &rewriter); - lhs_type = mlir::cast(lhs.getType()); - rhs_type = mlir::cast(rhs.getType()); - assert(lhs_type.getRank() == rhs_type.getRank()); - int64_t rank = lhs_type.getRank(); - auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); - auto lhs_contracting_dimensions = llvm::to_vector<4>( - llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); - auto rhs_contracting_dimensions = llvm::to_vector<4>( - llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); - auto dimension_numbers = DotDimensionNumbersAttr::get( - rewriter.getContext(), - /*lhs_batching_dimensions=*/batch_dimensions, - /*rhs_batching_dimensions=*/batch_dimensions, - /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, - /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); - // TODO(silvasean): Emit shape checks for contracting dimensions. - // (The batch dimensions are checked by the broadcasting logic) - rewriter.replaceOpWithNewOp( - op, op.getType(), lhs, rhs, dimension_numbers, - /*precision_config=*/GetPrecisionConfig(&rewriter), - /*algorithm=*/DotAlgorithmAttr{}); - return success(); - } -}; - -// Converts the tf.Split op into a series of HLO slice ops when the tensor to be -// split has fully static shape and the dimension to split is a constant. -// -// The main logic of this pattern is to calculate the index start and end range -// for each slice. And this happens only on the dimension to be split; for all -// other dimensions, all resultant slices' index start and end range covers the -// input tensor's full range. Strides for all resultant slices are all one. -// -// For example, the following source IR: -// -// %dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor -// %0:3 = "tf.Split"(%dim, %input) : (tensor, tensor<4x6xf32>) -> -// (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) -// -// will be converted into: -// -// %0 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 2]> : tensor<2xi64>, -// start_indices = dense<0> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : -// (tensor<4x6xf32>) -> tensor<4x2xf32> -// %1 = "mhlo.slice"(%input) { -// limit_indices = dense<4> : tensor<2xi64>, -// start_indices = dense<[0, 2]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : -// (tensor<4x6xf32>) -> tensor<4x2xf32> -// %2 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 6]> : tensor<2xi64>, -// start_indices = dense<[0, 4]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : -// (tensor<4x6xf32>) -> tensor<4x2xf32> -// TODO(antiagainst): consider lowering into TF ops so the pattern can be more -// applicable. -class ConvertSplitOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SplitOp op, - PatternRewriter &rewriter) const override { - // We can only split inputs that have fully static shape. - auto input_type = mlir::dyn_cast(op.getValue().getType()); - if (!input_type || !input_type.hasStaticShape()) return failure(); - - // We can only match when the split dimension is a constant scalar. - DenseIntElementsAttr split_dim_attr; - if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) - return failure(); - - // Get the dimension we are splitting at. Offset properly if it's negative. - int64_t input_rank = input_type.getRank(); - int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); - if (dim_index < 0) dim_index += input_rank; - - // Calculate the dimension size for each slice along the split dimension. - int64_t input_dim_size = input_type.getDimSize(dim_index); - - int64_t num_splits = op.getNumResults(); - int64_t slice_size = input_dim_size / num_splits; - - // Get each slice's type. - auto slice_shape = llvm::to_vector<4>(input_type.getShape()); - slice_shape[dim_index] = slice_size; - Type slice_type = tensorflow::GetTypeFromTFTensorShape( - slice_shape, input_type.getElementType()); - - // Parameters for constructing each slice. - SmallVector begin_indices(input_rank, 0); - auto end_indices = llvm::to_vector<4>(input_type.getShape()); - SmallVector strides(input_rank, 1); - - // All HLO slice results used to replace the original tf.Split op. - SmallVector slices; - slices.reserve(num_splits); - - for (int i = 0; i < num_splits; ++i) { - begin_indices[dim_index] = i * slice_size; - end_indices[dim_index] = (i + 1) * slice_size; - slices.push_back( - rewriter.create(op.getLoc(), slice_type, op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter))); - } - - rewriter.replaceOp(op, slices); - return success(); - } -}; - -// Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the -// dimension to split is a constant. -// TODO(disc): To recover static special case's performance with folding and -// canonicalization. delete ConvertSplitOp -class ConvertSplitOpDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SplitOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value input = op.getValue(); - auto input_type = mlir::dyn_cast(input.getType()); - if (!input_type) return failure(); - - // TODO(disc): remove static shape check once folding/canonicalization func - // added and ConvertSplitOp deleted. Calculate the dimension size for each - // slice along the split dimension. We are splitting along the dynamic - // dimension, or using static pattern transform - if (input_type.hasStaticShape()) return failure(); - - // We can only match when the split dimension is a constant scalar. - DenseIntElementsAttr split_dim_attr; - if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) - return failure(); - - // Get the dimension we are splitting at. Offset properly if it's negative. - int64_t input_rank = input_type.getRank(); - int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); - if (dim_index < 0) dim_index += input_rank; - - Value input_dim_size = - rewriter.create(loc, input, dim_index); - // Calculate the dimension size for each slice along the split dimension. - int num_splits = op.getNumResults(); - Value num_splits_value = rewriter.create( - loc, rewriter.getIndexAttr(num_splits)); - Value slice_size = - rewriter.create(loc, input_dim_size, num_splits_value); - - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - - SmallVector begin_indices(input_rank, zero); - SmallVector end_indices; - end_indices.reserve(input_rank); - SmallVector strides(input_rank, one); - for (int i = 0; i < input_rank; ++i) { - end_indices.push_back(rewriter.create(loc, input, i)); - } - - // All HLO d_slice results used to replace the original tf.Split op. - SmallVector slices; - slices.reserve(num_splits); - - for (int i = 0; i < num_splits; ++i) { - begin_indices[dim_index] = rewriter.create( - loc, slice_size, rewriter.create(loc, i)); - end_indices[dim_index] = rewriter.create( - loc, slice_size, rewriter.create(loc, i + 1)); - - Type index_ty = rewriter.getIndexType(); - auto begin_value = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(begin_indices.size())}, index_ty), - begin_indices); - auto end_value = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(end_indices.size())}, index_ty), - end_indices); - auto stride_value = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(strides.size())}, index_ty), - strides); - slices.push_back(rewriter.create( - loc, op.getOperation()->getResult(i).getType(), input, begin_value, - end_value, stride_value)); - } - - rewriter.replaceOp(op, slices); - return success(); - } -}; - -// Converts the tf.SplitV op into a series of HLO slice ops when the tensor to -// be split has fully static shape and the dimension to split and split sizes -// are constants. -// -// This is similar to the conversion for tf.Split op other than that the size of -// each chunk on the dimension to split is explicitly given as an op operand -// and they are not necessarily the same. -// -// For example, given the following IR: -// -// %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} -// %split_dim = "tf.Const"() {value = dense<1> : tensor} -// %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : -// (tensor<4x6xf32>, tensor<3xi32>, tensor) -> -// (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) -// -// We will generate slices following slices: -// %0 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 1]> : tensor<2xi64>, -// start_indices = dense<0> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : -// (tensor<4x6xf32>) -> tensor<4x1xf32> -// %1 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 3]> : tensor<2xi64>, -// start_indices = dense<[0, 1]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : -// (tensor<4x6xf32>) -> tensor<4x2xf32> -// %2 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 6]> : tensor<2xi64>, -// start_indices = dense<[0, 3]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : -// (tensor<4x6xf32>) -> tensor<4x3xf32> -class ConvertSplitVOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SplitVOp op, - PatternRewriter &rewriter) const override { - // We can only split inputs that have fully static shape. - // TODO(b/145731001): enhance to support dynamic-shaped inputs. - auto input_type = mlir::dyn_cast(op.getValue().getType()); - if (!input_type || !input_type.hasStaticShape()) return failure(); - - // We can only match when the split dimension is a constant scalar. - DenseIntElementsAttr split_dim_attr; - if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) - return failure(); - - // We can only match when the split sizes is a constant int vector. - DenseIntElementsAttr split_sizes_attr; - if (!matchPattern(op.getSizeSplits(), m_Constant(&split_sizes_attr))) - return failure(); - - // Get each chunck's size along the dimension to split. It may contain - // dynamic sizes and we need to update it if so. - SmallVector split_sizes; - int64_t total_dim_size = 0; // Total dimension size assigned to splits - std::optional dynamic_dim_index; - split_sizes.reserve( - mlir::cast(split_sizes_attr.getType()).getNumElements()); - for (const auto &dim : llvm::enumerate(split_sizes_attr)) { - int64_t dim_val = dim.value().getSExtValue(); - split_sizes.push_back(dim_val); - if (dim_val == -1) { - // We cannot have more than one dynamic dimension. - assert(!dynamic_dim_index && "invalid split sizes"); - dynamic_dim_index = dim.index(); - } else { - total_dim_size += dim_val; - } - } - - // Get the dimension we are splitting at. Offset properly if it's negative. - int64_t input_rank = input_type.getRank(); - int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); - if (dim_index < 0) dim_index += input_rank; - - int64_t input_dim_size = input_type.getDimSize(dim_index); - assert(((dynamic_dim_index && total_dim_size <= input_dim_size) || - (!dynamic_dim_index && total_dim_size == input_dim_size)) && - "invalid split sizes"); - - // Update the dynamic dimension with calculated concrete size. - if (dynamic_dim_index) - split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size; - - // Parameters for constructing each slice. - SmallVector begin_indices(input_rank, 0); - auto end_indices = llvm::to_vector<4>(input_type.getShape()); - SmallVector strides(input_rank, 1); - - // All HLO slice results used to replace the original tf.Split op. - SmallVector slices; - slices.reserve(op.getNumResults()); - - for (int i = 0, end = op.getNumResults(); i < end; ++i) { - end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; - slices.push_back(rewriter.create( - op.getLoc(), op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter))); - // Prepare the begin indice for the next slice. - begin_indices[dim_index] = end_indices[dim_index]; - } - - rewriter.replaceOp(op, slices); - return success(); - } -}; - -// Converts StridedSlice op to HLO Slice op along with Reverse op to handle -// negative strides and Reshape op to update the output shape. Indices and -// strides operands are converted to attributes with non-negative indexing. -// -// If the begin input is not a compile time constant, the begin input needs to -// be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this -// case, strides must have a known value of 1 (otherwise we have insufficient -// information to conform to XLA's op semantics). -// -// For example with an op like following, -// tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} -// : tensor -> tensor -// -// If the %begin input is constant, output would be: -// %reversed = "mhlo.Reverse" (%input) {dimensions = ...} -// %sliced = "mhlo.Slice" (%input) -// {start_indices = ..., limit_indices = ..., strides = ...} -// %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor -// -class ConvertStridedSliceOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op, - ArrayRef begin_indices, - ArrayRef end_indices, - ArrayRef strides, - RankedTensorType input_ty, - PatternRewriter &rewriter) const { - SmallVector hlo_begin_indices, hlo_end_indices, hlo_strides, - dims_to_reverse; - int64_t input_rank = input_ty.getRank(); - ArrayRef input_shape = input_ty.getShape(); - hlo_begin_indices.reserve(input_rank); - hlo_end_indices.reserve(input_rank); - hlo_strides.reserve(input_rank); - - int64_t indices_elements = begin_indices.size(); - if (input_rank < indices_elements) return failure(); - - // Convert from TensorFlow negative or out of range indices and strides - // values to legal HLO Slice attributes. - for (int i = 0, e = indices_elements; i != e; i++) { - int64_t begin = begin_indices[i]; - int64_t end = end_indices[i]; - int64_t stride = strides[i]; - - if (stride < 0) { - // Negative stride means that the output values are computed starting - // from end until begin. Mark the dimension for reversal before slice - // and compute indices for the reversed input. - dims_to_reverse.push_back(i); - begin = (input_shape[i] - 1) - begin; - end = (input_shape[i] - 1) - end; - stride = -stride; - } - - // Unlike TensorFlow, HLO requires begin and end values to be within - // range. - begin = std::max(int64_t(0), begin); - end = std::max(begin, end); - end = std::min(end, input_shape[i]); - - hlo_begin_indices.push_back(begin); - hlo_end_indices.push_back(end); - hlo_strides.push_back(stride); - } - - Location loc = op.getLoc(); - Value input = op.getInput(); - if (!dims_to_reverse.empty()) - input = rewriter.create( - loc, input_ty, op.getInput(), - GetI64ElementsAttr(dims_to_reverse, &rewriter)); - auto sliced = rewriter.create( - loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter), - GetI64ElementsAttr(hlo_end_indices, &rewriter), - GetI64ElementsAttr(hlo_strides, &rewriter)); - - // Reshape slice result so that the shape is updated depending on - // 'new_axis_mask' or 'shrink_axis_mask' attributes. - rewriter.replaceOpWithNewOp(op, op.getType(), sliced); - return success(); - } - - LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op, - RankedTensorType input_ty, - RankedTensorType result_ty, - PatternRewriter &rewriter) const { - // If begin and end values are dynamic, we can only support this lowering - // if strides are a known value of 1. - DenseIntElementsAttr sparse_strides_attr; - if (!matchPattern(op.getStrides(), m_Constant(&sparse_strides_attr))) { - return rewriter.notifyMatchFailure( - op, - "requires that strides are known when begin/end values are dynamic"); - } - SmallVector strides; - int64_t stride_value; - for (const APInt &stride : sparse_strides_attr) { - if ((stride_value = stride.getSExtValue()) != 1) { - return rewriter.notifyMatchFailure(op, - "requires that strides are all 1 " - "when begin/end values are dynamic"); - } - strides.push_back(stride_value); - } - - ArrayRef input_shape = input_ty.getShape(); - int last_dim = std::max(static_cast(input_shape.size()) - 1, 0); - - // When begin/end values are dynamic, the ellipsis mask, if set, must refer - // to the last dimension. - int ellipsis_mask = op.getEllipsisMask(); - if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim))) - return rewriter.notifyMatchFailure( - op, - "requires that ellipsis_mask, if set, refer to the last dimension of " - "input (when begin/end values are dynamic)"); - - // In this case where the begin and end values are dynamic, we only support - // cases where the number of output elements has to be equal to the number - // of input elements that are sliced. Each dimension is either sliced fully - // or sliced with a size of one. - int output_elements = result_ty.getNumElements(); - int input_elements_sliced = 1; - - // Begin must be a ranked, 1-dimensional tensor: This is checked by the - // verifier. - int64_t slicing_dim_size = - mlir::cast(op.getBegin().getType()).getDimSize(0); - uint64_t begin_mask = op.getBeginMask(); - uint64_t end_mask = op.getEndMask(); - const int input_rank = input_shape.size(); - for (int d = 0; d < input_rank; ++d) { - // Each dimension is either sliced fully or has size of one. - if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) || - (d >= slicing_dim_size)) { - input_elements_sliced *= input_shape[d]; - } - } - if (input_elements_sliced != output_elements) { - return rewriter.notifyMatchFailure( - op, - "requires the number of output elements to be equal to the number of " - "input elements sliced (when begin/end values are dynamic)"); - } - - SmallVector slice_begin_indices; - // For the dimensions that are to be sliced, all have slice sizes of 1. - SmallVector slice_sizes; - auto begin_element_ty = - mlir::cast(op.getBegin().getType()).getElementType(); - // Scalar tensor type. - TensorType type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, begin_element_ty); - Location loc = op.getLoc(); - auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter); - for (int d = 0; d < input_rank; ++d) { - if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) || - (d >= slicing_dim_size)) { - slice_begin_indices.push_back(zero); - slice_sizes.push_back(input_shape[d]); - continue; - } - - auto index = rewriter.create( - loc, op.getBegin(), GetI64ElementsAttr({d}, &rewriter), - GetI64ElementsAttr({d + 1}, &rewriter), - GetI64ElementsAttr({1}, &rewriter)); - // Convert index to scalar. - auto reshaped_index = rewriter.create(loc, type, index); - // If the index is negative, wrap it around with dimension size. - auto index_negative = - rewriter.create(loc, reshaped_index, zero); - auto input_val = GetScalarConstOfType(begin_element_ty, loc, - input_shape[d], &rewriter); - auto wrapped_index = - rewriter.create(loc, input_val, reshaped_index); - auto final_index = rewriter.create( - loc, type, index_negative, wrapped_index, reshaped_index); - slice_begin_indices.push_back(final_index); - slice_sizes.push_back(1); - } - - auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter); - auto sliced_type = tensorflow::GetTypeFromTFTensorShape( - slice_sizes, op.getType().getElementType()); - // This must be an xla DynamicSlice op due to the inputs that aren't - // constant. - auto sliced = rewriter.create( - loc, sliced_type, op.getInput(), slice_begin_indices, slice_sizes_attr); - - // Reshape slice result so that the shape is updated depending on - // 'new_axis_mask' or 'shrink_axis_mask' attributes. - rewriter.replaceOpWithNewOp(op, op.getType(), sliced); - return success(); - } - - LogicalResult matchAndRewrite(TF::StridedSliceOp op, - PatternRewriter &rewriter) const override { - // Input shape needs to be static to convert negative indices in TensorFlow - // to absolute indices required by HLO. - // - // TODO(hinsu): Relax this constraint for ops without negative indices and - // strides. - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - if (!input_ty || !input_ty.hasStaticShape()) return failure(); - - // Output shape needs to be static to apply 'new_axis_mask' or - // 'shrink_axis_mask' by reshaping tensor after slice. - // - // TODO(hinsu): Relax this constraint for ops without the above masks. - auto result_ty = mlir::dyn_cast(op.getType()); - if (!result_ty || !result_ty.hasStaticShape()) return failure(); - - DenseIntElementsAttr sparse_begin_attr, sparse_end_attr; - if (!matchPattern(op.getBegin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(op.getEnd(), m_Constant(&sparse_end_attr))) { - return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter); - } - - SmallVector begin_indices, end_indices, strides; - if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) { - return failure(); - } - return rewriteWithConstantBegin(op, begin_indices, end_indices, strides, - input_ty, rewriter); - } -}; - -// Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops. -// -// tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does -// the reverse: it propagates the graident for the sliced tensor to the original -// input tensor by doing padding with zeros. The main logic is calculating the -// indices and strides for padding. -class ConvertStridedSliceGradOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::StridedSliceGradOp op, - PatternRewriter &rewriter) const override { - // We need constant input shape to perform padding calculations later. - DenseIntElementsAttr input_shape_attr; - if (!matchPattern(op.getShape(), m_Constant(&input_shape_attr))) - return failure(); - - // We also need constant begin/end indices and strides to perform padding - // calculations. - // Bounded shape after performing strided slice - SmallVector shape; - // Bounded begin, end, and strides for strided slice - SmallVector begin_indices, end_indices, strides; - if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices, - &strides)) - return failure(); - - Value grad = op.getDy(); - Type element_type = mlir::cast(grad.getType()).getElementType(); - - // Perform reshape to undo any new/shrink axes done by strided slice. - grad = rewriter.create( - op.getLoc(), tensorflow::GetTypeFromTFTensorShape(shape, element_type), - grad); - - SmallVector padding_low, padding_high, padding_interm; - SmallVector dims_to_reverse; - padding_low.reserve(shape.size()); - padding_high.reserve(shape.size()); - padding_interm.reserve(shape.size()); - - // Prepare padding parameters for each dimension. - for (int i = 0, e = shape.size(); i < e; ++i) { - int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue(); - if (strides[i] > 0) { - padding_low.push_back(begin_indices[i]); - padding_interm.push_back(strides[i] - 1); - - // Pad the upper dimension up to the expected input shape. It's not - // sufficient simply to use end_indices[i] to compute the padding in - // cases where the stride does not divide evenly into the interval - // between begin_indices[i] and end_indices[i]. - int64_t size = - padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i]; - padding_high.push_back(input_dim - size); - } else { - dims_to_reverse.push_back(i); - padding_high.push_back(input_dim - begin_indices[i] - 1); - padding_interm.push_back(-strides[i] - 1); - - // Pad the lower dimension up to the expected input shape. - int64_t size = - padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i]; - padding_low.push_back(input_dim - size); - } - } - - if (!dims_to_reverse.empty()) { - grad = rewriter.create( - op.getLoc(), grad.getType(), grad, - GetI64ElementsAttr(dims_to_reverse, &rewriter)); - } - - auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter); - rewriter.replaceOpWithNewOp( - op, op.getType(), grad, zero, - GetI64ElementsAttr(padding_low, &rewriter), - GetI64ElementsAttr(padding_high, &rewriter), - GetI64ElementsAttr(padding_interm, &rewriter)); - return success(); - } -}; - -/// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and -/// offset applied to generate the range values. The output tensor needs to -/// have a static shape. -/// -/// For example an op like the following: -/// %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"} -/// : (tensor, tensor, tensor) -> tensor<5xf32> -/// -/// Output would be: -/// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32> -/// %scaled = "mhlo.multiply"(%iota, %delta) -/// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : -/// (tensor<5xf32>, tensor) -> tensor<5xf32> -/// %result = "mhlo.add"(%scaled, %offset) -/// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : -/// (tensor<5xf32>, tensor) -> tensor<5xf32> -/// -/// Implementation is defined in C++ due to no type interface for the iota op. -class ConvertRangeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::RangeOp op, - PatternRewriter &rewriter) const override { - auto result = op.getResult(); - auto result_type = result.getType(); - if (!mlir::cast(result_type).hasStaticShape()) { - return failure(); - } - - auto iota = rewriter.create(op.getLoc(), result_type, - rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( - op.getLoc(), result_type, iota, op.getDelta(), - hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.getDelta())); - rewriter.replaceOpWithNewOp( - op, result_type, scaled, op.getStart(), - hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.getStart())); - return success(); - } -}; - -// Converts RangeOp for cases with the length is a dynamic value. The shape of -// the resulting tensor computed, then the start and delta is used with the -// dynamic_iota value to compute the final range value. -// -// For example, the resulting range op value: -// %range = "tf.range"(%start, %limit, %delta) -// -// Is converted to the following. -// %start + %delta * iota(ceil(abs((%limit - %start) / %delta)) -// -// Implementation is defined in C++ due to the complicated type behavior. -class ConvertDynamicRangeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::RangeOp op, - PatternRewriter &rewriter) const override { - auto result = op.getResult(); - auto result_type = mlir::cast(result.getType()); - if (result_type.hasStaticShape()) { - return failure(); - } - - Value start = op.getStart(); - Value delta = op.getDelta(); - Value limit = op.getLimit(); - - // To compute the length we need to use floating point calculations so that - // ceil can be computed for the number of steps. - auto compute_element_type = - mlir::isa(getElementTypeOrSelf(start.getType())) - ? getElementTypeOrSelf(start.getType()) - : rewriter.getF64Type(); - auto compute_type = tensorflow::GetTypeFromTFTensorShape( - mlir::cast(limit.getType()).getShape(), - compute_element_type); - - // Compute the length of the sequence we are going to need. This includes - // some conversion to float for the operations. - // - // %size = ceil(abs((%limit - %start) / %delta)) - auto range = rewriter.create(op.getLoc(), limit, start); - auto abs = rewriter.create(op.getLoc(), range); - - // Delta is not necessarily the same type as start and limit. - auto abs_cast = - rewriter.create(op.getLoc(), compute_type, abs); - auto delta_cast = - rewriter.create(op.getLoc(), compute_type, delta); - - // Compute the total number of integer steps and convert to the HLO - // dimension tensor. - auto normalized = - rewriter.create(op.getLoc(), abs_cast, delta_cast); - auto ceil = rewriter.create(op.getLoc(), normalized); - auto steps = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), ceil); - auto reshape = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), - steps); - - // Using the resulting length compute the correct range value: - // - // %range = %start + %delta * iota(%size) - auto out_scalar_type = tensorflow::GetTypeFromTFTensorShape( - {}, getElementTypeOrSelf(result_type)); - auto start_out_cast = - rewriter.create(op.getLoc(), out_scalar_type, start); - auto delta_out_cast = - rewriter.create(op.getLoc(), out_scalar_type, delta); - - auto iota = rewriter.create( - op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( - op.getLoc(), result_type, iota, delta_out_cast, - hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast)); - rewriter.replaceOpWithNewOp( - op, result_type, scaled, start_out_cast, - hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast)); - return success(); - } -}; - -ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { - auto int_attr = mlir::cast(attr); - auto type = mlir::cast(val.getType()); - - SmallVector axis; - axis.reserve(int_attr.getNumElements()); - - int64_t rank = type.getRank(); - for (auto val : int_attr.getValues()) { - axis.push_back((val.getSExtValue() + rank) % rank); - } - - return builder->getI64TensorAttr(axis); -} - -/// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling -/// and offset applied to generate the linspace values. The output tensor needs -/// to have a static shape. The implementation is defined in C++ because there -/// is no type inference for the iota op. -class ConvertLinSpaceOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::LinSpaceOp op, - PatternRewriter &rewriter) const override { - auto result = op.getResult(); - auto result_type = mlir::dyn_cast(result.getType()); - if (!result_type || !result_type.hasStaticShape()) { - return failure(); - } - - DenseIntElementsAttr num_attr; - if (!matchPattern(op.getNum(), m_Constant(&num_attr))) { - return rewriter.notifyMatchFailure(op, "Num must be a constant scalar"); - } - - if (num_attr.begin() == num_attr.end()) { - return rewriter.notifyMatchFailure(op, "Num must not be empty"); - } - int64_t num = (*num_attr.begin()).getSExtValue(); - - // Calculate the scaling that needs to be applied to the iota. - auto step_numerator = rewriter.create( - op.getLoc(), op.getStart().getType(), op.getStop(), op.getStart(), - hlo::getBroadcastDimensionsAttr(&rewriter, op.getStop(), - op.getStart())); - Value step_denominator = rewriter.create( - op.getLoc(), op.getNum(), result_type.getElementType()); - if (num > 1) { - Value one = GetScalarConstOfType(result_type.getElementType(), - op.getLoc(), 1, &rewriter); - step_denominator = rewriter.create( - op.getLoc(), step_denominator.getType(), step_denominator, one, - hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); - } - auto step = rewriter.create( - op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, - hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator, - step_denominator)); - - // Scale the iota and add the offset. - auto iota = rewriter.create(op.getLoc(), result_type, - rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( - op.getLoc(), result_type, iota, step, - hlo::getBroadcastDimensionsAttr(&rewriter, iota, step)); - rewriter.replaceOpWithNewOp( - op, result_type, scaled, op.getStart(), - hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.getStart())); - return success(); - } -}; - -/// Converts a generic OpTy tensorflow op to a mhlo.reduce op over -/// ReductionOp. -/// `is_accumulation` controls whether it uses higher precision for the actual -/// reduction. This is set to false for ops like max where there is no precision -/// concerns. -// -// The Derived class should have a static method to return the initial value to -// use for reduction: -// static Value GetInitialValue(Type reduce_element_type, Location loc, -// PatternRewriter *rewriter); -// The reduce_element_type is guaranteed to be a float, int, or complex type -// suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType. -template -class GenericConvertReductionOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // TODO(b/141785544): Update this to not require ranked shapes. - // Input shape needs to be ranked to convert negative indices in TensorFlow - // to absolute indices required by HLO. - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - if (!input_ty) return failure(); - ArrayRef input_shape = input_ty.getShape(); - - DenseIntElementsAttr dimensions; - if (!matchPattern(op.getReductionIndices(), m_Constant(&dimensions))) - return failure(); - - // Build the final shape from input_shape and dimensions using a bitmap - // to mark the reduced dimensions. - SmallVector reduced_dimensions_bitmap(input_shape.size(), false); - SmallVector xla_dimensions; - for (const APInt &index_raw : dimensions.getValues()) { - int64_t index = index_raw.getSExtValue(); - int64_t rank = input_shape.size(); - if ((index < -rank || index >= rank)) return failure(); - index = (index + rank) % rank; - reduced_dimensions_bitmap[index] = true; - xla_dimensions.push_back(index); - } - - Location loc = op.getLoc(); - Type element_type = input_ty.getElementType(); - - // Only float, int, and complex types are currently supported. - if (!mlir::isa(element_type) && - !mlir::isa(element_type) && - !mlir::isa(element_type)) { - return rewriter.notifyMatchFailure( - op, "element type must be float, int, or complex type"); - } - - // Convert to an accumulation type to not lose precision when doing - // repeated arithmetic operations. - Type reduce_element_type = - is_accumulation ? GetAccumulationType(element_type) : element_type; - auto casted_input = - rewriter.create(loc, op.getInput(), reduce_element_type); - - // Each reduction op can have a different initial value. - Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter); - - auto reduction = rewriter.create( - loc, casted_input.getResult(), init, - GetI64ElementsAttr(xla_dimensions, &rewriter), reduce_element_type); - BuildReduceBody(reduce_element_type, &reduction.getBody(), - &rewriter); - Value result = reduction.getResult(0); - - // The mean op needs to divide by the product of the reduced dimensions. - if (std::is_same::value) { - Value in_shape = rewriter.create(loc, op.getInput()); - Value divisor_count = rewriter.create(loc, 1); - for (size_t i = 0; i < input_shape.size(); ++i) { - if (reduced_dimensions_bitmap[i]) { - Value index = rewriter.create(loc, i); - auto dim = rewriter.create(loc, in_shape, index); - divisor_count = - rewriter.create(loc, divisor_count, dim); - } - } - // HLO ops are only defined on tensors, so we cast the divisor from - // index -> i64 -> tensor<1xi64> -> tensor -> tensor - Value divisor_casted = rewriter.create( - loc, rewriter.getI64Type(), divisor_count); - Value divisor_tensor = rewriter.create( - loc, tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), - divisor_casted); - Value divisor = rewriter.create( - loc, tensorflow::GetTypeFromTFTensorShape({}, reduce_element_type), - divisor_tensor); - auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); - result = rewriter.create(loc, result, divisor, - broadcast_dims); - } - - result = rewriter.create(loc, result, element_type); - - // Need to reshape back after the reduction if we're keeping the reduced - // dimensions. Note that we do this through successive (nominally 1) - // applications of the TF ExpandDims op vs a more labor intensive - // reshape. Various code generation techniques benefit from the knowledge - // that this is a restricted form of shape manipulation that is just adding - // unit dims. - if (op.getKeepDims()) { - for (const auto &dim_is_reduced : - llvm::enumerate(reduced_dimensions_bitmap)) { - if (dim_is_reduced.value()) { - auto index_attr = GetI32ElementsAttr( - {static_cast(dim_is_reduced.index())}, &rewriter); - Value index = rewriter.create(loc, index_attr); - result = rewriter.create(loc, result, index); - } - } - } - rewriter.replaceOp(op, {result}); - - return success(); - } -}; - -// Converts Mean op to HLO Reduce op. -// -// %init = arith.constant dense<...> : tensor -// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] -// {dimensions = ...} -// %divisor = arith.constant dense<...> : tensor -// %mean = "mhlo.divide"(%sum, %divisor) -class ConvertMeanOp - : public GenericConvertReductionOp { - public: - using GenericConvertReductionOp::GenericConvertReductionOp; - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter); - } -}; - -// Converts Sum op to HLO Reduce op. -// -// %init = arith.constant dense<...> : tensor -// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] -// {dimensions = ...} -class ConvertSumOp - : public GenericConvertReductionOp { - public: - using GenericConvertReductionOp::GenericConvertReductionOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - // The neutral element of fp addition is -0.0, not 0.0: '0.0 + -0.0 = 0.0'. - return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter); - } -}; - -// Converts Max op to HLO Reduce op. -// -// %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] -// {dimensions = ...} -class ConvertMaxOp - : public GenericConvertReductionOp { - public: - using GenericConvertReductionOp::GenericConvertReductionOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarLimitConstOfType(reduce_element_type, loc, - hlo::kInfinityLowest, rewriter); - } -}; - -// Converts Min op to HLO Reduce op. -// -// %init = arith.constant dense<...> : tensor -// %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"] -// {dimensions = ...} -class ConvertMinOp - : public GenericConvertReductionOp { - public: - using GenericConvertReductionOp::GenericConvertReductionOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarLimitConstOfType(reduce_element_type, loc, - hlo::kInfinityMax, rewriter); - } -}; - -// Converts Prod op to HLO Reduce op. -// -// %init = arith.constant dense<...> : tensor -// %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"] -// {dimensions = ...} -class ConvertProdOp - : public GenericConvertReductionOp { - public: - using GenericConvertReductionOp::GenericConvertReductionOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); - } -}; - -// Converts All op to HLO Reduce op. -// -// %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"] -// {dimensions = ...} -class ConvertAllOp - : public GenericConvertReductionOp { - public: - using GenericConvertReductionOp::GenericConvertReductionOp; - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); - } -}; - -// Converts Any op to HLO Reduce op. -// -// %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"] -// {dimensions = ...} -class ConvertAnyOp - : public GenericConvertReductionOp { - public: - using GenericConvertReductionOp::GenericConvertReductionOp; - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter); - } -}; - -// Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform -// a reduction on the original input and the corresponding index. The reduction -// sub-computation selects the max (or min) value and the index for the value. -// Derived: is the resulting derived class of this class. -// OpTy: is TF::ArgMaxOp or TF::ArgMinOp. -template -class ConvertArgMinMaxOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - RankedTensorType input_type = - mlir::dyn_cast(op.getInput().getType()); - if (!input_type) { - return failure(); - } - - Type input_element_type = input_type.getElementType(); - // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If - // tf.ArgMax doesn't support complex data types, this check can be removed. - if (!input_element_type.isSignlessIntOrFloat()) return failure(); - - Location loc = op.getLoc(); - Value init_value = - Derived::GetInitialValue(input_element_type, loc, rewriter); - - RankedTensorType output_type = - mlir::dyn_cast(op.getOutput().getType()); - if (!output_type) { - return rewriter.notifyMatchFailure(op, "requires known rank"); - } - - Type index_element_type = output_type.getElementType(); - Value index_init_value = - GetScalarConstOfType(index_element_type, loc, 0, &rewriter); - - RankedTensorType index_type = tensorflow::GetTypeFromTFTensorShape( - input_type.getShape(), index_element_type); - - std::optional optional_axis = - GetIntegerHLOAxisFromTFAxis(op.getDimension(), input_type.getRank()); - if (!optional_axis.has_value()) - return rewriter.notifyMatchFailure(op, "required axis"); - int64_t axis = optional_axis.value(); - - IntegerAttr iota_dimension = - IntegerAttr::get(rewriter.getIntegerType(64), axis); - Value input_shape = rewriter.create(loc, op.getInput()); - Value index_values = rewriter.create( - loc, index_type, input_shape, iota_dimension); - - Value operands[] = {op.getInput(), index_values}; - Value init_values[] = {init_value, index_init_value}; - DenseIntElementsAttr reduction_dimensions = - GetI64ElementsAttr({axis}, &rewriter); - - auto reduction = rewriter.create( - loc, llvm::ArrayRef(operands), - llvm::ArrayRef(init_values), reduction_dimensions, - TypeRange({input_element_type, index_element_type})); - auto direction = Derived::GetDirection(); - BuildArgMinMaxReductionBody(input_element_type, index_element_type, - direction, &reduction.getBody(), &rewriter); - - rewriter.replaceOp(op, {reduction.getResult(1)}); - return success(); - } -}; - -// Converts tensorflow ArgMax op to mhlo operations. The actual -// implementation is in class ConvertArgMinMaxOp: -// -// %init_index = arith.constant dense<...> : tensor -// %init = arith.constant dense<...> : tensor -// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, -// %init_index) ["mhlo.arg_max"] -class ConvertArgMaxOp - : public ConvertArgMinMaxOp { - public: - using ConvertArgMinMaxOp::ConvertArgMinMaxOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { - return GetScalarLimitConstOfType(reduce_element_type, loc, - hlo::kInfinityLowest, &rewriter); - } - - static ComparisonDirection GetDirection() { return ComparisonDirection::GE; } -}; - -// Converts tensorflow ArgMin op to mhlo operations. The actual -// implementation is in class ConvertArgMinMaxOp: -// -// %init_index = arith.constant dense<...> : tensor -// %init = arith.constant dense<...> : tensor -// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, -// %init_index) ["mhlo.arg_min"] -class ConvertArgMinOp - : public ConvertArgMinMaxOp { - public: - using ConvertArgMinMaxOp::ConvertArgMinMaxOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { - return GetScalarLimitConstOfType(reduce_element_type, loc, - hlo::kInfinityMax, &rewriter); - } - - static ComparisonDirection GetDirection() { return ComparisonDirection::LE; } -}; - -// Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with -// assignment: -// -// %result = "mhlo.scatter"(%tensor, %indices, %updates) -// { dimensions = ... } -// -template -class ConvertTensorScatterOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - auto tensor_ty = mlir::dyn_cast(op.getTensor().getType()); - auto indices_ty = - mlir::dyn_cast(op.getIndices().getType()); - auto updates_ty = - mlir::dyn_cast(op.getUpdates().getType()); - - if (!tensor_ty || !indices_ty || !updates_ty) return failure(); - // Last dimension of the indices needs to known at compile time for - // computation of the 'update_window_dims' attribute in the dimensions - // struct. - int64_t num_index_dims = indices_ty.getShape().back(); - if (ShapedType::isDynamic(num_index_dims)) return failure(); - - auto updates = op.getUpdates(); - - // Broadcast scalar `updates` in into expected shape as following shape: - // updates.shape == indices.shape[:-1] + tensor.shape[indices.shape[-1]:] - if (updates_ty.getRank() == 0 && - (std::is_same::value || - std::is_same::value)) { - if (!tensor_ty.hasStaticShape()) { - return failure(); - } - - if (!indices_ty.hasStaticShape()) { - return failure(); - } - - auto tensor_shape = tensor_ty.getShape(); - auto indices_shape = indices_ty.getShape(); - auto index_depth = indices_shape.back(); - llvm::SmallVector expected_update_shape; - - // create the expected update shape which scalar update is broadcasted to - expected_update_shape.append(indices_shape.begin(), - std::prev(indices_shape.end())); - - expected_update_shape.append(std::next(tensor_shape.begin(), index_depth), - tensor_shape.end()); - - auto const_type = tensorflow::GetTypeFromTFTensorShape( - {static_cast(expected_update_shape.size())}, - rewriter.getIntegerType(64)); - - auto const_attr = GetI64ElementsAttr(expected_update_shape, &rewriter); - - auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); - - auto broadcast_to_type = tensorflow::GetTypeFromTFTensorShape( - llvm::ArrayRef(expected_update_shape), - updates_ty.getElementType()); - - updates = rewriter.create( - op->getLoc(), broadcast_to_type, op.getUpdates(), const_op); - - updates_ty = mlir::dyn_cast(updates.getType()); - } - - int64_t tensor_rank = tensor_ty.getRank(); - int64_t indices_rank = indices_ty.getRank(); - int64_t updates_rank = - mlir::dyn_cast(updates.getType()).getRank(); - - int64_t window_dims = tensor_rank - num_index_dims; - auto dims_attr = ScatterDimensionNumbersAttr::get( - rewriter.getContext(), - llvm::to_vector<4>( - llvm::seq(updates_rank - window_dims, updates_rank)), - llvm::to_vector<4>(llvm::seq(0, num_index_dims)), - /*inputBatchingDims=*/{}, - /*scatterIndicesBatchingDims=*/{}, - llvm::to_vector<4>(llvm::seq(0, num_index_dims)), - indices_rank - 1); - - Location loc = op.getLoc(); - auto scatter = rewriter.create( - loc, op.getType(), ValueRange(Value(op.getTensor())), op.getIndices(), - updates, dims_attr); - Derived::BuildScatterBody(tensor_ty.getElementType(), - &scatter.getUpdateComputation(), loc, rewriter); - - rewriter.replaceOp(op, scatter.getResult(0)); - return success(); - } -}; - -class ConvertTensorScatterUpdateOp - : public ConvertTensorScatterOp { - public: - using ConvertTensorScatterOp::ConvertTensorScatterOp; - - static void BuildScatterBody(Type element_type, Region *region, Location loc, - OpBuilder &builder) { - OpBuilder::InsertionGuard guard(builder); - Block *block = builder.createBlock(region); - Type type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); - block->addArguments({type, type}, SmallVector(2, loc)); - builder.create(loc, block->getArgument(1)); - } -}; - -class ConvertTensorScatterAddOp - : public ConvertTensorScatterOp { - public: - using ConvertTensorScatterOp::ConvertTensorScatterOp; - - static void BuildScatterBody(Type element_type, Region *region, Location loc, - OpBuilder &builder) { - OpBuilder::InsertionGuard guard(builder); - Block *block = builder.createBlock(region); - Type type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); - block->addArguments({type, type}, SmallVector(2, loc)); - auto add_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, add_op.getResult()); - } -}; - -class ConvertTensorScatterSubOp - : public ConvertTensorScatterOp { - public: - using ConvertTensorScatterOp::ConvertTensorScatterOp; - - static void BuildScatterBody(Type element_type, Region *region, Location loc, - OpBuilder &builder) { - OpBuilder::InsertionGuard guard(builder); - Block *block = builder.createBlock(region); - Type type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); - block->addArguments({type, type}, SmallVector(2, loc)); - auto sub_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, sub_op.getResult()); - } -}; - -class ConvertTensorScatterMinOp - : public ConvertTensorScatterOp { - public: - using ConvertTensorScatterOp::ConvertTensorScatterOp; - - static void BuildScatterBody(Type element_type, Region *region, Location loc, - OpBuilder &builder) { - OpBuilder::InsertionGuard guard(builder); - Block *block = builder.createBlock(region); - Type type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); - block->addArguments({type, type}, SmallVector(2, loc)); - auto min_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, min_op.getResult()); - } -}; - -class ConvertTensorScatterMaxOp - : public ConvertTensorScatterOp { - public: - using ConvertTensorScatterOp::ConvertTensorScatterOp; - - static void BuildScatterBody(Type element_type, Region *region, Location loc, - OpBuilder &builder) { - OpBuilder::InsertionGuard guard(builder); - Block *block = builder.createBlock(region); - Type type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); - block->addArguments({type, type}, SmallVector(2, loc)); - auto max_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, max_op.getResult()); - } -}; - -// Converts Tile op to HLO BroadcastInDim and Reshape ops. -// For shape [S1, S2] and multiples [M1, M2], -// MS1 = M1 * S1; MS2 = M2 * S2 -// -// %broadcast = mhlo.broadcast_in_dim(%input) { -// broadcast_dimensions = [0, 2] -// } -// %result = "mhlo.reshape"(%broadcast) : (tensor) -// -> tensor -class ConvertTileOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::TileOp op, - PatternRewriter &rewriter) const override { - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - if (!input_ty || !input_ty.hasStaticShape()) return failure(); - ArrayRef input_shape = input_ty.getShape(); - Type element_type = input_ty.getElementType(); - - DenseIntElementsAttr multiples; - if (!matchPattern(op.getMultiples(), m_Constant(&multiples)) || - multiples.getType().getRank() != 1) - return failure(); - - const int64_t input_shape_size = input_shape.size(); - if (multiples.getNumElements() != input_shape_size) return failure(); - - SmallVector broadcasted_shape; - SmallVector broadcast_dimensions; - broadcasted_shape.reserve(input_shape.size() * 2); - broadcast_dimensions.reserve(input_shape.size()); - for (auto multiple_and_input : - llvm::zip(multiples.getValues(), input_shape)) { - int64_t multiple = std::get<0>(multiple_and_input).getSExtValue(); - int64_t input_size = std::get<1>(multiple_and_input); - - if (multiple < 0) return failure(); - - // Line input up with the next dimension in broadcasted_shape - // when broadcasting. - int64_t broadcast_dim; - int64_t output_size = input_size * multiple; - if (input_size == 1 || multiple == 1) { - // Special case for when normal broadcasting will just work. - broadcast_dim = broadcasted_shape.size(); - broadcasted_shape.push_back(output_size); - } else { - // Tiling will happen for this dimension during the ReshapeOp below. - broadcasted_shape.push_back(multiple); - broadcast_dim = broadcasted_shape.size(); - broadcasted_shape.push_back(input_size); - } - broadcast_dimensions.push_back(broadcast_dim); - } - Location loc = op.getLoc(); - Type broadcasted_type = - tensorflow::GetTypeFromTFTensorShape(broadcasted_shape, element_type); - Type output_type = op.getType(); - - Value result = rewriter.create( - loc, broadcasted_type, op.getInput(), - GetI64ElementsAttr(broadcast_dimensions, &rewriter)); - - if (output_type != broadcasted_type) { - result = rewriter.create(loc, output_type, result); - } - - rewriter.replaceOp(op, {result}); - - return success(); - } -}; - -// Converts the tf.TileOp op into mhlo.dynamic_reshape -// TODO(disc): To recover static special case's performance with folding and -// canonicalization. -class ConvertTileOpDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - // clang-format off - // Converts Tile op to HLO DBroadcastInDim and DReshape ops. - // For shape [S1, S2] and multiples [M1, M2], - // MS1 = M1 * S1; MS2 = M2 * S2 - // - // %out_dim_size = [S1, M1, S2, M2] - // %broadcast_dimensions = [1, 3]; - // %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions); - // %shape = [MS1, MS2] - // %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor) -> tensor - // clang-format on - LogicalResult matchAndRewrite(TF::TileOp op, - PatternRewriter &rewriter) const final { - Location loc = op.getLoc(); - Value input = op.getInput(); - Value multiples = op.getMultiples(); - auto input_ty = mlir::dyn_cast(input.getType()); - if (!input_ty) return failure(); - // TODO(disc): Remove this constraint once fold and canonicalization - // implemented. - if (input_ty.hasStaticShape()) return failure(); - - Type element_type = input_ty.getElementType(); - int64_t input_rank = input_ty.getRank(); - SmallVector input_shape_values; - for (int64_t i = 0; i < input_rank; ++i) { - auto dim_size = input_ty.getDimSize(i); - if (dim_size == ShapedType::kDynamic) { - input_shape_values.push_back( - rewriter.create(loc, input, i)); - } else { - input_shape_values.push_back(rewriter.create( - loc, rewriter.getIndexAttr(dim_size))); - } - } - - auto multiples_ty = mlir::dyn_cast(multiples.getType()); - int64_t multiples_rank = multiples_ty.getRank(); - // rank of multiples input of tf.TileOp must be 1 - if (multiples_rank != 1) return failure(); - // multiples input of tf.TileOp must be fixed shaped - if ((!multiples_ty.hasStaticShape()) || - (multiples_ty.getDimSize(0) != input_rank)) { - return failure(); - } - Type index_ty = rewriter.getIndexType(); - // %out_dim_size - SmallVector out_dim_size; - out_dim_size.reserve(input_rank * 2); - for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) { - Value index = rewriter.create( - loc, rewriter.getIndexAttr(dim_idx)); - Value multiples_size = - rewriter.create(loc, multiples, ValueRange{index}); - Value multiples_size_casted = - rewriter.create(loc, index_ty, multiples_size); - out_dim_size.push_back(multiples_size_casted); - out_dim_size.push_back(input_shape_values[dim_idx]); - } - SmallVector broadcast_dimensions; - broadcast_dimensions.reserve(input_rank); - for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) { - broadcast_dimensions.push_back(1 + 2 * dim_idx); - } - auto broadcast_dims_attr = - GetI64ElementsAttr(broadcast_dimensions, &rewriter); - - Value out_dim_size_tensor = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(out_dim_size.size())}, index_ty), - out_dim_size); - SmallVector broadcast_shape(input_rank * 2, - ShapedType::kDynamic); - RankedTensorType broadcast_type = - tensorflow::GetTypeFromTFTensorShape(broadcast_shape, element_type); - Value broadcast = rewriter.create( - loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr); - - // %shape = [MS1, MS2] - SmallVector shape_values; - shape_values.reserve(input_rank); - for (int64_t i = 0; i < input_rank; ++i) { - Value dim_size_value = rewriter.create( - loc, out_dim_size[2 * i], out_dim_size[2 * i + 1]); - shape_values.push_back(dim_size_value); - } - Value shape = rewriter.create( - loc, tensorflow::GetTypeFromTFTensorShape({input_rank}, index_ty), - shape_values); - rewriter.replaceOpWithNewOp(op, op.getType(), - broadcast, shape); - return success(); - } -}; - -template -class ConvertMaxPoolGradOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - Type element_type = - mlir::cast(op.getOrigInput().getType()).getElementType(); - - // Compute paddings using the original input and kernel shape and strides. - // Here, ReduceWindow op as used as the MaxPool op is lowered to the - // ReduceWindow op. - auto input_ty = - mlir::dyn_cast(op.getOrigInput().getType()); - if (!input_ty) return failure(); - DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( - input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), - &rewriter); - - auto result = rewriter.create( - loc, op.getType(), op.getOrigInput(), op.getGrad(), - GetScalarConstOfType(element_type, loc, 0, &rewriter), - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), - paddings_attr); - - BuildReduceBody(element_type, &result.getScatter(), &rewriter); - { - OpBuilder::InsertionGuard guard(rewriter); - Block *block = rewriter.createBlock(&result.getSelect()); - - // Block arguments are scalars of the given element type. - Type type = - tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); - block->addArguments({type, type}, SmallVector(2, loc)); - - auto reducer = rewriter.create(loc, block->getArgument(0), - block->getArgument(1), - ComparisonDirection::GE); - rewriter.create(loc, reducer.getResult()); - } - - rewriter.replaceOp(op, result); - - return success(); - } -}; - -using ConvertMaxPool2DGradOp = - ConvertMaxPoolGradOp; -using ConvertMaxPool3DGradOp = - ConvertMaxPoolGradOp; - -// Converts tf.Conv?DBackpropInputOp into: -// %rev_filter = "mhlo.reverse"(%filter) -// %result = "mhlo.convolution"(%out_backprop, %rev_filter) -template -class ConvertConvBackpropInputOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Unpack all of the attributes. - tensorflow::TensorFormat data_format; - if (!FormatFromString(op.getDataFormat().str(), &data_format)) - return op.emitOpError("invalid data format"); - constexpr int num_dims = num_spatial_dims + 2; - int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); - - tensorflow::Padding padding; - if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) - return failure(); - - auto out_backprop_ty = - mlir::dyn_cast(op.getOutBackprop().getType()); - auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); - - // With the exception of out_backprop's batch dimension, out_backprop and - // filter need to have static shape. Filter is validated here, out_backprop - // is mostly validated at use. - if (!out_backprop_ty || !filter_ty || !filter_ty.hasStaticShape()) - return failure(); - - // Compute input_shape by supporting either: - // 1) Fully static shapes, represented as constants. - // 2) Static shapes with a dynamic batch dimension, represented as - // 1D tf.Pack of a batch dimension (can be static or dynamic) - // and other dimensions (can only be static), for example: - // "tf.Pack"(%142, %cst_301, %cst_301, %cst_300) {axis = 0 : i64, ...} - std::vector input_shape; - DenseIntElementsAttr input_shape_attr; - if (matchPattern(op.getInputSizes(), m_Constant(&input_shape_attr)) && - input_shape_attr.getType().getRank() == 1) { - input_shape.insert(input_shape.end(), - input_shape_attr.getValues().begin(), - input_shape_attr.getValues().end()); - } else { - auto pack = op.getInputSizes().template getDefiningOp(); - if (!pack || pack.getAxis() != 0) return failure(); - auto pack_ty = mlir::dyn_cast(pack.getType()); - if (!pack_ty || pack_ty.getRank() != 1) return failure(); - for (auto i = 0; i < pack_ty.getDimSize(0); ++i) { - if (i == batch_dim) { - // We don't use the batch dimension below, so we don't care about - // its size. Might as well populate it with -1. - input_shape.push_back(ShapedType::kDynamic); - } else { - DenseIntElementsAttr input_dims_attr; - if (matchPattern(pack.getValues()[i], m_Constant(&input_dims_attr)) && - input_dims_attr.getType().getRank() == 0) { - input_shape.push_back(input_dims_attr.getSplatValue()); - } else { - return failure(); - } - } - } - } - - auto dilations_attr = GetI64ElementsAttr(op.getDilations()); - std::vector dilations{ - dilations_attr.template getValues().begin(), - dilations_attr.template getValues().end()}; - auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ - strides_attr.template getValues().begin(), - strides_attr.template getValues().end()}; - - std::vector explicit_paddings; - if (padding == tensorflow::Padding::EXPLICIT) { - // EXPLICIT padding mode and the associated attribute is limited to - // Conv2DBackpropInput. So, fetch attribute by identifier instead of the - // op.explicit_paddings() attribute getter. - ArrayRef explicit_paddings_attr = - op->template getAttrOfType("explicit_paddings").getValue(); - explicit_paddings.reserve(explicit_paddings_attr.size()); - for (Attribute explicit_padding : explicit_paddings_attr) - explicit_paddings.push_back( - mlir::cast(explicit_padding).getInt()); - } - - ArrayRef filter_shape = filter_ty.getShape(); - - // Compute ConvDimensionNumbers, dilation, and padding. - SmallVector spatial_dims; - SmallVector lhs_dilation; - SmallVector rhs_dilation; - SmallVector paddings; - - for (int i : llvm::seq(0, num_spatial_dims)) { - const int64_t spatial_dim = - GetTensorSpatialDimIndex(num_dims, data_format, i); - spatial_dims.push_back(spatial_dim); - - // Prepare metadata indexed by spatial_dim for computing pad_before - // and pad_after. - int64_t input_size = input_shape[spatial_dim]; - if (input_size == ShapedType::kDynamic) return failure(); - int64_t output_size = out_backprop_ty.getDimSize(spatial_dim); - if (output_size == ShapedType::kDynamic) return failure(); - int64_t filter_size = filter_ty.getDimSize(i); - int64_t stride = strides[spatial_dim]; - int64_t dilation = dilations[spatial_dim]; - - // Compute pad_before and pad_after following the logic from - // ConvBackpropComputeDimensionsV2. (Unfortunately, we cannot call - // the function in question because it doesn't work with dynamic dims). - int64_t padding_before = -1, padding_after = -1; - if (padding == tensorflow::Padding::EXPLICIT) { - padding_before = explicit_paddings[2 * spatial_dim]; - padding_after = explicit_paddings[2 * spatial_dim + 1]; - } - int64_t expected_output_size = 0; - auto status = GetWindowedOutputSizeVerbose( - input_size, filter_size, dilation, stride, padding, - &expected_output_size, &padding_before, &padding_after); - if (!status.ok()) return failure(); - if (output_size != expected_output_size) return failure(); - int64_t effective_filter_size = (filter_size - 1) * dilation + 1; - int64_t pad_before = effective_filter_size - 1 - padding_before; - int64_t padded_out_size = input_size + effective_filter_size - 1; - int64_t expanded_output_size = (output_size - 1) * stride + 1; - int64_t pad_after = padded_out_size - expanded_output_size - pad_before; - - // Populate metadata for the upcoming mhlo.conv op using the result of - // the computations performed above. - lhs_dilation.push_back(stride); - rhs_dilation.push_back(dilation); - paddings.push_back(pad_before); - paddings.push_back(pad_after); - } - - RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( - {num_spatial_dims, 2}, rewriter.getIntegerType(64)); - auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); - - Value filter = op.getFilter(); - - const int feature_dim = - tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); - const int64_t in_depth = *(input_shape.begin() + feature_dim); - if (in_depth == ShapedType::kDynamic) return failure(); - const int64_t filter_in_depth = filter_shape[num_spatial_dims]; - const int64_t feature_group_count = in_depth / filter_in_depth; - - if (feature_group_count != 1) { - // 1. Reshape filter from - // [H, W, ..., filter_in_depth, out_depth] to - // [H, W, ..., filter_in_depth, G, out_depth / G]. - auto new_shape = llvm::to_vector<6>(filter_shape); - new_shape.back() = feature_group_count; - new_shape.push_back(filter_shape.back() / feature_group_count); - Type filter_element_ty = filter_ty.getElementType(); - auto ty = - tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create(op.getLoc(), ty, filter); - - // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]. - llvm::SmallVector perm(num_dims + 1); - std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]); - std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]); - ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create( - op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter)); - - // 3. Reshape to [H, W, ..., in_depth, out_depth / G]. - new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1]; - new_shape[num_spatial_dims + 1] = new_shape.back(); - new_shape.pop_back(); - ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create(op.getLoc(), ty, filter); - } - - SmallVector kernel_spatial_dims; - kernel_spatial_dims.resize(num_spatial_dims); - std::iota(kernel_spatial_dims.begin(), kernel_spatial_dims.end(), 0); - - // Mirror the filter in the spatial dimensions. - filter = rewriter.create( - op.getLoc(), filter, - GetI64ElementsAttr(kernel_spatial_dims, &rewriter)); - - // activation gradients - // = gradients (with padding and dilation) mirrored_weights - Value result = rewriter.create( - op.getLoc(), op.getType(), op.getOutBackprop(), filter, - /*window_strides=*/ - GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, - &rewriter), - /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), - GetI64ElementsAttr(rhs_dilation, &rewriter), - /*window_reversal=*/nullptr, - ConvDimensionNumbersAttr::get( - rewriter.getContext(), - /*inputBatchDimension=*/batch_dim, - /*inputFeatureDimension=*/feature_dim, - /*inputSpatialDimensions=*/spatial_dims, - // TF filter shape is [ H, W, ..., inC, outC ] - // Transpose the input and output features for computing the - // gradient. - /*kernelInputFeatureDimension=*/ - num_spatial_dims + 1, - /*kernelOutputFeatureDimension=*/ - num_spatial_dims, - /*kernelSpatialDimensions=*/kernel_spatial_dims, - /*outputBatchDimension=*/batch_dim, - /*outputFeatureDimension=*/feature_dim, - /*outputSpatialDimensions=*/spatial_dims), - rewriter.getI64IntegerAttr(feature_group_count), - /*batch_group_count=*/rewriter.getI64IntegerAttr(1), - /*precision_config=*/GetPrecisionConfig(&rewriter)); - - rewriter.replaceOp(op, {result}); - - return success(); - } -}; - -using ConvertConv2DBackpropInputOp = - ConvertConvBackpropInputOp; -using ConvertConv3DBackpropInputOp = - ConvertConvBackpropInputOp; - -// Converts tf.Conv?DBackpropFilterOp into: -// %result = "mhlo.convolution"(%input, %out_backprop) -template -class ConvertConvBackpropFilterOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Unpack all of the attributes. - tensorflow::TensorFormat data_format; - if (!FormatFromString(op.getDataFormat().str(), &data_format)) - return op.emitOpError("invalid data format"); - - tensorflow::Padding padding; - if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) - return failure(); - - auto out_backprop_ty = - mlir::dyn_cast(op.getOutBackprop().getType()); - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - - for (RankedTensorType ty : {out_backprop_ty, input_ty}) - if (!ty || !ty.hasStaticShape()) return failure(); - - ArrayRef out_backprop_shape = out_backprop_ty.getShape(); - ArrayRef input_shape = input_ty.getShape(); - - DenseIntElementsAttr filter_shape_attr; - if (!matchPattern(op.getFilterSizes(), m_Constant(&filter_shape_attr)) || - filter_shape_attr.getType().getRank() != 1) - return failure(); - - auto dilations_attr = GetI64ElementsAttr(op.getDilations()); - std::vector dilations{ - dilations_attr.template getValues().begin(), - dilations_attr.template getValues().end()}; - auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ - strides_attr.template getValues().begin(), - strides_attr.template getValues().end()}; - - std::vector explicit_paddings; - if (padding == tensorflow::Padding::EXPLICIT) { - // EXPLICIT padding mode and the associated attribute is limited to - // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the - // op.explicit_paddings() attribute getter. - ArrayRef explicit_paddings_attr = - op->template getAttrOfType("explicit_paddings").getValue(); - explicit_paddings.reserve(explicit_paddings_attr.size()); - for (Attribute explicit_padding : explicit_paddings_attr) - explicit_paddings.push_back( - mlir::cast(explicit_padding).getInt()); - } - - constexpr int num_dims = num_spatial_dims + 2; - auto filter_shape = filter_shape_attr.getValues(); - - // Reuse dimension computation logic from conv_grad_shape_utils.cc. - tensorflow::ConvBackpropDimensions dims; - if (!tensorflow::ConvBackpropComputeDimensionsV2( - /*label=*/"", num_spatial_dims, - ToTensorShape(input_shape), - ToTensorShape(filter_shape), - ToTensorShape(out_backprop_shape), dilations, - strides, padding, explicit_paddings, data_format, &dims) - .ok()) { - return failure(); - } - - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we need to: - // 1. In the case of group convolution, move the num_groups dimension before - // the batch dimension - // 2. Swap the roles of the batch and feature dimensions. - const int feature_dim = - tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); - const int64_t in_depth = input_shape[feature_dim]; - const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims); - const int64_t batch_group_count = in_depth / filter_in_depth; - - // Compute ConvDimensionNumbers, dilation, and padding. - SmallVector spatial_dims; - SmallVector kernel_spatial_dims; - SmallVector rhs_dilation; - SmallVector paddings; - SmallVector window_strides; - - // The filter gradients are computed by a convolution of the input - // activations and the output gradients, with some appropriate padding. - // See the comment at the top of conv_grad_ops.h for details. - - for (int i : llvm::seq(0, num_spatial_dims)) { - const int64_t dim = - tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i); - kernel_spatial_dims.push_back(dim); - // Besides padding the input, we will also expand output_rows to - // expanded_out_rows = (output_rows - 1) * stride + 1 - // with zeros in between: - // - // a . . . b . . . c . . . d . . . e - // - // This is done by specifying the window dilation factors in the - // convolution HLO below. - const auto &spatial_dim_i = dims.spatial_dims[i]; - rhs_dilation.push_back(spatial_dim_i.stride); - window_strides.push_back(dilations[dim]); - - // We will also need to pad the input with zeros such that after the - // convolution, we get the right size for the filter. - // The padded_in_rows should be such that when we convolve this with the - // expanded_out_rows as a filter, we should get filter_rows back. - - const int64_t padded_in_size = - spatial_dim_i.expanded_output_size + - (spatial_dim_i.filter_size - 1) * dilations[dim]; - - // However it can be smaller than input_rows: in this - // case it means some of the inputs are not used. - // - // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: - // - // INPUT = [ A B C ] - // - // FILTER = [ x y ] - // - // and the output will only have one column: a = A * x + B * y - // - // and input "C" is not used at all. - // - // We apply negative padding in this case. - const int64_t pad_total = padded_in_size - spatial_dim_i.input_size; - - // + For the EXPLICIT padding, we pad the top/left side with the explicit - // padding and pad the bottom/right side with the remaining space. - // + For the VALID padding, we don't pad anything on the top/left side - // and pad the bottom/right side with the remaining space. - // + For the SAME padding, we pad top/left side the same as bottom/right - // side. - // - // In addition, if the padded input size is smaller than the input size, - // we need to ignore some training elements of the input. We do this by - // applying negative padding on the right/bottom. - const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT - ? explicit_paddings[2 * dim] - : padding == tensorflow::Padding::SAME - ? std::max(pad_total / 2, 0) - : 0; - paddings.push_back(pad_before); - paddings.push_back(pad_total - pad_before); - } - - RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( - {num_spatial_dims, 2}, rewriter.getIntegerType(64)); - auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); - - SmallVector output_spatial_dimensions; - output_spatial_dimensions.resize(num_spatial_dims); - std::iota(output_spatial_dimensions.begin(), - output_spatial_dimensions.end(), 0); - - const int batch_dim = - tensorflow::GetTensorBatchDimIndex(num_dims, data_format); - - Value result = rewriter.create( - op.getLoc(), op.getType(), op.getInput(), op.getOutBackprop(), - /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), - /*padding=*/paddings_attr, /*lhs_dilation=*/ - GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, - &rewriter), - GetI64ElementsAttr(rhs_dilation, &rewriter), - /*window_reversal=*/nullptr, - ConvDimensionNumbersAttr::get( - rewriter.getContext(), - // Swap batch_dim and feature_dim in the activations. - /*inputBatchDimension=*/feature_dim, - /*inputFeatureDimension=*/batch_dim, - /*inputSpatialDimensions=*/kernel_spatial_dims, - // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, ..., - // out_depth] where the batch becomes the input feature for the - // convolution. - /*kernelInputFeatureDimension=*/batch_dim, - /*kernelOutputFeatureDimension=*/feature_dim, - /*kernelSpatialDimensions=*/kernel_spatial_dims, - /*outputBatchDimension=*/num_spatial_dims, - /*outputFeatureDimension=*/num_spatial_dims + 1, - /*outputSpatialDimensions=*/output_spatial_dimensions), - /*feature_group_count=*/rewriter.getI64IntegerAttr(1), - rewriter.getI64IntegerAttr(batch_group_count), - /*precision_config=*/GetPrecisionConfig(&rewriter)); - - rewriter.replaceOp(op, {result}); - - return success(); - } -}; - -using ConvertConv2DBackpropFilterOp = - ConvertConvBackpropFilterOp; -using ConvertConv3DBackpropFilterOp = - ConvertConvBackpropFilterOp; - -class ConvertOneHotOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::OneHotOp op, - PatternRewriter &rewriter) const override { - auto indices_ty = - mlir::dyn_cast(op.getIndices().getType()); - if (!indices_ty || !indices_ty.hasStaticShape()) return failure(); - ArrayRef indices_shape = indices_ty.getShape(); - Type element_type = indices_ty.getElementType(); - - DenseIntElementsAttr depth_attr; - if (!matchPattern(op.getDepth(), m_Constant(&depth_attr))) { - return failure(); - } - - int64_t depth = depth_attr.getValues()[0].getSExtValue(); - int64_t axis = op.getAxis(); - if (axis == -1) axis = indices_shape.size(); - - llvm::SmallVector broadcast_dims(indices_shape.size()); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - - llvm::SmallVector output_dims = - llvm::to_vector<4>(indices_shape); - output_dims.insert(output_dims.begin() + axis, depth); - - Location loc = op.getLoc(); - - // The iota result is the effective output shape of the computation, - // and indices must be broadcast into it. At this point, this computation - // would need to be reworked quite a bit to support dynamic shapes, so - // just using static broadcasting. - auto index_type = - tensorflow::GetTypeFromTFTensorShape(output_dims, element_type); - auto iota = rewriter.create( - loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); - auto broadcast_indices = rewriter.create( - loc, index_type, op.getIndices(), - GetI64ElementsAttr(broadcast_dims, &rewriter)); - - Value compare = rewriter.create( - loc, broadcast_indices, iota, ComparisonDirection::EQ); - Value on_value = rewriter.create( - loc, op.getType(), op.getOnValue(), - GetI64ElementsAttr(output_dims, &rewriter)); - Value off_value = rewriter.create( - loc, op.getType(), op.getOffValue(), - GetI64ElementsAttr(output_dims, &rewriter)); - Value result = rewriter.create(loc, op.getType(), compare, - on_value, off_value); - - rewriter.replaceOp(op, {result}); - - return success(); - } -}; - -// Converts InfeedDequeueTuple to XLA HLO create_token, infeed and -// get_tuple_element ops. -// -// All HLO infeed ops expect a HLO token type operand and produce a tuple -// containing a token. This HLO token type is used to order multiple infeed -// operations within a computation. The token type can come from other -// infeed/outfeed/send/recv ops or can be generated using create_token op with -// no operands. Here we emit a create_token op to generate the token type -// operand of infeed. The mhlo.InfeedOp can produce multiple results and later -// will be exported to XLA infeed op with single tuple return type. -// -// For example the following IR: -// %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) -// -// would be lowered to -// -// %token = "mhlo.create_token"() : () -> !mhlo.token -// %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} : -// (!mhlo.token) -> tensor<3xi32>, tensor<4xf32>, !mhlo.token> -// -class ConvertInfeedDequeueTupleOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op, - PatternRewriter &rewriter) const override { - SmallVector result_types; - result_types.reserve(op.getOutputs().size() + 1); - for (const auto &output : op.getOutputs()) { - Type ty = output.getType(); - if (auto tensor_ty = mlir::dyn_cast(ty)) { - if (!tensor_ty.hasStaticShape()) return failure(); - } - result_types.push_back(ty); - } - - // Infeed takes a single token operand. Generate the token using - // create_token op to pass to the infeed op. - auto token = rewriter.create( - op.getLoc(), mhlo::TokenType::get(rewriter.getContext())); - - result_types.push_back(token.getType()); - - ArrayAttr layout; // filled in during the xla-adjust-layout pass - auto data_and_token = - rewriter.create(op.getLoc(), result_types, token, - /*infeed_config=*/rewriter.getStringAttr(""), - /*layout=*/layout); - - result_types.pop_back(); // remove the token type. - - if (op.get_XlaSharding().has_value()) { - // _XlaSharding attribute in TF is a serialized string of the OpSharding - // proto, so convert to a text form here. - ::xla::OpSharding sharding_proto; - if (tensorflow::DecodeShardingAttribute( - op.get_XlaSharding().value().str(), sharding_proto) - .failed()) { - return failure(); - } - // Token is a control signal and not a real data, so arbitrarily assign - // the token to device 0. - if (sharding_proto.type() == ::xla::OpSharding::TUPLE) { - *sharding_proto.add_tuple_shardings() = - ::xla::sharding_builder::AssignDevice(0); - data_and_token->setAttr( - kShardingAttr, - rewriter.getStringAttr(sharding_proto.SerializeAsString())); - } else { - data_and_token->setAttr(kShardingAttr, op.get_XlaShardingAttr()); - } - } - - if (op->hasAttr("layouts")) { - // Append a UnitAttr for the "token" operand of the mhlo.infeed op here to - // avoid compilation failure when exporting "layouts" attribute of the - // corresponding InfeedDequeueTupleOp to a graph node. - data_and_token->setAttr("layout", op->getAttr("layouts")); - } - llvm::SmallVector results; - results.reserve(result_types.size()); - for (const auto &idx_and_type : llvm::enumerate(result_types)) { - results.push_back(data_and_token.getResult(idx_and_type.index())); - } - rewriter.replaceOp(op, ValueRange(results)); - return success(); - } -}; - -// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed -// ops. -// -// XLA HLO outfeed op expects a token, which we generate by emitting an -// create_token op. -// -// For example the following IR: -// "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> -// () -// -// would be lowered to -// -// %token = "mhlo.create_token"() : () -> !mhlo.token -// %outfeed_token = "mhlo.outfeed"(%val_1, %val_2, %token) {outfeed_config = ""} -// : -// (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token -// -class ConvertOutfeedEnqueueTupleOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, - PatternRewriter &rewriter) const override { - auto token_type = mhlo::TokenType::get(rewriter.getContext()); - auto token = rewriter.create(op.getLoc(), token_type); - - rewriter.create(op.getLoc(), token_type, op.getInputs(), token, - /*outfeed_config=*/rewriter.getStringAttr("")); - rewriter.eraseOp(op); - return success(); - } -}; - -// Converts tf.TopKV2 to chlo.top_k. -class ConvertTopKV2Op : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::TopKV2Op op, - PatternRewriter &rewriter) const override { - // We can only match when the `k` operand is a constant scalar. - DenseIntElementsAttr k_attr; - if (!matchPattern(op.getK(), m_Constant(&k_attr))) return failure(); - int64_t k = (*k_attr.begin()).getSExtValue(); - - TensorType input_type = mlir::cast(op.getInput().getType()); - if (!input_type.hasRank()) return failure(); - int64_t input_rank = input_type.getRank(); - int64_t last_dim_index = input_rank - 1; - int64_t last_dim_size = input_type.getDimSize(last_dim_index); - if (last_dim_size == ShapedType::kDynamic) return failure(); - - rewriter.replaceOpWithNewOp(op, op.getInput(), k); - return success(); - } -}; - -// Converts tf.Unpack to a series of XLA HLO slice ops. -// -// Each slice takes one element along the dimension to unpack and takes the full -// range for all other dimensions. Each slice is then reshaped to drop the -// dimension to unpack (which is always of size 1). -// TODO(antiagainst): consider changing this into a TF internal lowering pass. -class ConvertUnpackOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::UnpackOp op, - PatternRewriter &rewriter) const override { - auto value_type = mlir::dyn_cast(op.getValue().getType()); - if (!value_type) return failure(); - - int64_t value_rank = value_type.getRank(); - int64_t axis = op.getAxis(); - if (axis < 0) axis += value_rank; - - // Parameters for constructing each slice. - SmallVector begin_indices(value_rank, 0); - auto end_indices = llvm::to_vector<4>(value_type.getShape()); - SmallVector strides(value_rank, 1); - - // All HLO slice+squeeze results used to replace the original tf.Unpack op. - SmallVector results; - results.reserve(op.getNumResults()); - - for (int i = 0, end = op.getNumResults(); i < end; ++i) { - begin_indices[axis] = i; - end_indices[axis] = i + 1; - - auto slice_op = rewriter.create( - op.getLoc(), op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter)); - // Reshape to drop the axis dimension. - auto result = rewriter.create( - op.getLoc(), op.getType(i), slice_op, - rewriter.getI64ArrayAttr(op.getAxis())); - results.push_back(result); - } - - rewriter.replaceOp(op, results); - return success(); - } -}; - -// Converts tf.Unpack to a series of XLA HLO Slice ops. -// TODO(disc): To recover static special case's performance with folding and -// canonicalization. -class ConvertUnpackOpDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::UnpackOp op, - PatternRewriter &rewriter) const override { - auto value_type = mlir::dyn_cast(op.getValue().getType()); - if (!value_type) return failure(); - // TODO(disc): Remove this constraint once fold and canonicalization - // implemented. - if (value_type.hasStaticShape()) return failure(); - - int64_t value_rank = value_type.getRank(); - int64_t axis = op.getAxis(); - if (axis < 0) axis += value_rank; - Location loc = op.getLoc(); - - auto shape_scalar_type = rewriter.getIntegerType(32); - // Parameters for constructing each slice. - SmallVector begin_indices, end_indices, strides; - begin_indices.reserve(value_rank); - end_indices.reserve(value_rank); - strides.reserve(value_rank); - // final output shape - SmallVector shape_values; - shape_values.reserve(value_rank - 1); - // slice shape before reshape, should be like{?, 1, ?, ?} if axis = 1 - SmallVector slice_shape(value_rank, ShapedType::kDynamic); - for (int64_t dim_idx = 0; dim_idx < value_rank; ++dim_idx) { - int64_t dim_size = value_type.getDimSize(dim_idx); - if (dim_size == ShapedType::kDynamic) { - Value dim_i = rewriter.create( - loc, shape_scalar_type, - rewriter.create(loc, op.getOperand(), dim_idx)); - end_indices.push_back(dim_i); - if (dim_idx != axis) { - shape_values.push_back(dim_i); - } - } else { - Value dim_i = rewriter.create( - loc, shape_scalar_type, - rewriter.getIntegerAttr(shape_scalar_type, dim_size)); - end_indices.push_back(dim_i); - if (dim_idx != axis) { - shape_values.push_back(dim_i); - slice_shape[dim_idx] = dim_size; - } else { - slice_shape[dim_idx] = 1; - } - } - begin_indices.push_back( - rewriter.create(loc, 0, 32)); - strides.push_back(rewriter.create(loc, 1, 32)); - } - - SmallVector results; - results.reserve(op.getNumResults()); - Type i32_ty = rewriter.getI32Type(); - for (int64_t i = 0; i < op.getNumResults(); ++i) { - begin_indices[axis] = rewriter.create(loc, i, 32); - end_indices[axis] = rewriter.create(loc, i + 1, 32); - Value slice_op = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape(slice_shape, - value_type.getElementType()), - op.getValue(), - rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(begin_indices.size())}, i32_ty), - begin_indices), - rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(end_indices.size())}, i32_ty), - end_indices), - rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(strides.size())}, i32_ty), - strides)); - // Reshape to drop the axis dimension. - Value new_shape = rewriter.create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(shape_values.size())}, i32_ty), - shape_values); - Value reshape_op = rewriter.create(loc, op.getType(i), - slice_op, new_shape); - results.push_back(reshape_op); - } - - rewriter.replaceOp(op, results); - return success(); - } -}; - -// Converts the tf.SigmoidGradOp -// TODO(disc): To recover static special case's performance with folding and -// canonicalization. -class ConvertSigmoidGradOpDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SigmoidGradOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value y = op.getY(); - Value dy = op.getDy(); - auto tp_y = mlir::dyn_cast(y.getType()); - auto tp_dy = mlir::dyn_cast(dy.getType()); - if (!tp_y || !tp_dy) return failure(); - - // TODO(disc): Remove this constraint once fold and canonicalization - // implemented. - if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure(); - - Attribute attr; - Type elem_tp = tp_y.getElementType(); - if (elem_tp.isSignlessInteger()) { - attr = rewriter.getIntegerAttr(elem_tp, 1); - } else { - assert(mlir::isa(elem_tp)); - attr = rewriter.getFloatAttr(elem_tp, 1); - } - Value one = rewriter.create( - loc, DenseElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape({}, elem_tp), attr)); - - auto v0 = rewriter.create( - loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y)); - auto v1 = rewriter.create( - loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y)); - auto result = rewriter.create( - loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1)); - - rewriter.replaceOp(op, result.getOperation()->getResults()); - return success(); - } -}; - -// Converts TF unsorted segment reduction ops to XLA HLO scatter op. -// -// TF unsorted segment reduction op peforms the following calculation: -// -// Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's shape is -// [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's -// shape, so we can have data's shape represented as [SI0, SI1, ..., SIm, -// Dm+1, ..., Dn]. Then -// output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] = -// over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in] -// where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN). -// -// The op will be translated to XLA HLO scatter with the following parameters: -// * Update window dims is [segment_id_rank, data_rank). -// * Inserted window dims is {0}. -// * Scatter dims to operand dims mapping is {0}. -// * Index vector dim is segment_id_rank. -template -class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - auto data_type = mlir::dyn_cast(op.getData().getType()); - if (!data_type) return failure(); - int64_t data_rank = data_type.getRank(); - - auto segment_ids_type = - mlir::dyn_cast(op.getSegmentIds().getType()); - if (!segment_ids_type) return failure(); - int64_t segment_ids_rank = segment_ids_type.getRank(); - - DenseIntElementsAttr num_segments_attr; - if (!matchPattern(op.getNumSegments(), m_Constant(&num_segments_attr))) - return failure(); - - // The final shape for TF unsorted segment reduction op is [num_segments] + - // data_shape[segment_ids_rank:]. - SmallVector output_shape; - output_shape.push_back((*num_segments_attr.begin()).getSExtValue()); - auto suffix = data_type.getShape().drop_front(segment_ids_rank); - output_shape.append(suffix.begin(), suffix.end()); - auto output_type = tensorflow::GetTypeFromTFTensorShape( - output_shape, data_type.getElementType()); - - // Broadcast the initial value for reduction. This will become the - // 'operand' parameter to scatter to for the final scatter op. - Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), - op.getLoc(), &rewriter); - auto broadcasted_init = rewriter.create( - op.getLoc(), output_type, init, - GetI64ElementsAttr(output_shape, &rewriter)); - - // Parameters for the generated scatter op. - SmallVector inserted_window_dims(1, 0); - SmallVector scatter_dims_to_operand_dims(1, 0); - int64_t index_vector_dim = segment_ids_rank; - - // Put all parameters in a StructAttr. - auto dims_attr = ScatterDimensionNumbersAttr::get( - rewriter.getContext(), - llvm::to_vector<4>(llvm::seq(segment_ids_rank, data_rank)), - inserted_window_dims, - /*inputBatchingDims=*/{}, - /*scatterIndicesBatchingDims=*/{}, scatter_dims_to_operand_dims, - index_vector_dim); - - auto scatter = rewriter.create( - op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)), - op.getSegmentIds(), op.getData(), dims_attr); - BuildReduceBody(data_type.getElementType(), - &scatter.getUpdateComputation(), &rewriter); - - rewriter.replaceOp(op, scatter.getResult(0)); - return success(); - } -}; - -class ConvertUnsortedSegmentMaxOp - : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> { - public: - using GenericConvertUnsortedSegmentReductionOp:: - GenericConvertUnsortedSegmentReductionOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest, - rewriter); - } -}; - -class ConvertUnsortedSegmentMinOp - : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> { - public: - using GenericConvertUnsortedSegmentReductionOp:: - GenericConvertUnsortedSegmentReductionOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax, - rewriter); - } -}; - -class ConvertUnsortedSegmentProdOp - : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> { - public: - using GenericConvertUnsortedSegmentReductionOp:: - GenericConvertUnsortedSegmentReductionOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); - } -}; - -class ConvertUnsortedSegmentSumOp - : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> { - public: - using GenericConvertUnsortedSegmentReductionOp:: - GenericConvertUnsortedSegmentReductionOp; - - static Value GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter *rewriter) { - return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter); - } -}; - -// Converts tf.RandomShuffle op into a series of XLA HLO ops. -// -// tf.RandomShuffle shuffles tensors along the first dimension. If the input -// tensor's rank is 1, then it is translated into HLO sort op(s) according to -// indices randomly generated via HLO rng_uniform ops. Otherwise, it is -// translated into an HLO while op to first emulate shuffling indices using -// HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather -// with the shuffled indices. -class ConvertRandomShuffleOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::RandomShuffleOp op, - PatternRewriter &rewriter) const override { - auto no_op = [&]() { - rewriter.replaceOp(op, op.getValue()); - return success(); - }; - - auto input_type = mlir::dyn_cast(op.getValue().getType()); - if (!input_type) return failure(); - if (input_type.hasStaticShape() && input_type.getNumElements() <= 1) - // No shuffling is required, so copy input directly to output. - return no_op(); - - int64_t input_rank = input_type.getRank(); - int64_t first_dim_size = input_type.getDimSize(0); - if (ShapedType::isDynamic(first_dim_size)) return failure(); - - if (first_dim_size <= 1) - // No shuffling is required, so copy input directly to output. - return no_op(); - - // For vectors, shuffle values by sorting instead of the obvious - // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct, - // but not easily parallelizable. For a sufficiently parallel architecture, - // it is faster to sort many times, than Fisher-Yates shuffle once. - if (input_rank == 1) { - // Shuffle values by assigning each value a random key and sorting the - // keys. Keys can collide causing detectable patterns in the shuffled - // output. Collisions translates into more ascending sub-sequences in the - // shuffled output than would be expected by chance. To avoid collisions, - // the number of possible key values must be sufficiently large. - - // How are more than 2^32 keys created? In each loop iteration, the - // algorithm sorts by random keys. Conceptually, the earlier iterations - // are sorting on the lower-order bits of larger keys that are never - // actually assembled. - - // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is - // the number of possible keys and n is the number of values. If d = n^2, - // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit - // as n goes to infinity is zero. - - // This implementation ensures that the key-space is greater than or equal - // to the cube of the number of values. The risk of collisions can be - // further reduced by increasing Exponent at the expense of - // performance. - - // For Exponent = 2, the expected number of collisions per shuffle is - // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is - // about 1/2. - - // For Exponent = 3, the expected number of collisions per shuffle is - // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is - // about 1/3255. - - // For Exponent = 4, the expected number of collisions per shuffle is - // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is - // about 1/132622. - constexpr int exponent = 3; - int64_t num_elements = input_type.getNumElements(); - uint32_t u32_max = std::numeric_limits::max(); - int rounds = - std::ceil(exponent * std::log(num_elements) / std::log(u32_max)); - - Value current = op.getValue(); - for (int i = 0; i < rounds; ++i) { - auto keys = - CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0, - /*upper_limit=*/u32_max, &rewriter); - auto sorted = createSortOp( - &rewriter, op.getLoc(), {keys, current}, - {rewriter.getIntegerType(32), input_type.getElementType()}, - /*dimension=*/-1, /*isStable=*/false, - /*direction=*/ComparisonDirection::LT); - current = sorted.getResult(1); - } - rewriter.replaceOp(op, current); - return success(); - } - - // The Fisher-Yates algorithm. - - // Generate range(n) as the initial value for the indices to be swapped. - auto indices_type = tensorflow::GetTypeFromTFTensorShape( - {first_dim_size}, rewriter.getIntegerType(32)); - Value indices = rewriter.create( - op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0)); - - // Generate random numbers to be used as swaps for the indices. - Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0, - first_dim_size, &rewriter); - - // While loop body to perform index swaps. - auto swap_body_fn = [&](Location loc, Value i, ArrayRef old_values, - SmallVectorImpl *new_values, - OpBuilder *builder) { - Value swaps = old_values[0]; - Value indices = old_values[1]; - - auto scalar_i32_type = - tensorflow::GetTypeFromTFTensorShape({}, builder->getIntegerType(32)); - auto one_cross_i64_type = tensorflow::GetTypeFromTFTensorShape( - {1}, builder->getIntegerType(64)); - - auto scalar_one = - DenseIntElementsAttr::get(one_cross_i64_type, ArrayRef(1)); - - // We need to swap the indices[i] with indices[swaps[i]]. First get - // these index values. - Value source_index = - builder->create(loc, indices, i, scalar_one); - Value swap_index = builder->create( - loc, scalar_i32_type, - builder->create(loc, swaps, i, scalar_one)); - Value target_index = builder->create( - loc, indices, swap_index, scalar_one); - - // Then perform the swap. - // indices[i] <- indices[swaps[i]] - indices = builder->create( - loc, indices.getType(), indices, target_index, llvm::ArrayRef(i)); - // indices[swaps[i]] <- indices[i] - indices = builder->create( - loc, indices.getType(), indices, source_index, - llvm::ArrayRef(swap_index)); - - // Update new values. - new_values->assign({swaps, indices}); - }; - - // Create a while op to swap indices. - SmallVector while_output; - CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices}, - &while_output, &rewriter); - Value swaped_indices = while_output[1]; - - // Gather the data using the swapped indices as the shuffled order. - auto slice_sizes = tensorflow::ConvertMlirShapeToTF(input_type.getShape()); - slice_sizes[0] = 1; - auto dims_attr = GatherDimensionNumbersAttr::get( - rewriter.getContext(), - /*offsetDims=*/llvm::to_vector<4>(llvm::seq(1, input_rank)), - /*collapsedSliceDims=*/{0}, - /*operandBatchingDims=*/{}, - /*startIndicesBatchingDims=*/{}, - /*startIndexMap=*/{0}, - /*indexVectorDim=*/1); - - SmallVector slice_sizes_values; - for (auto i = 0; i < slice_sizes.size(); ++i) { - if (slice_sizes[i] == tensorflow::kTFDynamicSize) { - Value i_const = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(i)); - Value slice_size_index = - rewriter.create(op.getLoc(), op.getValue(), i_const); - Value index_to_i64 = rewriter.create( - op.getLoc(), rewriter.getI64Type(), slice_size_index); - Value i64_to_tensor = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), - index_to_i64); - slice_sizes_values.push_back(i64_to_tensor); - } else { - slice_sizes_values.push_back(rewriter.create( - op.getLoc(), GetI64ElementsAttr({slice_sizes[i]}, &rewriter))); - } - } - - auto slice_sizes_concat = rewriter.create( - op.getLoc(), slice_sizes_values, rewriter.getI64IntegerAttr(0)); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getValue(), swaped_indices, slice_sizes_concat, - dims_attr); - - return success(); - } -}; - -// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes. -class ConvertXlaShardingOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaShardingOp op, - PatternRewriter &rewriter) const override { - // TODO(b/148313088): define sharding attribute struct in MLIR intead of - // using a string. - if (!op.get_XlaSharding().has_value()) return failure(); - - NamedAttribute call_target_name = rewriter.getNamedAttr( - "call_target_name", rewriter.getStringAttr("Sharding")); - - auto custom_call = rewriter.create( - op.getLoc(), op.getType(), op.getInput(), - ArrayRef{call_target_name}); - custom_call->setAttr(kShardingAttr, op.get_XlaShardingAttr()); - rewriter.replaceOp(op, custom_call.getResult(0)); - - return success(); - } -}; - -// Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO. -class ConvertInplaceUpdateOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, - PatternRewriter &rewriter) const override { - auto input = mlir::dyn_cast>(op.getX()); - if (!input) return failure(); - auto indices = op.getI(); - auto updates = op.getV(); - - // Slice each row of `i` and `v` to perform a separate dynamic-update-slice - // on the contents of `x`. - auto input_type = mlir::cast(input.getType()); - auto updates_type = mlir::cast(updates.getType()); - auto indices_type = mlir::cast(indices.getType()); - if (!input_type.hasRank()) return failure(); - if (!updates_type.hasRank() || updates_type.isDynamicDim(0)) - return failure(); - if (!indices_type.hasStaticShape()) return failure(); - - if (indices_type.getRank() != 1) return failure(); - - SmallVector unpacked_indices_type( - indices_type.getDimSize(0), tensorflow::GetTypeFromTFTensorShape( - {}, indices_type.getElementType())); - // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are - // required to have matching types. This rewrite rule creates - // DynamicUpdateSlice ops where the first "start index" is always i32 and - // subsequent ones are constructed based on zero_attr. Thus the type - // for zero_attr needs to be i32 as well. - auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0); - auto unpacked_indices = rewriter.create( - op.getLoc(), unpacked_indices_type, indices, zero_attr); - - SmallVector split_updates_shape; - split_updates_shape.append(updates_type.getShape().begin(), - updates_type.getShape().end()); - split_updates_shape.front() = 1; - SmallVector split_updates_type; - split_updates_type.resize( - updates_type.getShape().front(), - tensorflow::GetTypeFromTFTensorShape(split_updates_shape, - updates_type.getElementType())); - - auto cst = - rewriter.create(op.getLoc(), zero_attr).getResult(); - auto split_updates = rewriter.create( - op.getLoc(), split_updates_type, cst, updates); - - SmallVector input_indices; - input_indices.resize(input_type.getRank(), cst); - - for (auto pair : - llvm::zip(unpacked_indices.getOutput(), split_updates.getOutput())) { - input_indices.front() = std::get<0>(pair); - input = rewriter.create( - op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); - } - - rewriter.replaceOp(op, input); - return success(); - } -}; - -// Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO. -class ConvertXlaDynamicUpdateSliceOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, - PatternRewriter &rewriter) const override { - auto indices_type = - mlir::dyn_cast(op.getIndices().getType()); - if (!indices_type || !indices_type.hasStaticShape() || - indices_type.getShape().size() != 1) - return failure(); - - SmallVector unpacked_indices_type( - indices_type.getDimSize(0), tensorflow::GetTypeFromTFTensorShape( - {}, indices_type.getElementType())); - auto unpacked_indices = rewriter.create( - op.getLoc(), unpacked_indices_type, op.getIndices(), - IntegerAttr::get(rewriter.getIntegerType(64), 0)); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getInput(), op.getUpdate(), - unpacked_indices.getOutput()); - return success(); - } -}; - -// Converts a TF XlaReduceScatter op to ReduceScatter HLO. -class ConvertXlaReduceScatterOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaReduceScatterOp op, - PatternRewriter &rewriter) const override { - DenseIntElementsAttr group_assignment; - if (!matchPattern(op.getGroupAssignment(), m_Constant(&group_assignment))) - return failure(); - auto replica_groups = - mlir::cast(hlo::convertElementsAttr( - group_assignment, rewriter.getIntegerType(64))); - if (replica_groups.getType().getRank() != 2) return failure(); - - APInt scatter_dimension; - if (!matchPattern(op.getScatterDimension(), - m_ConstantInt(&scatter_dimension))) - return failure(); - - Location loc = op.getLoc(); - Type element_type = getElementTypeOrSelf(op.getInput().getType()); - - auto reduce_scatter = rewriter.create( - loc, op.getType(), op.getInput(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - scatter_dimension.getSExtValue()), - replica_groups, ChannelHandleAttr()); - StringRef reduce_op = op.getReduceOp(); - if (reduce_op == "Add") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); - } else if (reduce_op == "Mul") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); - } else if (reduce_op == "Min") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); - } else if (reduce_op == "Max") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); - } else { - // For mean, add replicas in the same group. Then divide the sum by the - // number of replicas in each group below. - assert(reduce_op == "Mean"); - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); - } - Value result = reduce_scatter.getResult(); - - // For mean, divide the merge result by group size. - if (reduce_op == "Mean") { - int64_t replica_group_size = replica_groups.getType().getDimSize(1); - if (replica_group_size == 0) return failure(); - auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size, - &rewriter); - auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); - result = rewriter.create( - loc, result, divisor.getResult(), broadcast_dims); - } - - rewriter.replaceOp(op, {result}); - return success(); - } -}; - -// Converts tf.XlaReduceWindow to mhlo.ReduceWindow -class ConvertXlaReduceWindowOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaReduceWindowOp op, - PatternRewriter &rewriter) const override { - DenseElementsAttr window_dimensions, window_strides, base_dilations, - window_dilations, padding; - if (!(matchPattern(op.getWindowDimensions(), - m_Constant(&window_dimensions)) && - matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && - matchPattern(op.getBaseDilations(), m_Constant(&base_dilations)) && - matchPattern(op.getWindowDilations(), - m_Constant(&window_dilations)) && - matchPattern(op.getPadding(), m_Constant(&padding)))) - return failure(); - - Location loc = op.getLoc(); - - SmallVector result_types{op.getResult().getType()}; - // Create the mhlo.SelectAndScatter op. - auto reduce_window_op = rewriter.create( - loc, result_types, op.getInput(), op.getInitValue(), - mlir::cast(hlo::convertElementsAttr( - window_dimensions, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_strides, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - base_dilations, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_dilations, rewriter.getIntegerType(64))), - mlir::cast( - hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); - // Insert a call to the reducer in the region of the mhlo op. - mlir::SymbolRefAttr func = op.getComputation(); - auto func_op = cast(SymbolTable::lookupSymbolIn( - op->getParentOfType(), func)); - auto func_ty = func_op.getFunctionType(); - BuildBodyWithCall(rewriter, loc, func, func_ty, - &reduce_window_op.getBody()); - - rewriter.replaceOp(op, reduce_window_op.getResults()); - - return success(); - } -}; - -// Converts ClipByValue to XLA's clamp operation. Includes the broadcasting -// semantics for static and dynamic cases. -class ConvertClipByValueOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::ClipByValueOp op, - PatternRewriter &rewriter) const override { - Value input = op.getX(); - Value min = op.getClipValueMin(); - Value max = op.getClipValueMax(); - - auto input_ty = mlir::cast(input.getType()); - auto min_ty = mlir::cast(min.getType()); - auto max_ty = mlir::cast(max.getType()); - - if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) { - return failure(); - } - - auto shape = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape({input_ty.getRank()}, - rewriter.getI32Type()), - input); - - if (min_ty != input_ty) { - min = - rewriter.create(op.getLoc(), input_ty, min, shape); - } - - if (max_ty != input_ty) { - max = - rewriter.create(op.getLoc(), input_ty, max, shape); - } - - rewriter.replaceOpWithNewOp(op, input_ty, min, input, max); - return success(); - } -}; - -// Converts ConstOp to XLA's constant operation and introduces a tensor cast if -// needed. -class ConvertConstOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::ConstOp op, - PatternRewriter &rewriter) const override { - // Convert only for valid HLO tensors. - auto ty = mlir::dyn_cast(op.getType()); - if (!ty || - !mlir::isa(ty.getElementType())) - return failure(); - - Location loc = op.getLoc(); - Value result = rewriter.create(loc, op.getValue()); - if (result.getType() != op.getType()) - result = rewriter.create(loc, op.getType(), result); - rewriter.replaceOp(op, result); - return success(); - } -}; - -// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by -// setting appropriate window dimensions, with the given aggregation op as the -// reduction function. The input tensor needs to have a static shape, and 'axis' -// must be const. The TableGen pattern is not used for this rewrite because it -// involves regions. -template -class ConvertCumOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpT op, - PatternRewriter &rewriter) const override { - auto input = mlir::dyn_cast>(op.getX()); - if (!input) return failure(); - auto input_type = mlir::dyn_cast(input.getType()); - if (!input_type || !input_type.hasStaticShape()) { - return failure(); - } - - ArrayRef input_shape = input_type.getShape(); - int64_t rank = input_shape.size(); - - // We can only match when the axis is a constant scalar. - DenseIntElementsAttr axis_attr; - if (!matchPattern(op.getAxis(), m_Constant(&axis_attr))) { - return failure(); - } - - // Get the dimension to apply the reduction on, and offset properly if it is - // negative. - int64_t axis = (*axis_attr.begin()).getSExtValue(); - if (axis < 0) { - axis += rank; - } - - // If we're supposed to sum things up in the reverse direction, we reverse - // the input and then later reverse the output. - if (op.getReverse()) { - llvm::SmallVector dims_to_reverse({axis}); - input = rewriter.create( - op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter)); - } - - // Convert if we need to enlarge the element type's bitwidth to avoid - // precision loss. - Type input_element_type = input_type.getElementType(); - - // TODO(hinsu): Handle complex element types. - if (!input_element_type.isIntOrFloat()) return failure(); - - Type sum_element_type = GetSumAccumulationType(input_element_type); - input = rewriter.create(op.getLoc(), input, sum_element_type); - - SmallVector window_dims(rank, 1); - SmallVector window_strides(rank, 1); - window_dims[axis] = input_shape[axis]; - - SmallVector paddings(rank * 2, 0); - paddings[axis * 2] = - std::max(input_shape[axis] - 1, static_cast(0)); - auto paddings_attr = - DenseIntElementsAttr::get(tensorflow::GetTypeFromTFTensorShape( - {rank, 2}, rewriter.getIntegerType(64)), - paddings); - - int64_t init_value = (std::is_same::value) ? 0 : 1; - Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, - &rewriter); - - auto reduce = rewriter.create( - op.getLoc(), input.getType(), input, init, - GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)), - GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(sum_element_type, &reduce.getBody(), - &rewriter); - Value result = reduce.getResult(0); - - if (op.getExclusive()) { - // In "exclusive" operation, the output will start with the "init" (0) - // values. There is no way to express that as a ReduceWindowOp, so run the - // normal operation, and then use a PadOp to add the 0 "column" on the - // left and cut away the last column on the right. - llvm::SmallVector low_padding(rank, 0); - llvm::SmallVector high_padding(rank, 0); - llvm::SmallVector interior_padding(rank, 0); - low_padding[axis] = 1; - high_padding[axis] = -1; - result = rewriter.create( - op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter), - GetI64ElementsAttr(high_padding, &rewriter), - GetI64ElementsAttr(interior_padding, &rewriter)); - } - - // Convert back if we enlarged the element type's bitwidth. - result = - rewriter.create(op.getLoc(), result, input_element_type); - - if (op.getReverse()) { - llvm::SmallVector dims_to_reverse({axis}); - result = rewriter.create( - op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter)); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; - -using ConvertCumsumOp = ConvertCumOp; -using ConvertCumprodOp = ConvertCumOp; - -// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard -// dialect lowerings. This involves extracting the shape type, extracting and -// converting each dimension to a known integer type, and repacking into a final -// tensor. -class ConvertShapeOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::ShapeOp op, - PatternRewriter &rewriter) const override { - Value input = op.getInput(); - - auto result_ty = mlir::dyn_cast(op.getResult().getType()); - if (!result_ty) { - return failure(); - } - - auto index_tensor = tensorflow::GetTypeFromTFTensorShape( - result_ty.getShape(), rewriter.getIndexType()); - auto shape_op = - rewriter.create(op.getLoc(), index_tensor, input); - rewriter.replaceOpWithNewOp(op, result_ty, shape_op); - return success(); - } -}; - -class ConvertDynamicExpandDimsOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::ExpandDimsOp op, - PatternRewriter &rewriter) const override { - auto input = op.getInput(); - auto input_ty = mlir::cast(input.getType()); - auto result_ty = mlir::cast(op.getType()); - if (!result_ty.hasRank() || !input_ty.hasRank() || - result_ty.hasStaticShape()) { - return failure(); - } - - DenseIntElementsAttr expand_dims_attr; - if (!matchPattern(op.getDim(), m_Constant(&expand_dims_attr))) { - return failure(); - } - - auto shape = rewriter.create( - op.getLoc(), - tensorflow::GetTypeFromTFTensorShape({input_ty.getRank()}, - rewriter.getIndexType()), - input); - auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getValues()); - - llvm::SmallVector dims; - dims.resize(result_ty.getRank()); - - auto inserted_dim = expand_dims[0].getSExtValue(); - - // Handle the negative value use case. - if (inserted_dim < 0) { - inserted_dim += result_ty.getRank(); - // This means the value is completely incorrect, just return. - if (inserted_dim < 0) { - return failure(); - } - } - - dims[inserted_dim] = - rewriter.create(op.getLoc(), 1); - - for (int i = 0; i < dims.size() - 1; i++) { - // Add the extracted dim. - Value index = rewriter.create(op.getLoc(), i); - Value dim = rewriter.create(op.getLoc(), shape, index); - dims[i >= inserted_dim ? i + 1 : i] = dim; - } - - auto from_extents = - rewriter.create(op.getLoc(), dims); - rewriter.replaceOpWithNewOp(op, result_ty, input, - from_extents); - return success(); - } -}; - -class ConvertDynamicSqueezeOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SqueezeOp op, - PatternRewriter &rewriter) const override { - auto input = op.getInput(); - auto input_ty = mlir::cast(input.getType()); - auto result_ty = mlir::cast(op.getType()); - if (!result_ty.hasRank() || !input_ty.hasRank() || - result_ty.hasStaticShape()) { - return failure(); - } - - // The fully dynamic case is unsupported. - if (op.getSqueezeDims().empty()) { - return failure(); - } - - SmallVector squeeze_dims; - int64_t input_rank = input_ty.getRank(); - for (const auto &squeeze_dim_apint : - op.getSqueezeDims().getAsValueRange()) { - int64_t squeeze_dim = squeeze_dim_apint.getSExtValue(); - // Handle negative inputs. - if (squeeze_dim < 0) squeeze_dim += input_rank; - assert(squeeze_dim >= 0 && squeeze_dim < input_rank && - "squeeze dim out of bounds"); - - squeeze_dims.push_back(squeeze_dim); - } - - // Collect the unsqueezed dimensions. - llvm::SmallVector dims; - for (int64_t i = 0; i != input_rank; ++i) { - if (llvm::is_contained(squeeze_dims, i)) continue; - dims.push_back(rewriter.create(op.getLoc(), input, i)); - } - - auto from_extents = - rewriter.create(op.getLoc(), dims); - rewriter.replaceOpWithNewOp(op, result_ty, input, - from_extents); - return success(); - } -}; - -// Converts tf.XlaConvV2 to mhlo.Conv -class ConvertXlaConvV2Op : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaConvV2Op op, - PatternRewriter &rewriter) const override { - DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr, - rhs_dilation_attr, feature_group_count_attr; - if (!(matchPattern(op.getWindowStrides(), - m_Constant(&window_strides_attr)) && - matchPattern(op.getPadding(), m_Constant(&padding_attr)) && - matchPattern(op.getLhsDilation(), m_Constant(&lhs_dilation_attr)) && - matchPattern(op.getRhsDilation(), m_Constant(&rhs_dilation_attr)) && - matchPattern(op.getFeatureGroupCount(), - m_Constant(&feature_group_count_attr)))) - return failure(); - - auto window_strides_named_attr = rewriter.getNamedAttr( - "window_strides", - mlir::cast(hlo::convertElementsAttr( - window_strides_attr, rewriter.getIntegerType(64)))); - - auto padding_named_attr = rewriter.getNamedAttr( - "padding", mlir::cast(hlo::convertElementsAttr( - padding_attr, rewriter.getIntegerType(64)))); - - auto lhs_dilation_named_attr = rewriter.getNamedAttr( - "lhs_dilation", - mlir::cast(hlo::convertElementsAttr( - lhs_dilation_attr, rewriter.getIntegerType(64)))); - - auto rhs_dilation_named_attr = rewriter.getNamedAttr( - "rhs_dilation", - mlir::cast(hlo::convertElementsAttr( - rhs_dilation_attr, rewriter.getIntegerType(64)))); - - int64_t feature_group_count_val = - feature_group_count_attr.getValues()[0].getInt(); - auto feature_group_count_named_attr = rewriter.getNamedAttr( - "feature_group_count", - rewriter.getI64IntegerAttr(feature_group_count_val)); - - auto batch_group_count_named_attr = - rewriter.getNamedAttr("batch_group_count", op.getBatchGroupCountAttr()); - - xla::ConvolutionDimensionNumbers dnums; - dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); - auto dimension_numbers_named_attr = rewriter.getNamedAttr( - "dimension_numbers", - xla::ConvertConvDimensionNumbers(dnums, &rewriter)); - - xla::PrecisionConfig precision_config; - precision_config.ParseFromString( - op.getPrecisionConfigAttr().getValue().str()); - auto precision_config_named_attr = rewriter.getNamedAttr( - "precision_config", - xla::ConvertPrecisionConfig(&precision_config, &rewriter)); - - SmallVector operands{op.getLhs(), op.getRhs()}; - NamedAttribute attrs[] = { - window_strides_named_attr, padding_named_attr, - lhs_dilation_named_attr, rhs_dilation_named_attr, - feature_group_count_named_attr, batch_group_count_named_attr, - dimension_numbers_named_attr, precision_config_named_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); - return success(); - } -}; - -// Converts tf.XlaSelectAndScatter to mhlo.SelectAndScatter -class ConvertXlaSelectAndScatterOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaSelectAndScatterOp op, - PatternRewriter &rewriter) const override { - ElementsAttr window_dimensions, window_strides, padding; - if (!(matchPattern(op.getWindowDimensions(), - m_Constant(&window_dimensions)) && - matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && - matchPattern(op.getPadding(), m_Constant(&padding)))) - return failure(); - - Location loc = op.getLoc(); - - SmallVector result_types{op.getResult().getType()}; - // Create the mhlo.SelectAndScatter op. - auto select_and_scatter_op = rewriter.create( - loc, result_types, op.getOperand(), op.getSource(), op.getInitValue(), - mlir::cast(hlo::convertElementsAttr( - window_dimensions, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_strides, rewriter.getIntegerType(64))), - mlir::cast( - hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); - - auto insert_call_to = [&](const mlir::SymbolRefAttr &func, Region *region) { - auto func_op = cast(SymbolTable::lookupSymbolIn( - op->getParentOfType(), func)); - auto func_ty = func_op.getFunctionType(); - BuildBodyWithCall(rewriter, loc, func, func_ty, region); - }; - - // Insert a call to the select function in the select region of the mhlo op. - insert_call_to(op.getSelect(), &select_and_scatter_op.getSelect()); - // Insert a call to the scatter function in the scatter region of the mhlo - // op. - insert_call_to(op.getScatter(), &select_and_scatter_op.getScatter()); - - rewriter.replaceOp(op, select_and_scatter_op.getResult()); - - return success(); - } -}; - -// Convert tf.XlaSort to mhlo.Sort -class ConvertXlaSortOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaSortOp op, - PatternRewriter &rewriter) const override { - // Create the sort op. - Type element_type = getElementTypeOrSelf(op.getInput().getType()); - auto sort_op = - createSortOp(&rewriter, op.getLoc(), {op.getInput()}, {element_type}, - /*dimension=*/-1, /*isStable=*/false, - /*direction=*/ComparisonDirection::LT); - rewriter.replaceOp(op, sort_op.getResult(0)); - return success(); - } -}; - -inline std::optional TensorFlowRngAlgToXla( - tensorflow::Algorithm alg) { - if (alg == tensorflow::RNG_ALG_PHILOX) { - return xla::RandomAlgorithm::RNG_PHILOX; - } else if (alg == tensorflow::RNG_ALG_THREEFRY) { - return xla::RandomAlgorithm::RNG_THREE_FRY; - } else if (alg == tensorflow::RNG_ALG_AUTO_SELECT) { - return xla::RandomAlgorithm::RNG_DEFAULT; - } - return std::nullopt; -} - -// Converts tf.XlaRngBitGenerator op to mhlo.RngBitGenerator op. -class ConvertXlaRngBitGeneratorOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaRngBitGeneratorOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - DenseElementsAttr algorithm; - if (!(matchPattern(op.getAlgorithm(), m_Constant(&algorithm))) || - algorithm.getType().getRank()) { - return op.emitOpError() << "algorithm must be a constant scalar"; - } - auto alg = static_cast( - algorithm.getValues()[0].getInt()); - auto xla_alg = TensorFlowRngAlgToXla(alg); - if (!xla_alg) { - return op.emitOpError() << "unknown algorithm"; - } - - auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( - rewriter.getContext(), - *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.value())); - auto rng_bit_generator_op = rewriter.create( - loc, op.getResultTypes(), algorithm_attr, op.getInitialState()); - - rewriter.replaceOp(op, rng_bit_generator_op.getResults()); - - return success(); - } -}; - -// Converts tf.XlaVariadicReduceV2 to mhlo.Reduce -class ConvertXlaVariadicReduceV2Op - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaVariadicReduceV2Op op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - mlir::SymbolRefAttr func = op.getReducer(); - auto func_op = cast(SymbolTable::lookupSymbolIn( - op->getParentOfType(), func)); - auto func_ty = func_op.getFunctionType(); - SmallVector elementTypes{llvm::map_range( - func_ty.getResults(), - [](Type ty) { return mlir::cast(ty).getElementType(); })}; - - // Create the mhlo.reduce op. - auto reduce_op = rewriter.create( - loc, op.getInputs(), op.getInitValues(), - GetI64ElementsAttr(op.getDimensionsToReduce()), elementTypes); - - // Insert a call to the reducer in the region of the mhlo op. - BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_op.getBody()); - - rewriter.replaceOp(op, reduce_op.getResults()); - - return success(); - } -}; - -// Convert tf.XlaVariadicSort to mhlo.Sort -class ConvertXlaVariadicSortOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaVariadicSortOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - ElementsAttr dimension; - matchPattern(op.getDimension(), m_Constant(&dimension)); - // Create the mhlo.sort op. - auto sort_op = rewriter.create( - loc, op.getInputs(), dimension.getValues()[0].getInt(), - op.getIsStable()); - mlir::SymbolRefAttr func = op.getComparator(); - auto func_op = cast(SymbolTable::lookupSymbolIn( - op->getParentOfType(), func)); - auto func_ty = func_op.getFunctionType(); - // Insert a call to the reducer in the region of the mhlo op. - BuildBodyWithCall(rewriter, loc, func, func_ty, &sort_op.getComparator()); - - rewriter.replaceOp(op, sort_op.getResults()); - return success(); - } -}; - -// Convert tf.XlaReducePrecision to mhlo.ReducePrecision -class ConvertXlaReducePrecisionOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::XlaReducePrecisionOp op, - PatternRewriter &rewriter) const override { - IntegerType int32_type = rewriter.getIntegerType(32); - APInt exponent_bits = op.getExponentBitsAttr().getValue(); - // Truncating to 32-bits is safe, since pasing any number above the dtype - // size (which is at most 64, for float64) is equivalent to passing the - // dtype size. - IntegerAttr new_exponent_attr = - IntegerAttr::get(int32_type, exponent_bits.truncSSat(32)); - APInt mantissa_bits = op.getMantissaBitsAttr().getValue(); - IntegerAttr new_mantissa_attr = - IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32)); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getOperand(), new_exponent_attr, - new_mantissa_attr); - return success(); - } -}; - -class LowerYieldOp : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - TF::YieldOp op, TF::YieldOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); - return success(); - } -}; - -// Returns a new tensor type from the given type with element type updated to -// the given type. -TensorType UpdateElementTypeTo(Type ty, Type element_ty) { - auto ranked_ty = mlir::dyn_cast(ty); - if (!ranked_ty) { - return UnrankedTensorType::get(element_ty); - } - return RankedTensorType::get(ranked_ty.getShape(), element_ty, - ranked_ty.getEncoding()); -} - -template -class LowerControlFlowOp : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - SrcOpT op, typename SrcOpT::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - DstOpT mhlo_op; - Location loc = op.getLoc(); - - // To handle quant type conversions, use the converted operands' element - // types and original source op's shapes and encoding to get converted op's - // result types. This is only done for the While op for now. - llvm::SmallVector element_types; - int64_t num_results = op.getNumResults(); - if constexpr (std::is_same::value) { - element_types.reserve(num_results); - for (Value value : adaptor.getOperands()) { - element_types.push_back(getElementTypeOrSelf(value.getType())); - } - } - - if constexpr (std::is_same::value) { - // Explicitly handle the Case op because it has variadic regions and takes - // the number of regions as an input along with the operands. - mhlo_op = rewriter.create(loc, op.getResultTypes(), - adaptor.getBranchIndex(), - op.getBranches().size()); - } else if constexpr (std::is_same::value) { - llvm::SmallVector while_result_types; - while_result_types.reserve(num_results); - for (int64_t idx = 0; idx < num_results; ++idx) { - auto ty = UpdateElementTypeTo(op.getType(idx), element_types[idx]); - while_result_types.push_back(ty); - } - - mhlo_op = rewriter.create(loc, TypeRange(while_result_types), - adaptor.getOperands()); - } else { - mhlo_op = rewriter.create(loc, op.getResultTypes(), - adaptor.getOperands()); - } - - int64_t num_regions = op.getNumRegions(); - for (int64_t idx = 0; idx < num_regions; ++idx) { - Region ®ion = mhlo_op.getBodyRegion(idx); - rewriter.inlineRegionBefore(op.getBodyRegion(idx), region, region.end()); - - // Update region's entry blocks argument types to handle quantized element - // types. - if constexpr (std::is_same::value) { - TypeConverter::SignatureConversion signature(num_results); - Block &block = region.front(); - for (const auto &[block_idx, original_ty] : - llvm::enumerate(block.getArgumentTypes())) { - TensorType updated_ty = - UpdateElementTypeTo(original_ty, element_types[block_idx]); - signature.addInputs(block_idx, {updated_ty}); - } - rewriter.applySignatureConversion(®ion.front(), signature); - } - } - - // Replace all uses of `op` results with the newly created op. - rewriter.replaceOp(op, mhlo_op); - return success(); - } -}; - -// Keep all these in the odml namespace to avoid collisions with the tf2xla -// version for now. -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_legalize_tf.inc" - -// LINT.IfChange -void PopulatePatterns(MLIRContext *context, RewritePatternSet *patterns) { - populateWithGenerated(*patterns); - // clang-format off - patterns->add< - ConvertAllOp, - ConvertAnyOp, - ConvertArgMaxOp, - ConvertArgMinOp, - ConvertBatchMatMulV2Op, - ConvertBiasAddOp, - ConvertBroadcastToOp, - ConvertBF16FloorDivOp, - ConvertClipByValueOp, - ConvertConstOp, - ConvertConv2DOp, - ConvertConv3DOp, - ConvertDepthConv2DOp, - ConvertConv2DBackpropFilterOp, - ConvertConv3DBackpropFilterOp, - ConvertConv2DBackpropInputOp, - ConvertConv3DBackpropInputOp, - ConvertCumprodOp, - ConvertCumsumOp, - ConvertDiagPartOp, - ConvertDynamicExpandDimsOp, - ConvertDynamicSqueezeOp, - ConvertEinsumOp, - ConvertRFFTOp, - ConvertIRFFTOp, - ConvertFusedBatchNormGradOp, - ConvertFusedBatchNormGradV2Op, - ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV2Op, - ConvertFusedBatchNormV3Op, - ConvertInfeedDequeueTupleOp, - ConvertIdentityNOp, - ConvertInplaceUpdateOp, - ConvertLinSpaceOp, - ConvertMaxOp, - ConvertMinOp, - ConvertAvgPool2DOp, - ConvertAvgPool3DOp, - ConvertAvgPool2DGradOp, - ConvertAvgPool3DGradOp, - ConvertMaxPool2DOp, - ConvertMaxPool3DOp, - ConvertMaxPool2DGradOp, - ConvertMaxPool3DGradOp, - ConvertMeanOp, - ConvertOneHotOp, - ConvertOutfeedEnqueueTupleOp, - ConvertProdOp, - ConvertDynamicRangeOp, - ConvertMatrixDiagPartV3Op, - ConvertRangeOp, - ConvertSelectOp, - ConvertShapeOp, - ConvertSplitOp, - ConvertSplitVOp, - ConvertStridedSliceOp, - ConvertStridedSliceGradOp, - ConvertSumOp, - ConvertTensorScatterAddOp, - ConvertTensorScatterSubOp, - ConvertTensorScatterMinOp, - ConvertTensorScatterMaxOp, - ConvertTensorScatterUpdateOp, - ConvertTileOp, - ConvertTopKV2Op, - ConvertUnpackOp, - ConvertUnsortedSegmentMaxOp, - ConvertUnsortedSegmentMinOp, - ConvertUnsortedSegmentProdOp, - ConvertUnsortedSegmentSumOp, - ConvertRandomShuffleOp, - ConvertXlaShardingOp, - ConvertXlaDynamicUpdateSliceOp, - ConvertXlaConvV2Op, - ConvertXlaReducePrecisionOp, - ConvertXlaReduceScatterOp, - ConvertXlaReduceWindowOp, - ConvertXlaRngBitGeneratorOp, - ConvertXlaSelectAndScatterOp, - ConvertXlaSortOp, - ConvertXlaVariadicReduceV2Op, - ConvertXlaVariadicSortOp, - ConvertRollOp, - ConvertLeakyReluOp, - ConvertLeakyReluGradOp, - ConvertSplitOpDynamic, - ConvertSliceOpDynamic, - ConvertTileOpDynamic, - ConvertUnpackOpDynamic, - ConvertSigmoidGradOpDynamic, - ConvertConv2DDynamic, - ConvertPadOpDynamic, - ConvertGatherNdOpDynamic, - LowerControlFlowOp, - LowerControlFlowOp, - LowerControlFlowOp, - LowerYieldOp>(context); - // clang-format on -} -// LINT.ThenChange(:MlirAlwaysOps) -} // end namespace -} // end namespace mhlo - -namespace odml { -void PopulateLegalizeTfPatterns(MLIRContext *context, - RewritePatternSet *patterns) { - mlir::mhlo::PopulatePatterns(context, patterns); -} -} // end namespace odml -} // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h deleted file mode 100644 index 9594769e93f71c..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ - -#include -#include - -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project - -namespace mlir { - -namespace func { -class FuncOp; -} -class ModuleOp; -class Operation; -template -class OperationPass; -class Pass; - -namespace odml { - -/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern -/// list. -void PopulateLegalizeTfPatterns(MLIRContext* context, - RewritePatternSet* patterns); - -} // namespace odml -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td deleted file mode 100644 index 185216448a15ed..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td +++ /dev/null @@ -1,802 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the legalization pattern definition file for TF to XLA. - -include "mlir/IR/OpBase.td" -include "mlir/Dialect/Shape/IR/ShapeOps.td" -include "mlir/Dialect/Func/IR/FuncOps.td" -include "mlir/Dialect/Tensor/IR/TensorOps.td" -include "stablehlo/dialect/ChloOps.td" -include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "mhlo/IR/hlo_ops.td" - -def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; -def UnsignedIntTensor : TensorOf<[UI8, UI16, UI32, UI64]>; - -// IEEE compliant floating point tensors. -def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; - -//===----------------------------------------------------------------------===// -// BatchNorm op patterns. -//===----------------------------------------------------------------------===// - -def FalseBoolAttr : AttrConstraint().getValue()">>; -def TrueBoolAttr : AttrConstraint().getValue()">>; - -def CastValueToI64: NativeCodeCall< - "CastValueToI64($0.getLoc(), $1, &$_builder)">; - -def CastValueToElementType: NativeCodeCall< - "$_builder.create($0.getLoc(), $1, " - "getElementTypeOrSelf($2.getType()))">; - -// Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is -// the corresponding value of ranked tensor type whose axis is referred in $0. -def GetHLOAxisFromTFAxis : NativeCodeCall< - "GetHLOAxisFromTFAxis(" - "$0, $1.getType().cast().getRank(), &$_builder)">; - -// Same as the above but with $1 of type operand_range from variadic TensorFlow -// input. -def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< - "GetHLOAxisFromTFAxis(" - "$0, (*$1.begin()).getType().cast().getRank(), " - "&$_builder)">; - -def CastElementsToI64Elements : NativeCodeCall< - "hlo::convertElementsAttr(" - "$0.cast(), $_builder.getIntegerType(64)).cast()">; - -def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; - -//===----------------------------------------------------------------------===// -// ApproximateEqual op pattern. -//===----------------------------------------------------------------------===// - -class MHLO_ComparisonDirectionValue : - ConstantAttr; - -class CHLO_ComparisonDirectionValue : - ConstantAttr; - -// TODO(b/228291745): Assert that $x and $y have the same shape. -def : Pat<(TF_ApproximateEqualOp:$result $x, $y, $tolerance), - (CHLO_BroadcastCompareOp - (MHLO_AbsOp:$abs (MHLO_SubtractOp $x, $y)), - (CastValueToElementType $result, (MHLO_ConstantOp $tolerance), $abs), - (NullDenseI64ArrayAttr), - CHLO_ComparisonDirectionValue<"LT">, - (CHLO_DEFAULT_COMPARISON_TYPE))>; - -//===----------------------------------------------------------------------===// -// Assert op pattern. -//===----------------------------------------------------------------------===// - -// HLO and XLA doesn't support Assertions. -def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; - -//===----------------------------------------------------------------------===// -// Binary op patterns. -//===----------------------------------------------------------------------===// - -// Check that two values can be broadcasted together -def AreBroadcastCompatible : Constraint, - "types must be broadcastable">; - -class DirectBinaryPat - : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), - (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; - -foreach fromToBinPair = [[TF_AddV2Op, CHLO_BroadcastAddOp], - [TF_Atan2Op, CHLO_BroadcastAtan2Op], - [TF_ComplexOp, CHLO_BroadcastComplexOp], - [TF_DivOp, CHLO_BroadcastDivOp], - [TF_LeftShiftOp, CHLO_BroadcastShiftLeftOp], - [TF_MaximumOp, CHLO_BroadcastMaxOp], - [TF_MinimumOp, CHLO_BroadcastMinOp], - [TF_ModOp, CHLO_BroadcastRemOp], - [TF_MulOp, CHLO_BroadcastMulOp], - [TF_NextAfterOp, CHLO_BroadcastNextAfterOp], - [TF_PolygammaOp, CHLO_BroadcastPolygammaOp], - [TF_PowOp, CHLO_BroadcastPowOp], - [TF_RealDivOp, CHLO_BroadcastDivOp], - [TF_SubOp, CHLO_BroadcastSubOp], - [TF_ZetaOp, CHLO_BroadcastZetaOp]] in - def : DirectBinaryPat; - -def LowerRightShiftSigned : - Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r), - (CHLO_BroadcastShiftRightArithmeticOp $l, $r, - (BinBroadcastDimensions $l, $r)), - [(SignedIntTensor $r)]>; - -def LowerRightShiftUnsigned : - Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r), - (CHLO_BroadcastShiftRightLogicalOp $l, $r, - (BinBroadcastDimensions $l, $r)), - [(UnsignedIntTensor $r)]>; - -// Performs a substitution of FloorDiv, pseudo code below: -// -// return floor(div(x, y)) -def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (MHLO_FloorOp - (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), - [(IEEEFloatTensor $l)]>; - -// Performs a substitution of FloorDiv for integer tensors, which required -// additional correction for a negative numerator / denominator. Equivalent -// pseudocode is shown below: -// -// T z = x / y -// return (z * y != x && (x < 0) != (y < 0)) ? z - 1 : z -// -// BroadcastToDimensions is used to compute the broadcast attr to higher -// dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') -// without returning the broadcast of 'r' to broadcast('l', 'r'). -def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (MHLO_SelectOp - (CHLO_BroadcastAndOp - (CHLO_BroadcastCompareOp - (CHLO_BroadcastMulOp:$mul - (CHLO_BroadcastDivOp:$div $l, $r, - (BinBroadcastDimensions $l, $r)), - $r, (BinBroadcastDimensions $div, $r)), - $l, (BinBroadcastDimensions $mul, $l), CHLO_ComparisonDirectionValue<"NE">, - (CHLO_DEFAULT_COMPARISON_TYPE)), - (CHLO_BroadcastCompareOp - (CHLO_BroadcastCompareOp:$l_cmp $l, - (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), - (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, - (CHLO_DEFAULT_COMPARISON_TYPE)), - (CHLO_BroadcastCompareOp:$r_cmp $r, - (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), - (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, - (CHLO_DEFAULT_COMPARISON_TYPE)), - (BinBroadcastDimensions $l_cmp, $r_cmp), CHLO_ComparisonDirectionValue<"NE">, - (CHLO_DEFAULT_COMPARISON_TYPE)), - (NullDenseI64ArrayAttr)), - (CHLO_BroadcastSubOp $div, - (MHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), - (NullDenseI64ArrayAttr)), $div), - [(SignedIntTensor $l)]>; - -// FloorDiv of unsigned is just div. -def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), - [(UnsignedIntTensor $l)]>; - -// Performs a substitution of FloorMod designed to correct for possibly negative -// values. Pseudocode shown below: -// -// T trunc_mod = std::fmod(x, y); -// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y -// : trunc_mod -def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), - (MHLO_SelectOp - (CHLO_BroadcastAndOp - (CHLO_BroadcastCompareOp - (CHLO_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), - (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), - (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"NE">, - (CHLO_DEFAULT_COMPARISON_TYPE)), - (CHLO_BroadcastCompareOp - (CHLO_BroadcastCompareOp:$r_cmp $r, - (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), - (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, - (CHLO_DEFAULT_COMPARISON_TYPE)), - (CHLO_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, - (BinBroadcastDimensions $rem, $r_zeros), CHLO_ComparisonDirectionValue<"LT">, - (CHLO_DEFAULT_COMPARISON_TYPE)), - (BinBroadcastDimensions $r_cmp, $rem_cmp), CHLO_ComparisonDirectionValue<"NE">, - (CHLO_DEFAULT_COMPARISON_TYPE)), - (NullDenseI64ArrayAttr)), - (CHLO_BroadcastAddOp $r, - $rem, (BinBroadcastDimensions $r, $rem)), $rem), - [(TensorOf<[I8, I16, I32, I64, F16, F32, F64]> $l)]>; - -// FloorMod of unsigned is just mod. -def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), - (CHLO_BroadcastRemOp $l, $r, (BinBroadcastDimensions $l, $r)), - [(UnsignedIntTensor $l)]>; - -def Get2DTransposePerm: NativeCodeCall< - "Get2DTransposePerm($0, &$_builder)">; - -def : Pat<(TF_RiscAddOp $l, $r), (MHLO_AddOp $l, $r)>; - -def : Pat<(TF_RiscDotOp $a, $b, $transpose_a, $transpose_b), - (MHLO_DotOp - (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), - (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), - /*precision_config=*/(NullArrayAttr))>; - -//===----------------------------------------------------------------------===// -// Logical & bitwise binary op patterns. -//===----------------------------------------------------------------------===// - -class DirectLogicalBinaryPat - : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), - (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), - [(AnyTypeOf<[SignedIntTensor, UnsignedIntTensor]> $l)]>; - -foreach fromToBinPair = [[TF_LogicalAndOp, CHLO_BroadcastAndOp], - [TF_LogicalOrOp, CHLO_BroadcastOrOp], - [TF_BitwiseAndOp, CHLO_BroadcastAndOp], - [TF_BitwiseOrOp, CHLO_BroadcastOrOp], - [TF_BitwiseXorOp, CHLO_BroadcastXorOp]] in - def : DirectLogicalBinaryPat; - -//===----------------------------------------------------------------------===// -// Compare op patterns. -//===----------------------------------------------------------------------===// - -class DirectComparePat - : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), - (CHLO_BroadcastCompareOp - $l, $r, (BinBroadcastDimensions $l, $r), direction, - (CHLO_DEFAULT_COMPARISON_TYPE))>; - -def : DirectComparePat>; -def : DirectComparePat>; -def : DirectComparePat>; -def : DirectComparePat>; - -class EqualityPat - : Pat<(FromOp AnyTensor:$l, AnyTensor:$r, - TrueBoolAttr:$incompatible_shape_error), - (CHLO_BroadcastCompareOp - $l, $r, (BinBroadcastDimensions $l, $r), direction, - (CHLO_DEFAULT_COMPARISON_TYPE)), - [(MHLO_Tensor $l)]>; - -def : EqualityPat>; -def : EqualityPat>; - -//===----------------------------------------------------------------------===// -// Concat op patterns. -//===----------------------------------------------------------------------===// - -def OneElementAttrPred - : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; - -def OneElementAttr - : ElementsAttrBase, - "Scalar ElementsAttr">; - -def HasRankedFirstOperand - : Constraint()">>; - -def IsShapedTensor - : Constraint()">>; - -// This pattern converts TensorFlow axis format to HLO axis format which -// doesn't wrap around like TensorFlow and is always positive. For this -// conversion, use the first input to get inputs rank. Other inputs need not be -// ranked. -// Defining op for `axis` is TensorFlow constant op in the pattern as during -// the conversion, original Concat op operands still refers to the old ops even -// if HLO constant op is introduced as an replacement for the TensorFlow -// Constant op. -def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), - (MHLO_ConcatenateOp $inputs, - (GetHLOAxisFromTFAxisVariadic $axis, $inputs)), - [(HasRankedFirstOperand $inputs)]>; - -//===----------------------------------------------------------------------===// -// CollectivePermute op patterns. -//===----------------------------------------------------------------------===// - -def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)), - (MHLO_CollectivePermuteOp $input, - (CastElementsToI64Elements $source_target_pairs), - (NullChannelHandleAttr))>; - -//===----------------------------------------------------------------------===// -// CrossReplicaSum op patterns. -//===----------------------------------------------------------------------===// - -def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)), - (MHLO_CrossReplicaSumOp $input, - (CastElementsToI64Elements $group_assignment))>; - -//===----------------------------------------------------------------------===// -// All2All op patterns. -//===----------------------------------------------------------------------===// - -def ValueToVariadic: NativeCodeCall<"SmallVector{$0}">; -def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), - (MHLO_AllToAllOp (ValueToVariadic $input), $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment), (NullChannelHandleAttr))>; - -//===----------------------------------------------------------------------===// -// FFT op patterns. -//===----------------------------------------------------------------------===// - -class MHLO_FftTypeValue : - ConstantAttr; - -def GetInnerDimFromValue : NativeCodeCall< - "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; - -def CheckInnerDimStatic - : Constraint(), &$_builder)">>; - -def : Pat<(TF_FFTOp:$res $input), - (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), - [(CheckInnerDimStatic $input)]>; - -def : Pat<(TF_IFFTOp:$res $input), - (MHLO_FftOp $input, MHLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), - [(CheckInnerDimStatic $input)]>; - -//===----------------------------------------------------------------------===// -// GatherV2 op patterns. -//===----------------------------------------------------------------------===// - -// Here, $params and $indices needs to be ranked so that $axis and $batch_dims -// attributes can be converted from TensorFlow axis format supporting negative -// indexing to the HLO format. -def LegalizeGatherV2 : - Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices, - (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims), - (MHLO_TorchIndexSelectOp $params, $indices, - (GetHLOAxisFromTFAxis $axis, $params), - (GetHLOAxisFromTFAxis $batch_dims, $indices))>; - -//===----------------------------------------------------------------------===// -// Pad op patterns. -//===----------------------------------------------------------------------===// - -class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; - -class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; - -// Interior padding attribute based on the TF padding. -def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0.cast())">; - -def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), - (MHLO_PadOp $input, $c, - (SliceDenseIntElementsAttrColumn2D<"0"> $padding), - (SliceDenseIntElementsAttrColumn2D<"1"> $padding), - (GetInteriorPadding $padding))>; - -//===----------------------------------------------------------------------===// -// Identity op patterns. -//===----------------------------------------------------------------------===// - -foreach src = [TF_IdentityOp, TF_StopGradientOp, TF__EagerConstOp] in - def : Pat<(src $op), (replaceWithValue $op)>; - -// TODO(b/32223192): Support CheckNumerics in HLO. -foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in - def : Pat<(src $op, $msg), (replaceWithValue $op)>; - -//===----------------------------------------------------------------------===// -// MatMul op patterns. -//===----------------------------------------------------------------------===// - -def GetPrecisionConfig: NativeCodeCall< - "GetPrecisionConfig(&$_builder)">; - -def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), - (MHLO_DotOp - (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), - (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), - /*precision_config=*/(GetPrecisionConfig))>; - -//===----------------------------------------------------------------------===// -// Lower `tf.ZerosLike` -//===----------------------------------------------------------------------===// - -def : Pat<(TF_ZerosLikeOp AnyTensor:$arg), - (MHLO_ConstantLike<"0"> $arg)>; - -//===----------------------------------------------------------------------===// -// Lower `tf.OnesLike` -//===----------------------------------------------------------------------===// - -def : Pat<(TF_OnesLikeOp AnyTensor:$arg), - (MHLO_ConstantLike<"1"> $arg)>; - -//===----------------------------------------------------------------------===// -// Elu op patterns. -//===----------------------------------------------------------------------===// - -def : Pat<(TF_EluOp AnyTensor:$features), - (MHLO_SelectOp - (MHLO_CompareOp - $features, - (MHLO_ConstantLike<"0">:$zero $features), - MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), - $features, - (MHLO_Expm1Op $features))>; - -def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), - (MHLO_SelectOp - (CHLO_BroadcastCompareOp - $features, - (MHLO_ConstantOp:$zero (GetScalarOfType<0> $features)), - (BinBroadcastDimensions $zero, $features), - CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE)), - $gradients, - (MHLO_MulOp - $gradients, - (CHLO_BroadcastAddOp - $features, - (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), - (BinBroadcastDimensions $one, $features))))>; - -//===----------------------------------------------------------------------===// -// Relu op patterns. -//===----------------------------------------------------------------------===// - -// TODO(hinsu): Make these patterns to TF to TF lowering. Relu6 lowering will -// require HLO canonicalization of min and max on a tensor to ClampOp. - -// TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. -def : Pat<(TF_ReluOp AnyTensor:$input), - (CHLO_BroadcastMaxOp - (MHLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, - (BinBroadcastDimensions $zero, $input)), - [(TF_IntOrFpTensor $input)]>; - -// TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. -def : Pat<(TF_Relu6Op AnyRankedTensor:$input), - (MHLO_ClampOp (MHLO_ConstantOp (GetScalarOfType<0> $input)), $input, - (MHLO_ConstantOp (GetScalarOfType<6> $input))), - [(TF_IntOrFpTensor $input)]>; - -// ReluGrad(gradients, features) = gradients * (features > 0) -// The condition that $gradients and $features need to have the same shape is -// implicitly enforced: $zero is created to have the same shape as $features, -// MHLO_SelectOp enforces that $gradients and $zero have the same shape. -def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), - (MHLO_SelectOp - (MHLO_CompareOp $features, (MHLO_ConstantLike<"0">:$zero $features), - MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), - $gradients, $zero)>; - -//===----------------------------------------------------------------------===// -// Softsign op patterns. -//===----------------------------------------------------------------------===// - -/// Converts a TF::SoftsignOp to HLO. -/// Softsign(features) = features / (1 + abs(features)) -def : Pat<(TF_SoftsignOp AnyTensor:$input), - (MHLO_DivOp - $input, - (MHLO_AddOp (MHLO_ConstantLike<"1"> $input), (MHLO_AbsOp $input)) - ) - >; - -/// Converts a TF::SoftsignGradOp to HLO. -/// SoftsignGrad(gradient, features) = gradient / ((1 + abs(features)) ^ 2) -def : Pattern< - (TF_SoftsignGradOp AnyRankedTensor:$gradients, AnyRankedTensor:$features), - [(CHLO_BroadcastAddOp:$add - (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (MHLO_AbsOp $features), - (BinBroadcastDimensions $one, $features) - ), - (CHLO_BroadcastDivOp - $gradients, - (MHLO_MulOp $add, $add), - (BinBroadcastDimensions $gradients, $add) - ) - ]>; - -//===----------------------------------------------------------------------===// -// Slice op patterns. -//===----------------------------------------------------------------------===// - -def UnpackStartingIndices: NativeCodeCall< - "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; - -def CanBeTranslatedToDynamicSlice : Constraint())">>; - -def TFSliceSizes2HLOSliceSizes : NativeCodeCall< - "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," - "&$_builder)">; - -def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, - (ConstantLikeMatcher AnyAttr:$slice_sizes)), - (MHLO_DynamicSliceOp $input, - (UnpackStartingIndices $op, $starting_indices), - (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), - [(CanBeTranslatedToDynamicSlice $input, $starting_indices, - $slice_sizes)]>; - -//===----------------------------------------------------------------------===// -// Select op patterns. -//===----------------------------------------------------------------------===// - - def : Pat<(TF_SelectV2Op MHLO_Tensor:$pred, MHLO_Tensor:$on_true, - MHLO_Tensor:$on_false), - (CHLO_BroadcastSelectOp $pred, $on_true, $on_false)>; - -//===----------------------------------------------------------------------===// -// PartitionedCall and LegacyCall op patterns. -//===----------------------------------------------------------------------===// - -def ArgTypesMatchCallee : Constraint< - // $0 is a resultset (possibly empty), and $_op isn't assigned. So retrieve - // the op using the builder. - CPred<"ArgTypesMatchCallee(&*$_builder.getInsertionPoint(), $1, $2)">>; - -foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { - def : Pat<(callOp:$op $args, FlatSymbolRefAttr:$f, - $config, $config_proto, $executor_type), - (CallOp $f, $args), - [(ArgTypesMatchCallee $op, $args, $f)]>; -} - -// The extra attr on this op is _disable_call_shape_inference, which we ignore -// in the bridge. -def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), - (CallOp $f, $args), - [(ArgTypesMatchCallee $op, $args, $f)]>; - -//===----------------------------------------------------------------------===// -// Reverse op patterns. -//===----------------------------------------------------------------------===// - -// Handles axis conversion for TF reverse. -def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; - -def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), - (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; - -//===----------------------------------------------------------------------===// -// Unary op patterns. -//===----------------------------------------------------------------------===// - -foreach Mapping = [ - [TF_AbsOp, MHLO_AbsOp], - [TF_CeilOp, MHLO_CeilOp], - [TF_ComplexAbsOp, MHLO_AbsOp], - [TF_CosOp, MHLO_CosineOp], - [TF_ExpOp, MHLO_ExpOp], - [TF_Expm1Op, MHLO_Expm1Op], - [TF_ErfOp, MHLO_ErfOp], - [TF_FloorOp, MHLO_FloorOp], - [TF_ImagOp, MHLO_ImagOp], - [TF_InvertOp, MHLO_NotOp], - [TF_IsFiniteOp, MHLO_IsFiniteOp], - [TF_LogOp, MHLO_LogOp], - [TF_Log1pOp, MHLO_Log1pOp], - [TF_LogicalNotOp, MHLO_NotOp], - [TF_NegOp, MHLO_NegOp], - [TF_RealOp, MHLO_RealOp], - [TF_RsqrtOp, MHLO_RsqrtOp], - [TF_SigmoidOp, MHLO_LogisticOp], - [TF_SinOp, MHLO_SineOp], - [TF_SqrtOp, MHLO_SqrtOp], - [TF_TanhOp, MHLO_TanhOp], - [TF_TanOp, MHLO_TanOp] - ] in { - def : Pat<(Mapping[0] MHLO_Tensor:$input), - (Mapping[1] $input)>; -} - -foreach Mapping = [ - [TF_AcosOp, CHLO_AcosOp], - [TF_AcoshOp, CHLO_AcoshOp], - [TF_AsinOp, CHLO_AsinOp], - [TF_AsinhOp, CHLO_AsinhOp], - [TF_AtanOp, CHLO_AtanOp], - [TF_AtanhOp, CHLO_AtanhOp], - [TF_CoshOp, CHLO_CoshOp], - [TF_ConjOp, CHLO_ConjOp], - [TF_DigammaOp, CHLO_DigammaOp], - [TF_ErfcOp, CHLO_ErfcOp], - [TF_IsInfOp, CHLO_IsInfOp], - [TF_LgammaOp, CHLO_LgammaOp], - [TF_SinhOp, CHLO_SinhOp], - ] in { - def : Pat<(Mapping[0] MHLO_AnyTensor:$input), - (Mapping[1] $input)>; -} - -def : Pat<(TF_AngleOp $x), (MHLO_Atan2Op (MHLO_ImagOp $x), (MHLO_RealOp $x))>; - -// TODO(bixia): Lower with Truncate=True for floating point value conversions. -def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (MHLO_ConvertOp $arg)>; - -def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), - (MHLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; - - -// Lowering these ops with static shape to mhlo.reshape -foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { - def : Pat<(TfOp:$res MHLO_Tensor:$arg, $ignored), - (MHLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], [], - (addBenefit 2)>; -} - -// Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. -def : Pat<(TF_SignOp $x), (MHLO_SignOp $x)>; - -def BothElementTypesSameWidthIntOrFloat : Constraint, - "element types must be integers or floats">; - -// TODO(mgester): Due to restrictions of xla::BitcastConvertType we currently -// only lower if both input and output types are int or float and have same width - -def : Pat<(TF_BitcastOp:$res MHLO_Tensor:$arg), - (MHLO_BitcastConvertOp $arg), - [(BothElementTypesSameWidthIntOrFloat $res, $arg)]>; - -// TODO(jpienaar): Lower constant like to constant to broadcast if dynamic -// and going to MHLO. - -//===----------------------------------------------------------------------===// -// Random ops. -//===----------------------------------------------------------------------===// -// TODO(b/148269299): handle random number generator seeds/states correctly. - -class MHLO_RngDistributionValue : - ConstantAttr; - -def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), - (MHLO_RngOp - (MHLO_ConstantOp - (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), - (MHLO_ConstantOp - (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), - (CastValueToI64 $old, $shape), - MHLO_RngDistributionValue<"UNIFORM">), - [(IsShapedTensor $shape)]>; - -def : Pat<(TF_RandomStandardNormalOp:$old $shape, $seed, $seed2), - (MHLO_RngOp - (MHLO_ConstantOp - (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), - (MHLO_ConstantOp - (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), - (CastValueToI64 $old, $shape), - MHLO_RngDistributionValue<"NORMAL">), - [(IsShapedTensor $shape)]>; - -//===----------------------------------------------------------------------===// -// Sigmoid grad op. -//===----------------------------------------------------------------------===// - -// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the -// shape of $l instead of having it as a constant. -def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (MHLO_MulOp - (MHLO_MulOp $r, $l), - (MHLO_SubtractOp (MHLO_ConstantOp (ConstantSplat<"1"> $l)), $l))>; - -//===----------------------------------------------------------------------===// -// Softplus op. -//===----------------------------------------------------------------------===// - -def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">; - -def : Pattern<(TF_SoftplusOp AnyTensor:$features), - [ - (MHLO_ExpOp:$features_exp $features), - (CHLO_BroadcastAddOp:$threshold - (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features))), - (MHLO_ConstantOp (GetScalarOfType<2> $features)), - (NullDenseI64ArrayAttr) - ), - (MHLO_SelectOp:$output - (CHLO_BroadcastCompareOp - $features, - (MHLO_NegOp $threshold), - (NullDenseI64ArrayAttr), - CHLO_ComparisonDirectionValue<"GT">, - (CHLO_DEFAULT_COMPARISON_TYPE) - ), - $features, - (MHLO_SelectOp - (CHLO_BroadcastCompareOp - $features, - $threshold, - (NullDenseI64ArrayAttr), - CHLO_ComparisonDirectionValue<"LT">, - (CHLO_DEFAULT_COMPARISON_TYPE) - ), - $features_exp, - (MHLO_Log1pOp $features_exp) - ) - ), - (replaceWithValue $output) - ]>; - -//===----------------------------------------------------------------------===// -// XlaReplicaId op. -//===----------------------------------------------------------------------===// - -def : Pat<(TF_XlaReplicaIdOp), - (TF_CastOp (MHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; - -//===----------------------------------------------------------------------===// -// XlaGather op. -//===----------------------------------------------------------------------===// - -def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">; - -def HasValidGatherDims : Constraint>; - -def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), - $dimension_numbers, $indices_are_sorted), - (MHLO_GatherOp $operand, $start_indices, - (ToGatherDimNumsAttr $dimension_numbers), - (CastElementsToI64Elements $slice_sizes), - $indices_are_sorted), - [(HasValidGatherDims $dimension_numbers)]>; - -//===----------------------------------------------------------------------===// -// XlaDotOp op. -//===----------------------------------------------------------------------===// - -def ToDotDimNumsAttr : NativeCodeCall<"GetDotDimNumsAttr($0, &$_builder)">; - -def ToPrecisionConfigsAttr : NativeCodeCall<"GetPrecisionConfigAttr($0, &$_builder)">; - -def HasValidDotDims : Constraint>; - -def HasValidPrecisionConfig : Constraint>; - -def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), - (MHLO_DotGeneralOp $lhs, $rhs, - (ToDotDimNumsAttr $dimension_numbers), - (ToPrecisionConfigsAttr $precision_config), - (EmptyDotAlgorithmAttr)), - [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; - -//===----------------------------------------------------------------------===// -// XlaDotV2Op op. -//===----------------------------------------------------------------------===// - -def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), - (MHLO_DotGeneralOp $lhs, $rhs, - (ToDotDimNumsAttr $dimension_numbers), - (ToPrecisionConfigsAttr $precision_config), - (EmptyDotAlgorithmAttr)), - [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; - -//===----------------------------------------------------------------------===// -// XlaDynamicSlice op. -//===----------------------------------------------------------------------===// - -def : Pat<(TF_XlaDynamicSliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, - (ConstantLikeMatcher AnyAttr:$slice_sizes)), - (MHLO_DynamicSliceOp $input, - (UnpackStartingIndices $op, $starting_indices), - (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes))>; - -//===----------------------------------------------------------------------===// -// XlaEisumOp op. -//===----------------------------------------------------------------------===// - -def : Pat<(TF_XlaEinsumOp $lhs, $rhs, $equation), - (MHLO_EinsumOp $lhs, $rhs, $equation)>; - -//===----------------------------------------------------------------------===// -// XlaOptimizationBarrierOp op. -//===----------------------------------------------------------------------===// - -def : Pat<(TF_XlaOptimizationBarrierOp $args), - (MHLO_OptimizationBarrierOp $args)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc index e38cad1d4c7edc..ae4ee26eab9b8c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc @@ -33,11 +33,11 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/Register.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -102,7 +102,7 @@ void TFToMhloPass::runOnOperation() { MLIRContext *context = func->getContext(); RewritePatternSet patterns(context); - odml::PopulateLegalizeTfPatterns(context, &patterns); + mhlo::PopulateLegalizeTfPatterns(context, &patterns); TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns); mhlo::Tf2XlaTypeConverter converter; mhlo::PopulateLegalizeTfWithTf2XlaPatterns( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc deleted file mode 100644 index b120a6f02e1460..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" - -#include - -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/utils/hlo_utils.h" - -namespace mlir { -namespace odml { - -mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, - OpBuilder* builder) { - return builder->create(loc, - hlo::getScalarOfType(ty, raw_value)); -} - -mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, - OpBuilder* builder) { - return builder->create(loc, - hlo::getScalarNegZeroOfType(ty)); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) { - RankedTensorType ty = - RankedTensorType::get(static_cast(attr.size()), - IntegerType::get(attr.getContext(), 64)); - return DenseIntElementsAttr::get(ty, attr.getValue()); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, - Builder* builder) { - RankedTensorType ty = RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, values); -} - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h deleted file mode 100644 index 13ff4c4767721d..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ - -#include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace mlir { -namespace odml { - -// Builds body for reduce op by using the template binary op as the -// reducer op. -template -void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { - OpBuilder::InsertionGuard guard(*builder); - Block* block = builder->createBlock(body); - - // Block arguments are scalars of the given element type. - Type type = RankedTensorType::get(/*shape=*/{}, element_type); - Location loc = body->getLoc(); - block->addArguments({type, type}, SmallVector(2, loc)); - - auto reducer = - builder->create(loc, block->getArgument(0), block->getArgument(1)); - builder->create(loc, reducer.getResult()); -} - -mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, - OpBuilder* builder); - -mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, - OpBuilder* builder); - -// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. -DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); -DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, - Builder* builder); - -} // namespace odml -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc deleted file mode 100644 index 63926df535b6be..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" - -#include - -#include -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace mlir { -namespace odml { -namespace { - -TEST(UtilsTest, GetScalarConstOfType) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - Type ty = builder.getI32Type(); - mhlo::ConstantOp op = GetScalarConstOfType(ty, loc, 123, &builder); - EXPECT_EQ(op.getValue().getValues()[0], 123); - - op->destroy(); -} - -TEST(UtilsTest, GetScalarNegZeroOfType) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - Type ty = builder.getF32Type(); - mhlo::ConstantOp op = GetScalarNegZeroOfType(ty, loc, &builder); - EXPECT_EQ(op.getValue().getValues()[0], -0.f); - - op->destroy(); -} - -TEST(UtilsTest, GetI64ElementsAttr) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - ArrayRef values = {1, 2, 3}; - auto valuesAttr = builder.getI64ArrayAttr(values); - DenseIntElementsAttr attr = GetI64ElementsAttr(valuesAttr); - EXPECT_EQ(attr.getValues()[0], 1); - EXPECT_EQ(attr.getValues()[1], 2); - EXPECT_EQ(attr.getValues()[2], 3); -} - -TEST(UtilsTest, GetI64ElementsAttrBuilder) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - ArrayRef values = {1, 2, 3}; - DenseIntElementsAttr attr = GetI64ElementsAttr(values, &builder); - EXPECT_EQ(attr.getValues()[0], 1); - EXPECT_EQ(attr.getValues()[1], 2); - EXPECT_EQ(attr.getValues()[2], 3); -} - -} // namespace - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 041aa03134924c..e67e0e45961117 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -64,7 +64,6 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 1939a3dc8cd875..4231990e0769d1 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -2654,3 +2654,3813 @@ func.func @sigmoid_grad_dynamic(%arg0: tensor, %arg1: tensor) -> t func.return %0 : tensor } +// ----- + +// CHECK-LABEL: @sin +func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.sine %arg0 : tensor<2xf32> + %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @sin_dynamic +func.func @sin_dynamic(%arg0: tensor) -> tensor { + // CHECK: mhlo.sine %arg0 : tensor + %0 = "tf.Sin"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @rsqrt +func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.rsqrt %arg0 : tensor<2xf32> + %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @rsqrt_dynamic +func.func @rsqrt_dynamic(%arg0: tensor) -> tensor { + // CHECK: mhlo.rsqrt %arg0 : tensor + %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @sqrt +func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.sqrt %arg0 : tensor<2xf32> + %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @sqrt_dynamic +func.func @sqrt_dynamic(%arg0: tensor) -> tensor { + // CHECK: mhlo.sqrt %arg0 : tensor + %0 = "tf.Sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @tanh +func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.tanh %arg0 : tensor<2xf32> + %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @tanh_dynamic +func.func @tanh_dynamic(%arg0: tensor) -> tensor { + // CHECK: mhlo.tanh %arg0 : tensor + %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @bitcast +func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xf32> + %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @bitcast_dynamic +func.func @bitcast_dynamic(%arg0: tensor) -> tensor { + // CHECK: mhlo.bitcast_convert %arg0 : (tensor) -> tensor + %0 = "tf.Bitcast"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @bitcast_same_widths +func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { + // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xi32> + %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: func @bitcast_smaller_input_width +func.func @bitcast_smaller_input_width(%arg0: tensor<8xi8>) -> tensor { + // CHECK: mhlo.bitcast_convert %arg0 : (tensor<8xi8>) -> tensor + %0 = "tf.Bitcast"(%arg0) : (tensor<8xi8>) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @bitcast_smaller_output_width +func.func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2x2xf16> { + // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2x2xf16> + %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf16> + func.return %0 : tensor<2x2xf16> +} + +// ----- + +// CHECK-LABEL: squeeze +func.func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { + // CHECK: mhlo.reshape + %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> + func.return %0 : tensor<1x10xf32> +} + +// ----- + +// CHECK-LABEL: squeeze_ranked +func.func @squeeze_ranked(%arg0: tensor) -> tensor { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor + // CHECK: %[[T:.*]] = tensor.from_elements %[[D2]] : tensor<1xindex> + // CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[R]] : tensor + %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [0, 1] }: (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: squeeze_ranked_negative +func.func @squeeze_ranked_negative(%arg0: tensor) -> tensor { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor + // CHECK: %[[T:.*]] = tensor.from_elements %[[D0]], %[[D2]] : tensor<2xindex> + // CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor, tensor<2xindex>) -> tensor + // CHECK: return %[[R]] : tensor + %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [-2] }: (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: squeeze_ranked_dynamic +func.func @squeeze_ranked_dynamic(%arg0: tensor) -> tensor { + // CHECK: "tf.Squeeze" + %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: squeeze_dynamic +func.func @squeeze_dynamic(%arg0: tensor) -> tensor<*xf32> { + // CHECK: "tf.Squeeze" + %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: expand_dims +func.func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { + // CHECK: mhlo.reshape + %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> + func.return %0 : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: expand_dims_dynamic +func.func @expand_dims_dynamic(%arg0: tensor) -> tensor { + %axis = "tf.Const"() {value = dense<1> : tensor} : () -> (tensor) + + // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[CST0:.+]] = arith.constant 0 + // CHECK-DAG: %[[CST1:.+]] = arith.constant 1 + // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] + // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1 + // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] + // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]] + // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]] + %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor, tensor) -> tensor + + // CHECK: return %[[RESHAPE]] + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: expand_dynamic_dims_rank1_axis +func.func @expand_dynamic_dims_rank1_axis(%arg0: tensor) -> tensor { + %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + + // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[CST0:.+]] = arith.constant 0 + // CHECK-DAG: %[[CST1:.+]] = arith.constant 1 + // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] + // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1 + // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] + // CHECK-DAG: %[[CST2:.+]] = arith.constant 2 + // CHECK-DAG: %[[GETEXTENT2:.+]] = tensor.extract %[[SHAPEOF]][%[[CST2]]] + // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]], %[[GETEXTENT2]] + // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]] + %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor, tensor<1xi32>) -> tensor + + // CHECK: return %[[RESHAPE]] + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @sign +// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> +func.func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + // CHECK: [[SIGN:%.*]] = mhlo.sign [[ARG]] + // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32> + %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) + func.return %0 : tensor<1x2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: func @sign_dynamic +func.func @sign_dynamic(%arg0: tensor) -> tensor { + // CHECK: [[SIGN:%.*]] = mhlo.sign %arg0 : tensor + // CHECK: return [[SIGN]] : tensor + %0 = "tf.Sign"(%arg0) : (tensor) -> (tensor) + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: slice_constant_start +func.func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) + // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> + // CHECK: return %[[RESULT]] : tensor<2xi32> + %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: slice_i32_consts +func.func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor + // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> + %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: slice_constant_start_negative_one_size +func.func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { + // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<3> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<3xi32> + // CHECK: return %[[RESULT]] : tensor<3xi32> + %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32> + func.return %0 : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: slice_constant_start_dynamic_shape +func.func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice" + // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]]) + // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : + // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> + // CHECK: return %[[RESULT]] : tensor<1x4xi32> + %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) + %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) + %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// ----- + +// CHECK-LABEL: slice_variable_start +func.func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor + // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1) + // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + // CHECK: return %[[RESULT]] : tensor<1x4xi32> + %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) + %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// ----- + +// CHECK-LABEL: slice_mhlo_sizes +func.func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { + // CHECK-NOT: "tf.Slice" + %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> + func.return %1 : tensor<1x512x4xf32> +} + +// ----- + +// CHECK-LABEL: slice_variable_start_negative_one_size +func.func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: %[[RESULT:.*]] = "tf.Slice" + // CHECK: return %[[RESULT]] : tensor<1x4xi32> + %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>) + %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// ----- + +// CHECK-LABEL: slice_real_dynamic_slice +func.func @slice_real_dynamic_slice(%arg0: tensor<4xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor { + // CHECK: tensor.extract {{.*}} : tensor<1xi64> + // CHECK: tensor.extract {{.*}} : tensor<1xi64> + // CHECK: arith.index_cast {{.*}} : index to i64 + // CHECK: arith.cmpi eq, {{.*}} : i64 + // CHECK: arith.addi {{.*}} : i64 + // CHECK: tensor.dim {{.*}} : tensor<4xi32> + // CHECK: arith.index_cast {{.*}} : index to i64 + // CHECK: select {{.*}} : i64 + // CHECK: arith.index_cast {{.*}} : i64 to index + // CHECK: arith.index_cast {{.*}} : i64 to index + // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> + %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// StridedSlice op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: simple_strided_slice +func.func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { + %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: mhlo.slice + // CHECK-DAG-SAME: start_indices = dense<[0, 1]> + // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> + // CHECK-DAG-SAME: strides = dense<[1, 3]> + // CHECK-SAME: -> tensor<3x2xf32> + + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> + func.return %output : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: dynamic_strided_slice +func.func @dynamic_strided_slice(%input: tensor) -> tensor { + %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: "tf.StridedSlice" + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + func.return %output : tensor +} + +// ----- + +// CHECK-LABEL: strided_slice_negative_indices +func.func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { + %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> + + // CHECK: mhlo.slice + // CHECK-DAG-SAME: start_indices = dense<[0, 1]> + // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> + // CHECK-DAG-SAME: strides = dense<[1, 3]> + // CHECK-SAME: -> tensor<3x2xf32> + + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> + func.return %output : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: dynamic_strided_slice_negative_indices +func.func @dynamic_strided_slice_negative_indices(%input: tensor) -> tensor { + %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: tf.StridedSlice + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + func.return %output : tensor +} + +// ----- + +// CHECK-LABEL: strided_slice_range_clamping +func.func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> { + %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: mhlo.slice + // CHECK-DAG-SAME: start_indices = dense<[0, 0]> + // CHECK-DAG-SAME: limit_indices = dense<[1, 8]> + // CHECK-DAG-SAME: strides = dense<[1, 3]> + // CHECK-SAME: -> tensor<1x3xf32> + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32> + func.return %output : tensor<1x3xf32> +} + +// ----- + +// CHECK-LABEL: strided_slice_empty +func.func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> { + %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: mhlo.constant dense<> : tensor<0xf32> + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32> + func.return %output : tensor<0xf32> +} + +// ----- + +// CHECK-LABEL: strided_slice_begin_end_mask +// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32> +func.func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { + + // For StridedSlice + // Dim #: 0, 1, 2 + // Input shape: [4, 128, 1024] + // Begin: 1, 4, -3 + // End: 8, 65, 42 + // Stride: 1, 4, -1 + // Begin mask: 0, 0, 1 (= 1) + // End mask: 1, 0, 0 (= 4) + + // So result shape: + // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 + // Dim #1: 4 to 65 stride 4: so 16 + // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 + // result shape: [4, 16, 1022] + + %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) + + // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]]) + + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]]) + // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]> + // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]> + // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> + // CHECK-SAME: -> tensor<4x16x1022xf32> + + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32> + + // CHECK: mhlo.reshape %[[SLICE]] + // CHECK-SAME: -> tensor<4x16x1022xf32> + + func.return +} + +// ----- + +// CHECK-LABEL: strided_slice_shrink_axis_mask +// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32> +func.func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) { + + // For StridedSlice + // Dim #: 0, 1, 2 + // Input shape: [4, 128, 1024] + // Begin: 1, 4, -3 + // End: 8, 65, 42 + // Stride: 1, 4, -1 + // Begin mask: 1, 0, 0 (= 1) + // End mask: 0, 0, 1 (= 4) + // Shrink axis mask: 1, 0, 1 (= 5) + + // So result shape: + // Dim #0: shrink axis, take value at [1] + // Dim #1: 4 to 65 stride 4: so 16 + // Dim #2: shrink axis, take value at [-3] + // result shape: [16] + + // As output shape of StridedSlice differs, a reshape will follow. + + %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) + + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) + // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]> + // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]> + // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> + // CHECK-SAME: -> tensor<1x16x1xf32> + + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32> + + // CHECK: mhlo.reshape %[[SLICE]] + // CHECK-SAME: -> tensor<16xf32> + + func.return +} + +// ----- + +// CHECK-LABEL: strided_slice_ellipsis_mask +// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> +func.func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) { + // For StridedSlice input[1, ..., 8:, :10, 2:6:2] + // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized + // slice input[1, :, :, 8:, :10, 2:6:2] + + // The start, limit indices and strides attributes of mhlo.slice would + // reflect the canonicalized slice. + // As output shape of StridedSlice differs, a reshape will follow. + + %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) + %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>) + %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) + + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) + // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> + // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> + // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> + // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32> + + // CHECK: mhlo.reshape %[[SLICE]] + // CHECK-SAME: -> tensor<4x8x8x10x2xf32> + + func.return +} + +// ----- + +// CHECK-LABEL: strided_slice_new_axis_mask +// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> +func.func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { + // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] + // New axis mask is at index 1 and 6 of sparse spec, so + // new_axis_mask = 2^1 + 2^6 = 66 + // The ellipsis mask is applied to dim #1, #2 of input i.e, we get + // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] + // This is then reshaped to add the new axes. + + // The start, limit indices and strides attributes of mhlo.slice would + // reflect the canonicalized slice. + // As output shape of StridedSlice differs, a reshape will follow to reflect + // new axes added. + + %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) + %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) + %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) + + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) + // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> + // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> + // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> + // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32> + + // CHECK: mhlo.reshape %[[SLICE]] + // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32> + + func.return +} + +// ----- + +// CHECK-LABEL: strided_slice_implicit_ellipsis_mask( +// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> +func.func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { + // StridedSlice gets input[8:10], which is same as input[8:10, ...] + // The start_indices, limit_indices, and strides attribute of mhlo.slice + // reflect the canonicalized slice. + %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> + %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]]) + // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> + // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> + // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[SLICE]] : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> + // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> + func.return %0 : tensor<2x16x2xf32> +} + +// ----- + +// CHECK-LABEL: strided_slice_nonconstant_begin_end +func.func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) { + // In this case, the `begin` and `end` inputs are unknown at compile time -- + // so the StridedSlice needs to slice these vectors and use that as input to + // an HLO dynamic slice. + %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + // CHECK: %[[A:.*]] = mhlo.reshape %arg0 : (tensor) -> tensor<1xi32> + // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]]) + // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] + // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor + // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : + // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor + // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" + // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) + // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : + // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x1x97xi32> + // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32> + %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> + func.return %result : tensor<1x97xi32> +} + +// ----- + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_with_start_end_mask +// CHECK-SAME: (%[[INPUT:.*]]: tensor<32x1x97xi32>, %[[BEGIN:.*]]: tensor<3xi32>, %[[END:.*]]: tensor<3xi32>) +func.func @strided_slice_nonconstant_begin_end_with_start_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<3xi32>, %end: tensor<3xi32>) -> (tensor<1x97xi32>) { + %strides = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> + + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] + // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor + // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : + // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor + // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" + // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) + // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : + // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x1x97xi32> + // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32> + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x97xi32> + func.return %result : tensor<1x97xi32> +} + +// ----- + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1 +func.func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) { + // Dynamic stride: when `begin` and `end` inputs are unknown at compile time, + // `strides` must be known. + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + func.return %result : tensor<1x97xi32> +} + +// ----- + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2 +func.func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown + // at compile time, `strides` must be known to have all 1 values. + %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + func.return %result : tensor<1x97xi32> +} + +// ----- + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count +func.func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> { + %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + // When begin/end are dynamic, the number of output elements must be equal to + // the number of input elements sliced. + // CHECK: tf.StridedSlice + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32> + func.return %0 : tensor<6x10xf32> +} + +// ----- + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask +func.func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This ellipsis mask is not supported because it does not refer to the last + // dimension. + // [0, 1, 0] = 2 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + func.return %result : tensor<1x97xi32> +} + +// ----- + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask +func.func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This ellipsis mask is supported because it refers to the last dimension. + // [1, 0, 0] = 4 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: mhlo.dynamic_slice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + func.return %result : tensor<1x97xi32> +} + +// ----- + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask +func.func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This shrink_axis mask is supported because it refers to a major dimension. + // [1, 1, 1] = 7 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: mhlo.dynamic_slice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + func.return %result : tensor<1x97xi32> +} + +//===----------------------------------------------------------------------===// +// Reduction op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @mean +func.func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { + // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = array} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[MEAN]] : (tensor<4xf32>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: return %[[RESULT]] : tensor<4x1xf16> + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> + func.return %0 : tensor<4x1xf16> +} + +// ----- + +// CHECK-LABEL: func @mean_scalar_dim +func.func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { + // Verify that tf.Mean op with scalar attributes are lowered successfully. + + // CHECK-NOT: tf.Mean + %dimension = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor) -> tensor<4x1xf16> + func.return %0 : tensor<4x1xf16> +} + +// ----- + +// CHECK-LABEL: func @mean_dynamic +func.func @mean_dynamic(%arg0: tensor) -> tensor { + // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor, tensor) -> tensor + // CHECK: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<2xindex> + // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index + // CHECK: %[[REDUCED_DIM:.*]] = tensor.extract %[[SHAPE0]][%[[C1_2]]] : tensor<2xindex> + // CHECK: %[[MUL:.*]] = arith.muli %[[C1_1]], %[[REDUCED_DIM]] : index + // CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[MUL]] : index to i64 + // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor + // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor + // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = array} : (tensor, tensor) -> tensor + // CHECK: %[[MEAN_CONVERTED:.*]] = mhlo.convert %[[MEAN]] : (tensor) -> tensor + // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor -> tensor<1xindex> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[UNREDUCED_DIM:.*]] = tensor.extract %[[SHAPE1]][%[[C0]]] : tensor<1xindex> + // CHECK: %[[RESULT_SHAPE:.*]] = tensor.from_elements %[[UNREDUCED_DIM]], %[[C1]] : tensor<2xindex> + // CHECK: %[[RESULT:.*]] = mhlo.dynamic_reshape %[[MEAN_CONVERTED]], %[[RESULT_SHAPE]] : (tensor, tensor<2xindex>) -> tensor + // CHECK: return %[[RESULT]] : tensor + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor, tensor<1xi64>) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @sum +func.func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { + // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: return %[[RESULT]] : tensor<4x1xf16> + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> + func.return %0 : tensor<4x1xf16> +} + +// ----- + +// CHECK-LABEL: func @sum_dynamic +func.func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { + // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x?xf16>) -> tensor<4x?xf32> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x?xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: return %[[RESULT]] : tensor<4x1xf16> + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> + func.return %0 : tensor<4x1xf16> +} + +// ----- + +// CHECK-LABEL: func @max +func.func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { + // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x8xf16>, tensor) -> tensor<4xf16> + // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> + // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: return %[[RESULT]] : tensor<4x1xf16> + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> + func.return %0 : tensor<4x1xf16> +} + +// ----- + +// CHECK-LABEL: func @max_qint +// Regression test to ensure we don't crash getting the initial value for +// tf.Max when using quantized integer types. +func.func @max_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> + func.return %0 : tensor<4x1x!tf_type.qint8> +} + +// ----- + +// CHECK-LABEL: func @max_dynamic +func.func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { + // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x?xf16> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x?xf16>, tensor) -> tensor<4xf16> + // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> + // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: return %[[RESULT]] : tensor<4x1xf16> + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> + func.return %0 : tensor<4x1xf16> +} + +// ----- + +// CHECK-LABEL: func @min +func.func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { + // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.minimum across dimensions = [1] : (tensor<4x8xf16>, tensor) -> tensor<4xf16> + // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> + // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: return %[[RESULT]] : tensor<4x1xf16> + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> + func.return %0 : tensor<4x1xf16> +} + +// ----- + +// CHECK-LABEL: func @min_qint +// Regression test to ensure we don't crash getting the initial value for +// tf.Min when using quantized integer types. +func.func @min_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> + func.return %0 : tensor<4x1x!tf_type.qint8> +} + +// ----- + +// CHECK-LABEL: func @prod +func.func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { + // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.multiply across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: return %[[RESULT]] : tensor<4x1xf16> + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> + func.return %0 : tensor<4x1xf16> +} + +// ----- + +// CHECK-LABEL: func @prod_qint +// Regression test to ensure we don't crash getting the initial value for +// tf.Prod when using quantized integer types. +func.func @prod_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> + func.return %0 : tensor<4x1x!tf_type.qint8> +} + +// ----- + +// CHECK-LABEL: @all +func.func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[INIT:.*]] = mhlo.constant dense : tensor + // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.and across dimensions = [1] : (tensor<4x8xi1>, tensor) -> tensor<4xi1> + %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> + func.return %0 : tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: @all_keep_dim +func.func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { + // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1> + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> + func.return %0 : tensor<4x1xi1> +} + +// ----- + +// CHECK-LABEL: @all_dynamic +func.func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1> + // CHECK: mhlo.reduce(%[[ARG]] + %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> + func.return %0 : tensor<4x1xi1> +} + +// ----- + +// CHECK-LABEL: @any +func.func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[INIT:.*]] = mhlo.constant dense : tensor + // CHECK: mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.or across dimensions = [1] : (tensor<4x8xi1>, tensor) -> tensor<4xi1> + %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> + func.return %0 : tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: @any_keep_dim +func.func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { + // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1> + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> + func.return %0 : tensor<4x1xi1> +} + +// ----- + +// CHECK-LABEL: @any_dynamic +func.func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1> + // CHECK: mhlo.reduce(%[[ARG]] + %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> + func.return %0 : tensor<4x1xi1> +} + +//===----------------------------------------------------------------------===// +// Tile op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @tile_by_reshape +func.func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { + // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> + // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[BROADCASTED]] : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> + // CHECK: return %[[RESULT]] : tensor<28x24xf32> + %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> + %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32> + func.return %0 : tensor<28x24xf32> +} + +// ----- + +// CHECK-LABEL: func @tile_just_broadcast +func.func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { + // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<7x3xf32> + // CHECK: return %[[RESULT]] : tensor<7x3xf32> + %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> + %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32> + func.return %0 : tensor<7x3xf32> +} + +// ----- + +// CHECK-LABEL: func @tile_dynamic_shape +func.func @tile_dynamic_shape(%arg0: tensor) -> tensor { + %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi32> } : () -> tensor<2xi32> + // CHECK: tensor.dim {{.*}} : tensor + // CHECK: tensor.from_elements {{.*}} : tensor<4xindex> + // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor, tensor<4xindex>) -> tensor + // CHECK: muli {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<2xindex> + // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xindex>) -> tensor + %0 = "tf.Tile"(%arg0, %multiples) : (tensor, tensor<2xi32>) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// ArgMax/ArgMin op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0 +func.func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { + // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor + // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x7xi32> + // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) + // CHECK: (%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) + // CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[RESULT1:.*]] = mhlo.select %[[COMPARE]], %[[ARG1]], %[[ARG3]] : tensor, tensor + // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] + // CHECK: %[[RESULT2:.*]] = mhlo.select %[[COMPARE]], %[[ARG2]], %[[ARG4]] : tensor, tensor + // CHECK: %[[RESULT3:.*]] = mhlo.select %[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]] : tensor, tensor + // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor, tensor + // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> + %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor) -> tensor<7xi32> + func.return %0 : tensor<7xi32> +} + +// ----- + +// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1 +func.func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> { + // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor + // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xf32> -> tensor<2xindex> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x7xi64> + // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) + // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> + %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor) -> tensor<3xi64> + func.return %0 : tensor<3xi64> +} + +// ----- + +// CHECK-LABEL: func @argmax_i1_input_i64_output_axis_1 +func.func @argmax_i1_input_i64_output_axis_1(%arg0: tensor<3x7xi1>) -> tensor<3xi64> { + // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense : tensor + // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi1> -> tensor<2xindex> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x7xi64> + // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) + // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> + %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi1>, tensor) -> tensor<3xi64> + func.return %0 : tensor<3xi64> +} + +// ----- + +// CHECK-LABEL: func @argmax_dynamic_shape_input_output +func.func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor { + // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor + // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x?xi32> + // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) + // CHECK: return %[[REDUCE]]#1 : tensor + %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @argmax_dynamic_shape_input +func.func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { + // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor + // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x?xi32> + // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) + // CHECK: return %[[REDUCE]]#1 : tensor<3xi32> + %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor) -> tensor<3xi32> + func.return %0 : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: func @argmin_i64_input_i32_output_axis_0 +func.func @argmin_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { + // CHECK: %[[INIT:.*]] = mhlo.constant dense<9223372036854775807> : tensor + // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x7xi32> + // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) + // CHECK: (%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) + // CHECK: %[[COMPARE:.*]] = mhlo.compare LE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[RESULT1:.*]] = mhlo.select %[[COMPARE]], %[[ARG1]], %[[ARG3]] : tensor, tensor + // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] + // CHECK: %[[RESULT2:.*]] = mhlo.select %[[COMPARE]], %[[ARG2]], %[[ARG4]] : tensor, tensor + // CHECK: %[[RESULT3:.*]] = mhlo.select %[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]] : tensor, tensor + // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor, tensor + // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> + %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %0 = "tf.ArgMin"(%arg0, %axis) : (tensor<3x7xi64>, tensor) -> tensor<7xi32> + func.return %0 : tensor<7xi32> +} + +//===----------------------------------------------------------------------===// +// Random op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @rng_uniform +func.func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> + %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> + // CHECK: return %[[F32]] + func.return %0 : tensor<12x?x64xf32> +} + +// ----- + +// CHECK-LABEL: func @random_uniform_simple +func.func @random_uniform_simple(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> + %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> + // CHECK: return %[[F32]] + func.return %0 : tensor<12x?x64xf32> +} + +// ----- + +// CHECK-LABEL: func @random_uniform_with_seeds +func.func @random_uniform_with_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12x64xf32> { + // CHECK: %0 = mhlo.constant dense<[32, 12, 12, 64]> : tensor<4xi32> + // CHECK-NEXT: %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %2 = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-NEXT: %3 = mhlo.convert %0 : (tensor<4xi32>) -> tensor<4xi64> + // CHECK-NEXT: %4 = "mhlo.rng"(%1, %2, %3) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> + %cst = "tf.Const"() {value = dense<[32, 12, 12, 64]> : tensor<4xi32>} : () -> tensor<4xi32> + %0 = "tf.RandomUniform"(%cst) {seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<4xi32>) -> tensor<32x12x12x64xf32> + // CHECK: return %4 : tensor<32x12x12x64xf32> + func.return %0 : tensor<32x12x12x64xf32> +} + +// ----- + +// CHECK-LABEL: func @rng_std_normal +func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*NORMAL.*}} -> tensor<12x?x64xf32> + %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> + // CHECK: return %[[F32]] + func.return %0 : tensor<12x?x64xf32> +} + +//===----------------------------------------------------------------------===// +// Range op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @range +// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor +func.func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { + %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor + // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = array} + // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} + %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> + func.return %3 : tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: func @range_dynamic +// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor +func.func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 + // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]] + // CHECK-DAG: [[CONVERT_1:%.+]] = mhlo.convert [[ABS1]] + // CHECK-DAG: [[CONVERT_2:%.+]] = mhlo.convert %arg2 + // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT_1]], [[CONVERT_2]] + // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] + // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert [[CEIL]] + // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT_3]] + // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) <{iota_dimension = 0 : i64}> + // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 + // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} + %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[ADD]] + func.return %2 : tensor +} + +// ----- + +// CHECK-LABEL: func @range_int_dynamic +// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor +func.func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 + // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]] + // CHECK-DAG: [[CONVERT_1:%.+]] = mhlo.convert [[ABS1]] + // CHECK-DAG: [[CONVERT_2:%.+]] = mhlo.convert %arg2 + // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT_1]], [[CONVERT_2]] + // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] + // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert [[CEIL]] + // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT_3]] + // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) <{iota_dimension = 0 : i64}> + // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 + // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} + %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[ADD]] + func.return %2 : tensor +} + +// ----- + +// CHECK-LABEL: func @linspace_static +// CHECK-SAME: [[START:%.*]]: tensor, [[STOP:%.*]]: tensor +func.func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { + // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4> + // CHECK-DAG: [[NUM_F32:%.*]] = mhlo.convert [[NUM]] + // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> + // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] + // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] + // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] + // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = array} + // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} + // CHECK: return [[LINSPACE]] + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor + %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func @linspace_dynamic +func.func @linspace_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "tf.LinSpace" + %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @linspace_invalid_num +func.func @linspace_invalid_num(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: mhlo.constant dense<> : tensor<0xi32> + // CHECK: "tf.LinSpace" + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor<0xi32>) -> tensor + func.return %1 : tensor +} + +//===----------------------------------------------------------------------===// +// LegacyCall op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +func.func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { + func.return %arg0: tensor<10x2xf32> +} + +// CHECK-LABEL: testSimpleLegacyCallOp +func.func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { + // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32> + %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32> + // CHECK: return %[[RESULT]] + func.return %0: tensor<10x2xf32> +} + +// ----- + +func.func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { + func.return %arg0: tensor<10x2xf32> +} + +// CHECK-LABEL: testMultiInputLegacyCallOp +func.func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { + // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> + %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> + // CHECK: return %[[RESULT]] + func.return %0: tensor<10x2xf32> +} + +//===----------------------------------------------------------------------===// +// Conv op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: conv_simple +func.func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> { + + // CHECK: mhlo.convolution(%arg0, %arg1) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[0, 1], [2, 3]], rhs_dilate = [2, 3]} + // CHECK-SAME: batch_group_count = 1 + // CHECK-SAME: feature_group_count = 2 + + %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> + func.return %0 : tensor<256x8x7x16xf32> +} + +// ----- + +// CHECK-LABEL: conv3d_simple +func.func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> { + + // CHECK: mhlo.convolution(%arg0, %arg1) + // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] + // CHECK-SAME{LITERAL}: window = {stride = [5, 6, 7], pad = [[1, 2], [2, 3], [2, 3]], rhs_dilate = [2, 3, 4]} + // CHECK-SAME: batch_group_count = 1 + // CHECK-SAME: feature_group_count = 2 + + %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> + func.return %0 : tensor<256x7x6x5x16xf32> +} + +// ----- + +// CHECK-LABEL: depthwiseconv_simple +func.func @depthwiseconv_simple(%arg0: tensor, %arg1: tensor<2x2x3x3xf32>) -> tensor { + // CHECK: %[[RESHAPED_FILTER:.*]] = mhlo.reshape %arg1 : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> + // CHECK: mhlo.convolution(%arg0, %[[RESHAPED_FILTER]]) + // CHECK-SAME: feature_group_count = 3 + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + data_format = "NHWC", + device = "", + dilations = [1, 1, 1, 1], + explicit_paddings = [], + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor, tensor<2x2x3x3xf32>) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: conv_valid_padding +func.func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { + // CHECK: mhlo.convolution(%arg0, %arg1) + + %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> + func.return %0 : tensor<1x2x3x1xf32> +} + +// ----- + +// CHECK-LABEL: conv_explicit_paddings +func.func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> { + + // CHECK: mhlo.convolution(%arg0, %arg1) + // CHECK-SAME{LITERAL}: pad = [[6, 0], [3, 3]] + + %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> + func.return %0 : tensor<256x9x7x16xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_backprop_input_dynamic +func.func @conv2d_backprop_input_dynamic(%filter: tensor<2x2x1x16xf32>, %out_backprop: tensor) -> tensor { + // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> + // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} + // CHECK-SAME: batch_group_count = 1 : i64 + // CHECK-SAME: feature_group_count = 1 : i64 + // CHECK: return %[[RESULT]] + %cst_0_1d = "tf.Const"() {device = "", value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_1_0d = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %cst_1_1d = "tf.Const"() {device = "", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_512_0d = "tf.Const"() {device = "", value = dense<512> : tensor} : () -> tensor + %out_backprop_shape = "tf.Shape"(%out_backprop) {device = ""} : (tensor) -> tensor<4xi32> + %batch_size = "tf.StridedSlice"(%out_backprop_shape, %cst_0_1d, %cst_1_1d, %cst_1_1d) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %input_shape = "tf.Pack"(%batch_size, %cst_512_0d, %cst_512_0d, %cst_1_0d) {axis = 0 : i64, device = ""} : (tensor, tensor, tensor, tensor) -> tensor<4xi32> + %result = "tf.Conv2DBackpropInput"(%input_shape, %filter, %out_backprop) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<2x2x1x16xf32>, tensor) -> tensor + return %result : tensor +} + +// ----- + +// CHECK-LABEL: @conv2d_backprop_input +func.func @conv2d_backprop_input( + %filter: tensor<3x3x1x32xf32>, + %out_backprop: tensor<100x26x26x32xf32> + ) -> tensor<100x28x28x1xf32> { + // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> + // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: batch_group_count = 1 : i64 + // CHECK-SAME: feature_group_count = 1 : i64 + // CHECK: return %[[RESULT]] + %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32> + %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { + data_format = "NHWC", + dilations = [1, 1, 1, 1], + explicit_paddings = [], + padding = "VALID", + strides = [1, 1, 1, 1], + use_cudnn_on_gpu = true + } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_backprop_input_grouped +func.func @conv2d_backprop_input_grouped( + %filter: tensor<2x2x5x21xf32>, + %out_backprop: tensor<5x2x2x21xf32> + ) -> tensor<5x3x3x15xf32> { + %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32> + + // Verify filter transformation for grouped convolution. + + // CHECK: %[[RESHAPE:.*]] = mhlo.reshape %arg0 : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32> + // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]]) + // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]> + // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32> + // CHECK: mhlo.reshape %[[TRANSPOSE]] : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32> + + %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { + data_format = "NHWC", + dilations = [1, 1, 1, 1], + explicit_paddings = [], + padding = "VALID", + strides = [1, 1, 1, 1], + use_cudnn_on_gpu = true + } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32> + func.return %result : tensor<5x3x3x15xf32> +} + + +// CHECK-LABEL: @conv3d_backprop_input +func.func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { + // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> + // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, o, i]->[b, 0, 1, 2, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} + // CHECK-SAME: batch_group_count = 1 : i64, + // CHECK-SAME: feature_group_count = 1 : i64 + + // CHECK: return %[[RESULT]] + %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32> + %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> + func.return %result : tensor<2x8x8x8x1xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_backprop_filter +func.func @conv2d_backprop_filter( + %input: tensor<100x28x28x1xf32>, + %out_backprop: tensor<100x26x26x32xf32> + ) -> tensor<3x3x1x32xf32> { + // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) + // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: batch_group_count = 1 : i64 + // CHECK-SAME: feature_group_count = 1 : i64 + // CHECK: return %[[RESULT]] + %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32> + %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { + data_format = "NHWC", + dilations = [1, 1, 1, 1], + explicit_paddings = [], + padding = "VALID", + strides = [1, 1, 1, 1], + use_cudnn_on_gpu = true + } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32> + func.return %result : tensor<3x3x1x32xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_backprop_filter_grouped +func.func @conv2d_backprop_filter_grouped( + %input: tensor<1x2x2x2xf32>, + %out_backprop: tensor<1x1x1x2xf32> + ) -> tensor<2x2x1x2xf32> { + + // CHECK: mhlo.convolution(%arg0, %arg1) + // CHECK-SAME: batch_group_count = 2 : i64 + // CHECK-SAME: feature_group_count = 1 : i64 + + %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32> + %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { + data_format = "NHWC", + dilations = [1, 1, 1, 1], + explicit_paddings = [], + padding = "VALID", + strides = [1, 1, 1, 1], + use_cudnn_on_gpu = true + } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32> + func.return %result : tensor<2x2x1x2xf32> +} + + +// CHECK-LABEL: @conv3d_backprop_filter +func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> { + // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) + // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} + // CHECK-SAME: batch_group_count = 1 : i64 + // CHECK-SAME: feature_group_count = 1 : i64 + // CHECK: return %[[RESULT]] + %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> + %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> + func.return %result : tensor<3x3x3x1x6xf32> +} + +// ----- + +// CHECK-LABEL: @collective_permute +func.func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + %source_target_pairs = "tf.Const" () { + value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32> + } : () -> tensor<3x2xi32> + + // CHECK: "mhlo.collective_permute" + // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) { + } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32> + + func.return %0 : tensor<128x32xf32> +} + +// ----- + +// CHECK-LABEL: @cross_replica_sum +func.func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { + %replica_groups = "tf.Const" () { + value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> + } : () -> tensor<2x4xi32> + + // CHECK: mhlo.cross-replica-sum + // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32> + func.return %result : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: conv_dynamic +func.func @conv_dynamic(%arg0: tensor, %arg1: tensor<3x3x3x16xf32>) -> tensor { + // CHECK: "mhlo.dynamic_conv" + // CHECK-SAME: <{batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>}> : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor + %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor, tensor<3x3x3x16xf32>) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.Split legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @split_not_match_dynamic_split_dim_input +func.func @split_not_match_dynamic_split_dim_input(%input: tensor<4x4xf32>, %split_dim: tensor) -> (tensor<*xf32>, tensor<*xf32>) { + // CHECK: tf.Split + %0:2 = "tf.Split"(%split_dim, %input) : (tensor, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) + func.return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @split_not_match_dynamic_input_shape +func.func @split_not_match_dynamic_input_shape(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> + // CHECK: arith.divsi {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> + // CHECK: muli {{.*}} : index + // CHECK: muli {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> + %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) + func.return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32> +} + +// ----- + +// CHECK-LABEL: @split_not_match_static_split_dim_size +func.func @split_not_match_static_split_dim_size(%input: tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> + // CHECK: arith.divsi {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> + // CHECK: muli {{.*}} : index + // CHECK: muli {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> + %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) + func.return %0#0, %0#1 : tensor<2x?x4xf32>, tensor<2x?x4xf32> +} + +// ----- + +// CHECK-LABEL: @split_match_and_split_into_two +func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<2x6xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<2x6xf32> + %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) + // CHECK: return %[[ONE]], %[[TWO]] + func.return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> +} + +// ----- + +// CHECK-LABEL: @split_match_and_split_into_three +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) +func.func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> + %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) + // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] + func.return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> +} + +//===----------------------------------------------------------------------===// +// tf.TopKV2 legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: topk_v2_non_const_k +func.func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor) -> (tensor, tensor) { + // CHECK: tf.TopKV2 + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor) -> (tensor, tensor) + func.return %0#0, %0#1: tensor, tensor +} + +// ----- + +// CHECK-LABEL: topk_v2_unknown_input_last_dim +func.func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) { + %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor + // CHECK: tf.TopKV2 + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor) -> (tensor<16x?xf32>, tensor<16x?xi32>) + func.return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32> +} + +// ----- + +// CHECK-LABEL: topk_v2 +// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32> +func.func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { + %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor + + // CHECK: chlo.top_k(%[[INPUT]], k = 8) + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) + func.return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> +} + +//===----------------------------------------------------------------------===// +// tf.SplitV legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @splitv_match_and_split_into_three +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) +func.func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { + %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x1xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x3xf32> + %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) + // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] + func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> +} + +// ----- + +// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes +func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { + %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64> + // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64> + // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64> + %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) + func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> +} + +// ----- + +// CHECK-LABEL: @splitv_dynamic +func.func @splitv_dynamic(%input: tensor) -> (tensor, tensor, tensor) { + %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: tf.SplitV + %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) + func.return %0#0, %0#1, %0#2 : tensor, tensor, tensor +} + +//===----------------------------------------------------------------------===// +// tf.Assert legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @assert +func.func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { + // CHECK-NOT: tf.Assert + "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor, tensor<*xf32>) -> () + func.return +} + +//===----------------------------------------------------------------------===// +// tf.Unpack legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @unpack +func.func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { + // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES1:.*]] = mhlo.reshape %[[SLICE1]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES2:.*]] = mhlo.reshape %[[SLICE2]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES3:.*]] = mhlo.reshape %[[SLICE3]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> + + %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) + // return %[[RES1]], %[[RES2]], %[[RES3]] + func.return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32> +} + +// ----- + +// CHECK-LABEL: func @unpack_dynamic +func.func @unpack_dynamic(%arg0: tensor) -> (tensor, tensor) { + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor + // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> + // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xi32>) -> tensor + // CHECK: tensor.from_elements {{.*}} : tensor<3xi32> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor + // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> + // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xi32>) -> tensor + // CHECK: return {{.*}} : tensor, tensor + %0:2 = "tf.Unpack"(%arg0) {axis = -1 : i64} : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +//===----------------------------------------------------------------------===// +// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @unsorted_segment_sum +// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32> +// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32> +func.func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) { + %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) <{broadcast_sizes = dense<[4, 64]> : tensor<2xi64>}> : (tensor) -> tensor<4x64xf32> + // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: scatter_dimension_numbers = + // CHECK-SAME: update_window_dims = [2] + // CHECK-SAME: inserted_window_dims = [0] + // CHECK-SAME: scatter_dims_to_operand_dims = [0] + // CHECK-SAME: index_vector_dim = 2 + // CHECK-SAME: unique_indices = false + // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): + // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor + // CHECK: mhlo.return [[ADD]] + // CHECK-NEXT: (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> + // CHECK: return [[SCATTER]] + %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor) -> (tensor<4x64xf32>) + func.return %0: tensor<4x64xf32> +} + +// ----- + +// CHECK-LABEL: @unsorted_segment_prod +// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32> +// CHECK-SAME: [[SI:%.*]]: tensor +func.func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { + %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) <{broadcast_sizes = dense<[4, 64]> : tensor<2xi64>}> : (tensor) -> tensor<4x64xf32> + // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: scatter_dimension_numbers = + // CHECK-SAME: update_window_dims = [2] + // CHECK-SAME: inserted_window_dims = [0] + // CHECK-SAME: scatter_dims_to_operand_dims = [0] + // CHECK-SAME: index_vector_dim = 2 + // CHECK-SAME: unique_indices = false + // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): + // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor + // CHECK: mhlo.return [[MUL]] + // CHECK-NEXT: (tensor<4x64xf32>, tensor, tensor<8x?x64xf32>) -> tensor<4x?xf32> + // CHECK: return [[SCATTER]] + %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) + func.return %0: tensor<4x?xf32> +} + +// ----- + +// CHECK-LABEL: @unsorted_segment_min +func.func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { + %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + // CHECK: mhlo.constant dense<3.40282347E+38> : tensor + // CHECK: mhlo.scatter + // CHECK: mhlo.minimum + %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) + func.return %0: tensor<4x?xf32> +} + +// ----- + +// CHECK-LABEL: @unsorted_segment_max +func.func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { + %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor + // CHECK: mhlo.scatter + // CHECK: mhlo.maximum + %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) + func.return %0: tensor<4x?xf32> +} + +//===----------------------------------------------------------------------===// +// tf.GatherNd legalization +//===----------------------------------------------------------------------===// +// CHECK-LABEL: func @gatherNd_dynamic +func.func @gatherNd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tensor.dim + // CHECK: index_cast + // CHECK: tensor.from_elements + // CHECK: mhlo.dynamic_gather + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: offset_dims = [2] + // CHECK-SAME: collapsed_slice_dims = [0, 1] + // CHECK-SAME: start_index_map = [0, 1] + // CHECK-SAME: index_vector_dim = 2 + // CHECK-SAME: indices_are_sorted = false + %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @gatherNd_static +func.func @gatherNd_static(%arg0: tensor<2x4x128xf32>, %arg1: tensor<2x1xi32>) -> tensor<2x4x128xf32> { + // CHECK: "mhlo.gather"({{.*}}) <{ + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: offset_dims = [1, 2] + // CHECK-SAME: collapsed_slice_dims = [0] + // CHECK-SAME: start_index_map = [0] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: slice_sizes = dense<[1, 4, 128]> + // CHECK-SAME: (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> + %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> + func.return %0 : tensor<2x4x128xf32> +} + +//===----------------------------------------------------------------------===// +// tf.GatherV2 legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @gather_v2 +// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] +func.func @gather_v2(%params: tensor<16x2x3xf32>, %indices: tensor<16x5xi32>) -> tensor<16x2x5xf32> { + // CHECK: mhlo.torch_index_select + // CHECK-SAME: %[[PARAMS]], %[[INDICES]] + // CHECK-SAME: batch_dims = 1 + // CHECK-SAME: dim = 2 + %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> + %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> + func.return %1 : tensor<16x2x5xf32> +} + +// ----- + +// CHECK-LABEL: @gather_v2_dynamic +// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] +func.func @gather_v2_dynamic(%params: tensor, %indices: tensor) -> tensor { + // CHECK: mhlo.torch_index_select + // CHECK-SAME: %[[PARAMS]], %[[INDICES]] + // CHECK-SAME: batch_dims = 1 + // CHECK-SAME: dim = 2 + %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> + %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor + func.return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @gather_v2_dynamic_index_i64 +// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] +func.func @gather_v2_dynamic_index_i64(%params: tensor, %indices: tensor) -> tensor { + // CHECK: mhlo.torch_index_select + // CHECK-SAME: %[[PARAMS]], %[[INDICES]] + // CHECK-SAME: batch_dims = 1 + // CHECK-SAME: dim = 2 + %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> + %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor + func.return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @gather_v2_dynamic_shape +// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] +func.func @gather_v2_dynamic_shape(%params: tensor, %indices: tensor) -> tensor { + // CHECK: mhlo.torch_index_select + // CHECK-SAME: %[[PARAMS]], %[[INDICES]] + // CHECK-SAME: batch_dims = 1 + // CHECK-SAME: dim = 2 + %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> + %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor + func.return %1 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.StridedSliceGrad legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: strided_slice_grad +// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32> +func.func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> { + + // For StridedSlice + // Dim #: 0, 1, 2 + // Input shape: [4, 128, 1024] + // Begin: 1, 4, -3 + // End: 8, 65, 42 + // Stride: 1, 4, -1 + // Begin mask: 1, 0, 0 (= 1) + // End mask: 0, 0, 1 (= 4) + + // So result shape: + // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 + // Dim #1: 4 to 65 stride 4: so 16 + // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 + // result shape: [4, 16, 1022] + + // To pad back: + // Dim #: 0, 1, 2 + // Pad low: 0, 4, 0 + // Pad interm: 0, 3, 0 + // Pad high: 0, 63, 2 + + %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) + + // CHECK: [[RESHAPE:%.*]] = mhlo.reshape %arg0 : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> + // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) <{dimensions = dense<2> : tensor<1xi64>}> : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> + // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) <{edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>}> : (tensor<4x16x1022xf32>, tensor) -> tensor<4x128x1024xf32> + + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> + // CHECK: return [[PAD]] + func.return %0: tensor<4x128x1024xf32> +} + +// ----- + +// CHECK-LABEL: strided_slice_grad_shrink_axis_mask +// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32> +func.func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> { + // Input to StridedSlice was of shape 4x8xf32 + // Strided slice gets input[2:3, 0:8] + // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32 + // which is the shape of gradient. + // StridedSliceGrad would reshape the gradient to 1x8xf32 and + // then pad to match the shape of input 4x8xf32. + + %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<8xf32>) -> tensor<1x8xf32> + // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64> + // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64> + // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64> + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32> + + // CHECK: return [[PAD]] : tensor<4x8xf32> + func.return %0 : tensor<4x8xf32> +} + +// ----- + +// CHECK-LABEL: strided_slice_grad_new_axis_mask +// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32> +func.func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> { + // Input to StridedSlice was of shape 8xf32 + // Strided slice gets input[tf.new_axis, 2:4] + // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is + // 1x2xf32 which is the shape of gradient. + // StridedSliceGrad would reshape the gradient to 2xf32 and + // then pad to match the shape of input 4x8xf32. + + %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x2xf32>) -> tensor<2xf32> + // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64> + // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64> + // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64> + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32> + + // CHECK: return [[PAD]] : tensor<8xf32> + func.return %0 : tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: strided_slice_grad_ellipsis_mask +// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32> +func.func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> { + // Input to StridedSlice was of shape 4x4x8xf32 + // Strided slice gets input[2:4, ...] + // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and + // dim#2, ignoring begin and end indices for these dimensions. So the output + // is 2x4x8xf32 which is the shape of gradient. + // StridedSliceGrad would pad the gradient to match the shape of + // input 4x4x8xf32. + + %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> + // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64> + // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64> + // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64> + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32> + + // CHECK: return [[PAD]] : tensor<4x4x8xf32> + func.return %0 : tensor<4x4x8xf32> +} + + +// CHECK-LABEL: strided_slice_grad_all_masks +// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32> +func.func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> { + // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] + // New axis mask is at index 1 and 6 of sparse spec, so + // new_axis_mask = 2^1 + 2^6 = 66 + // The ellipsis mask is applied to dim #1, #2 of input i.e, we get + // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] + // The StridedSliceGrad op would propogate the gradient for the sliced tensor + // to the original input tensor by padding with zeroes. + + %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>) + %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) + %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) + %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) + + // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0) + // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> + // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // The edge_padding_low, edge_padding_high and interior_padding attributes of + // mhlo.pad would reflect the padding required to get the shape of the + // input of StridedSlice op. + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]]) + // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> + // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64> + // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64> + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> + + // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32> + func.return %0 : tensor<2x4x8x16x32x64xf32> +} + +// ----- + +// CHECK-LABEL: @tensor_scatter_update +func.func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: scatter_dimension_numbers + // CHECK-SAME: update_window_dims = [1] + // CHECK-SAME: inserted_window_dims = [0, 1] + // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: unique_indices = false + // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: mhlo.return %arg4 : tensor + // CHECK: }) + %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @tensor_scatter_update_scalar_update +func.func @tensor_scatter_update_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor) -> tensor<4x3xi32> { + // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> + // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi64>) -> tensor<2x3xi32> + // CHECK: "mhlo.scatter" + %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor) -> tensor<4x3xi32> + func.return %0 : tensor<4x3xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_scatter_add +func.func @tensor_scatter_add(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: scatter_dimension_numbers + // CHECK-SAME: update_window_dims = [1] + // CHECK-SAME: inserted_window_dims = [0, 1] + // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: unique_indices = false + // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.add %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) + %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @tensor_scatter_add_scalar_update +func.func @tensor_scatter_add_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor) -> tensor<4x3xi32> { + // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> + // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi64>) -> tensor<2x3xi32> + // CHECK: "mhlo.scatter + %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor) -> tensor<4x3xi32> + func.return %0 : tensor<4x3xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_scatter_sub +func.func @tensor_scatter_sub(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: scatter_dimension_numbers + // CHECK-SAME: update_window_dims = [1] + // CHECK-SAME: inserted_window_dims = [0, 1] + // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: unique_indices = false + // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.subtract %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) + %0 = "tf.TensorScatterSub"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @tensor_scatter_min +func.func @tensor_scatter_min(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: scatter_dimension_numbers + // CHECK-SAME: update_window_dims = [1] + // CHECK-SAME: inserted_window_dims = [0, 1] + // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: unique_indices = false + // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.minimum %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) + %0 = "tf.TensorScatterMin"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @tensor_scatter_max +func.func @tensor_scatter_max(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: scatter_dimension_numbers + // CHECK-SAME: update_window_dims = [1] + // CHECK-SAME: inserted_window_dims = [0, 1] + // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: unique_indices = false + // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.maximum %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) + %0 = "tf.TensorScatterMax"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.RandomShuffle legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @random_shuffle_num_elems_le_1 +func.func @random_shuffle_num_elems_le_1() -> tensor { + // CHECK: [[INPUT:%.*]] = mhlo.constant dense<1.000000e+20> : tensor + // CHECK-NEXT: return [[INPUT]] + %cst = "tf.Const"() {value = dense<1.000000e+20> : tensor} : () -> tensor + %0 = "tf.RandomShuffle"(%cst) {device = "", seed = -4294967297 : i64, seed2 = -2147483649 : i64} : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @random_shuffle_first_dim_1 +// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32> +func.func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> { + %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>) + // CHECK-NEXT: return [[INPUT]] + func.return %0: tensor<1x?xf32> +} + +// ----- + +// CHECK-LABEL: @random_shuffle_1D_16 +// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32> +func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { + // CHECK-DAG: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> + // CHECK-DAG: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor + // CHECK-DAG: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor + // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) <{rng_distribution = #mhlo.rng_distribution}> + // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) <{dimension = -1 : i64, is_stable = {{.*}}}> ({ + // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): + // CHECK: mhlo.compare LT, [[ARG1]], [[ARG2]], TOTALORDER + // CHECK: }) : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) + // CHECK: return [[SORT]]#1 + %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) + func.return %0: tensor<16xf32> +} + +// ----- + +// CHECK-LABEL: @random_shuffle_1D_10240 +func.func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { + // CHECK: mhlo.rng{{.*UNIFORM.*}} + // CHECK: mhlo.sort + // CHECK: mhlo.rng{{.*UNIFORM.*}} + // CHECK: mhlo.sort + %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) + func.return %0: tensor<10240xf32> +} + +// ----- + +// CHECK-LABEL: @random_shuffle_3D +// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> +func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { + // CHECK: [[INDICES:%.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> + + // CHECK-DAG: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> + // CHECK-DAG: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor + // CHECK-DAG: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor + // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) <{rng_distribution = #mhlo.rng_distribution}> + + // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor + + // CHECK: [[WHILE_OUT:%.*]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[IV_INIT]], [[ITER_ARG1:.*]] = [[SWAPS]], [[ITER_ARG2:.*]] = [[INDICES]]) + // CHECK: [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor + // CHECK: [[CMP:%.*]] = mhlo.compare LT, [[ITER_ARG0]], [[LIMIT]], NOTYPE + // CHECK: mhlo.return [[CMP]] + // CHECK: } do { + // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[ITER_ARG0]]) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<1xi32> + // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG0]]) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<1xi32> + // CHECK: [[SWP:%.*]] = mhlo.reshape [[SWP_IDX]] : (tensor<1xi32>) -> tensor + // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) <{slice_sizes = dense<1> : tensor<1xi64>}> + // CHECK: [[INDICES1:%.*]] = mhlo.dynamic_update_slice [[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG0]] : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> + // CHECK: [[INDICES2:%.*]] = mhlo.dynamic_update_slice [[INDICES1]], [[SRC_IDX]], [[SWP]] : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> + // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor + // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[ITER_ARG0]], [[ONE]] + // CHECK: mhlo.return [[NEW_IV]], [[ITER_ARG1]], [[INDICES2]] + // CHECK: } + + // CHECK: [[CONSTANT1:%.*]] = mhlo.constant dense<1> : tensor<1xi64> + // CHECK: [[ARITH_CONSTANT:%.*]] = arith.constant 1 : index + // CHECK: [[SHAPE_DIM:%.*]] = shape.dim %arg0, [[ARITH_CONSTANT]] : tensor<4x?x16xf32>, index -> index + // CHECK: [[INDEX_CAST:%.*]] = arith.index_cast [[SHAPE_DIM]] : index to i64 + // CHECK: [[FROM_ELEMENTS:%.*]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> + // CHECK: [[CONSTANT2:%.*]] = mhlo.constant dense<16> : tensor<1xi64> + // CHECK: [[CONCATENATE:%.*]] = "mhlo.concatenate"([[CONSTANT1]], [[FROM_ELEMENTS]], [[CONSTANT2]]) <{dimension = 0 : i64}> : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> + // CHECK: [[DYNAMIC_GATHER:%.*]] = "mhlo.dynamic_gather"([[INPUT]], [[WHILE_OUT]]#2, [[CONCATENATE]]) + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: offset_dims = [1, 2] + // CHECK-SAME: collapsed_slice_dims = [0] + // CHECK-SAME: start_index_map = [0] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME:: (tensor<4x?x16xf32>, tensor<4xi32>, tensor<3xi64>) -> tensor<4x?x16xf32> + + // CHECK: return [[DYNAMIC_GATHER]] + + %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>) + func.return %0: tensor<4x?x16xf32> +} + +//===----------------------------------------------------------------------===// +// tf.AvgPool legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @avgpool_valid_padding +// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> +// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: mhlo.return [[ADD]] +// CHECK: }) +// CHECK-SAME: -> tensor<2x3x5x7xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = array +// CHECK-SAME: -> tensor<2x3x5x7xf32> +// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] +// CHECK-SAME: -> tensor<2x3x5x7xf16> +// CHECK: return [[CONV16]] +func.func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> + func.return %0 : tensor<2x3x5x7xf16> +} + +// ----- + +// CHECK-LABEL: @avgpool_3d_valid_padding +// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> +// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: mhlo.return [[ADD]] +// CHECK: }) +// CHECK-SAME: -> tensor<2x4x3x5x7xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = array +// CHECK-SAME: -> tensor<2x4x3x5x7xf32> +// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] +// CHECK-SAME: -> tensor<2x4x3x5x7xf16> +// CHECK: return [[CONV16]] +func.func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> + func.return %0 : tensor<2x4x3x5x7xf16> +} + +// ----- + +// CHECK-LABEL: @avgpool_nchw_format +// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> +// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: mhlo.return [[ADD]] +// CHECK: }) +// CHECK-SAME: -> tensor<2x7x3x5xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = array +// CHECK-SAME: -> tensor<2x7x3x5xf32> +// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] +// CHECK-SAME: -> tensor<2x7x3x5xf16> +// CHECK: return [[CONV16]] +func.func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> + func.return %0 : tensor<2x7x3x5xf16> +} + +// ----- + +// CHECK-LABEL: @avgpool_3d_ncdhw_format +// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> +// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> +// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: mhlo.return [[ADD]] +// CHECK: }) +// CHECK-SAME: -> tensor<2x7x4x3x5xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = array +// CHECK-SAME: -> tensor<2x7x4x3x5xf32> +// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] +// CHECK-SAME: -> tensor<2x7x4x3x5xf16> +// CHECK: return [[CONV16]] +func.func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> + func.return %0 : tensor<2x7x4x3x5xf16> +} + +// ----- + +// CHECK-LABEL: @avgpool_same_padding( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM1]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x4x6x7xf32> +// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: mhlo.return %[[SUM2]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x4x6x7xf32> +// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> +// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> +// CHECK: } +func.func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> + func.return %0 : tensor<2x4x6x7xf32> +} + +// ----- + +// CHECK-LABEL: @avgpool_3d_same_padding( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM1]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x4x4x6x7xf32> +// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: mhlo.return %[[SUM2]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x4x4x6x7xf32> +// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] +// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> +// CHECK: } +func.func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> + func.return %0 : tensor<2x4x4x6x7xf32> +} + +//===----------------------------------------------------------------------===// +// AvgPoolGrad op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @avgpool_grad_valid_padding( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { +// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] +// CHECK-SAME: broadcast_dimensions = array +// CHECK-SAME: -> tensor<10x12x16x64xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> +// CHECK-SAME: -> tensor<10x25x33x64xf32> +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<10x24x32x64xf32> +// CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32> +func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) + %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { + data_format = "NHWC", + ksize = [1, 2, 2, 1], + padding = "VALID", + strides = [1, 2, 2, 1] + } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> + func.return %result : tensor<10x24x32x64xf32> +} + +// ----- + +// CHECK-LABEL: @avgpool_3d_grad_valid_padding( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { +// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = array} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> +// CHECK-SAME: -> tensor<10x8x25x33x64xf32> +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<10x8x24x32x64xf32> +// CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32> +func.func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>) + %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { + data_format = "NDHWC", + ksize = [1, 1, 2, 2, 1], + padding = "VALID", + strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> + func.return %result : tensor<10x8x24x32x64xf32> +} + +// ----- + +// CHECK-LABEL: @avgpool_grad_same_padding( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { +// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM1]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x4x7x9xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> +// CHECK-SAME: -> tensor<2x14x27x9xf32> +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: mhlo.return %[[SUM2]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x13x25x9xf32> +// CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32> +func.func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>) + %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { + data_format = "NHWC", + ksize = [1, 2, 3, 1], + padding = "SAME", + strides = [1, 4, 4, 1] + } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> + func.return %result : tensor<2x13x25x9xf32> +} + +// ----- + +// CHECK-LABEL: @avgpool_3d_grad_same_padding( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { +// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM1]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x8x4x7x9xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> +// CHECK-SAME: -> tensor<2x8x14x27x9xf32> +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: mhlo.return %[[SUM2]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x8x13x25x9xf32> +// CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32> +func.func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>) + %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { + data_format = "NDHWC", + ksize = [1, 1, 2, 3, 1], + padding = "SAME", + strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> + func.return %result : tensor<2x8x13x25x9xf32> +} + +// ----- + +// CHECK-LABEL: @avgpool_grad_nchw_format( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { +// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM1]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x9x4x7xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]> +// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> +// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> +// CHECK-SAME: -> tensor<2x9x14x27xf32> +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<1> +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: mhlo.return %[[SUM2]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x9x13x25xf32> +// CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32> +func.func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>) + %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { + data_format = "NCHW", + ksize = [1, 1, 2, 3], + padding = "SAME", + strides = [1, 1, 4, 4] + } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> + func.return %result : tensor<2x9x13x25xf32> +} + +// ----- + +// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { +// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM1]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x9x8x4x7xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]> +// CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> +// CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> +// CHECK-SAME: -> tensor<2x9x8x14x27xf32> +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: mhlo.return %[[SUM2]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<2x9x8x13x25xf32> +// CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32> +func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>) + %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { + data_format = "NCDHW", + ksize = [1, 1, 1, 2, 3], + padding = "SAME", + strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> + func.return %result : tensor<2x9x8x13x25xf32> +} + +// ----- + +// CHECK-LABEL: @avgpool_grad_bf16( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { +// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] +// CHECK-SAME: broadcast_dimensions = array +// CHECK-SAME: -> tensor<10x12x16x64xbf16> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> +// CHECK-SAME: -> tensor<10x25x33x64xbf16> +// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = mhlo.convert %[[REDUCE_WINDOW_INPUT]] : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> +// CHECK: %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[SUM]] : tensor +// CHECK: }) +// CHECK-SAME: -> tensor<10x24x32x64xf32> +// CHECK: %[[RESULT_CONVERTED:.*]] = mhlo.convert %[[RESULT]] : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> +// CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> +func.func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { + %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) + %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { + data_format = "NHWC", + ksize = [1, 2, 2, 1], + padding = "VALID", + strides = [1, 2, 2, 1] + } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> + func.return %result : tensor<10x24x32x64xbf16> +} + +// ----- + +// CHECK-LABEL: xla_sharding +func.func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { + // CHECK-NEXT: mhlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} + %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> + func.return %0 : tensor<4x16xf32> +} + +// ----- + +// CHECK-LABEL: inplace_update_one +func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { + // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> + // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> + // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] + // CHECK-DAG: [[UPDATE:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE2]], [[RESHAPE1]], [[CST]] + %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> + + // CHECK: return [[UPDATE]] + func.return %0 : tensor<8x4xf32> +} + +// ----- + +// CHECK-LABEL: inplace_update_three +func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { + // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> + // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> + // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> + // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> + // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> + // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> + // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] + // CHECK-DAG: [[RESHAPE2:%.+]] = mhlo.reshape [[SLICE2]] + // CHECK-DAG: [[RESHAPE3:%.+]] = mhlo.reshape [[SLICE3]] + // CHECK-DAG: [[UPDATE1:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]] + // CHECK-DAG: [[UPDATE2:%.+]] = mhlo.dynamic_update_slice [[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]] + // CHECK-DAG: [[UPDATE3:%.+]] = mhlo.dynamic_update_slice [[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]] + %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> + + // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> + func.return %0 : tensor<8x8x4xf32> +} + +// ----- + +// CHECK-LABEL: xla_dynamic_update_slice +func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { + // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> + // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor + // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> + // CHECK: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] : (tensor<1xi32>) -> tensor + // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]] : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> + // CHECK: return [[DUS]] + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> + func.return %0 : tensor<4x16xf32> +} + +// ----- + +// CHECK-LABEL: xla_dynamic_update_slice2 +func.func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { + // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor + // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]] : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK: return [[DUS]] + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +//===----------------------------------------------------------------------===// +// AllToAll op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @alltoall_basic +// See https://www.tensorflow.org/api_docs/python/tf/raw_ops/AllToAll +func.func @alltoall_basic(%input: tensor<1x2xf32>) -> tensor<2x1xf32> { + %group_assignment = "tf.Const" () { + value = dense<[[0, 1]]> : tensor<1x2xi32> + } : () -> tensor<1x2xi32> + %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 0 : i64, split_count = 2 : i64, split_dimension = 1 : i64} : (tensor<1x2xf32>, tensor<1x2xi32>) -> tensor<2x1xf32> + // CHECK: mhlo.all_to_all + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + func.return %result : tensor<2x1xf32> +} + + +//===----------------------------------------------------------------------===// +// Cumsum op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @cumsum_static +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> +func.func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: mhlo.return [[SUM]] : tensor + // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> + // CHECK: return [[CONVERT_REDUCE]] + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func @cumsum_exclusive +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> +func.func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: mhlo.return [[SUM]] : tensor + // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) <{edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> + // CHECK: return [[CONVERT_REDUCE]] + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func @cumsum_reverse +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> +func.func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: mhlo.return [[SUM]] : tensor + // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> + // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[REVERSE_BACK]] + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func @cumsum_exclusive_reverse +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> +func.func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: mhlo.return [[SUM]] : tensor + // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) <{edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> + // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[REVERSE_BACK]] + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func @cumsum_empty +func.func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + + // CHECK: mhlo.constant dense<> : tensor<0xf32> + %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor) -> tensor<0xf32> + func.return %1 : tensor<0xf32> +} + +// ----- + +// CHECK-LABEL: func @cumsum_dynamic +func.func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "tf.Cumsum" + %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// Cumprod op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @cumprod +func.func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) + // CHECK: mhlo.mul + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +//===----------------------------------------------------------------------===// +// tf.Softplus legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @softplus_f16 +// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>) +func.func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { + // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] + // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor + // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] + // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} + // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] + // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] + // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] + %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16> + + // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16> + func.return %0 : tensor<8x16xf16> +} + +// ----- + +// CHECK-LABEL: func @softplus_bf16 +// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>) +func.func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { + // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] + // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor + // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] + // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} + // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] + // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] + // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] + %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16> + + // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16> + func.return %0 : tensor<8x16xbf16> +} + +// ----- + +// CHECK-LABEL: func @softplus_f32 +// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>) +func.func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] + // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor + // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] + // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} + // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] + // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] + // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] + %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> + + // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// ----- + +// CHECK-LABEL: func @softplus_f64 +// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>) +func.func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { + // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] + // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor + // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] + // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} + // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] + // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] + // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] + %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64> + + // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> + func.return %0 : tensor<8x16xf64> +} + +// ----- + +// CHECK-LABEL: @xla_gather +func.func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> { + %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> + + // CHECK: "mhlo.gather" + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: offset_dims = [0, 1] + // CHECK-SAME: collapsed_slice_dims = [0] + // CHECK-SAME: start_index_map = [0, 1] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: indices_are_sorted = true + // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> + + %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> + func.return %0 : tensor<1x300x10xf32> +} + +// ----- + +// CHECK-LABEL: @xla_gather_i32 +func.func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> { + %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32> + + // CHECK: "mhlo.gather" + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: offset_dims = [0, 1] + // CHECK-SAME: collapsed_slice_dims = [0] + // CHECK-SAME: start_index_map = [0, 1] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: indices_are_sorted = true + // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> + + %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<1x300x10xf32> + func.return %0 : tensor<1x300x10xf32> +} + + +// CHECK: func @stridedslice_with_i32 +func.func @stridedslice_with_i32(%arg0: tensor) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} { +// CHECK-NOT: tf.StridedSlice +// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic_slice +// CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[DYNSLICE]] +// CHECK: return [[RESHAPE]] + %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf_type.shape<>], device = ""} : (tensor, tensor) -> tensor + %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf_type.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32> + func.return %6 : tensor<4xf32> +} + +func.func @replica_id() -> tensor { + // CHECK: %[[ID:.*]] = mhlo.replica_id : tensor + // CHECK: %[[RESULT:.*]] = mhlo.convert %0 : (tensor) -> tensor + %0 = "tf.XlaReplicaId"() : () -> tensor + func.return %0 : tensor +} + +// CHECK: func @angle_c64 +// CHECK-SAME: ([[ARG0:%.*]]: tensor>) +func.func @angle_c64(%arg0: tensor>) -> tensor { +// CHECK: [[IMAG:%.*]] = mhlo.imag [[ARG0]] +// CHECK: [[REAL:%.*]] = mhlo.real [[ARG0]] +// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]] + %0 = "tf.Angle"(%arg0): (tensor>) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.ApproximateEqual legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @approximateequal_f64 +func.func @approximateequal_f64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor + // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor + // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : (tensor) -> tensor + // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK: return %[[LE]] : tensor + %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor, tensor) -> tensor + func.return %equal : tensor +} + +// CHECK-LABEL: func @approximateequal_i32 +func.func @approximateequal_i32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor + // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor + // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : (tensor) -> tensor + // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK: return %[[LE]] : tensor + %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor, tensor) -> tensor + func.return %equal : tensor +} + +// CHECK-LABEL: func @approximateequal_complex64 +func.func @approximateequal_complex64(%arg0: tensor>, %arg1: tensor>) -> tensor { + // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor> + // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : (tensor>) -> tensor + // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : tensor + // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK: return %[[LE]] : tensor + %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor>, tensor>) -> tensor + func.return %equal : tensor +} + +//===----------------------------------------------------------------------===// +// tf.XlaConvV2 legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: xla_conv_v2 +func.func @xla_conv_v2(%lhs: tensor<8x4x16x16x16xf32>, %rhs: tensor<4x3x3x16x16xf32>) -> (tensor<4x4x14x14x16xf32>) { + %feature_group_count = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %rhs_dilation = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> + %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], window = {stride = [3, 1, 1], pad = {{\[\[}}0, 0], {{\[}}0, 0], {{\[}}0, 0]], lhs_dilate = [4, 1, 1], rhs_dilate = [1, 1, 1]} {batch_group_count = 2 : i64, feature_group_count = 1 : i64, precision_config = []} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>) -> tensor<4x4x14x14x16xf32> + %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {batch_group_count = 2 : i64, dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<4x4x14x14x16xf32> + func.return %0 : tensor<4x4x14x14x16xf32> +} + +//===----------------------------------------------------------------------===// +// tf.XlaDot legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @xladot_matmul( +// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> +func.func @xladot_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { + // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{ + // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< + // CHECK-NOT: lhs_batching_dimensions = + // CHECK-NOT: rhs_batching_dimensions = + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: precision_config = [] + %res = "tf.XlaDot"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> + func.return %res : tensor<64x16xi32> +} + +//===----------------------------------------------------------------------===// +// tf.XlaDotV2 legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @xladotv2_matmul( +// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> +func.func @xladotv2_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { + // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{ + // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< + // CHECK-NOT: lhs_batching_dimensions = + // CHECK-NOT: rhs_batching_dimensions = + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: precision_config = [] + %res = "tf.XlaDotV2"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> + func.return %res : tensor<64x16xi32> +} + +//===----------------------------------------------------------------------===// +// tf.XlaDynamicSlice legalization +//===----------------------------------------------------------------------===// +// ----- + +// CHECK-LABEL: xla_dynamic_slice_constant_start +func.func @xla_dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor + // CHECK-NEXT: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) + // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32> + %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: xla_dynamic_slice_i32_consts +func.func @xla_dynamic_slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor + // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> + %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: xla_dynamic_slice_constant_start_dynamic_shape +func.func @xla_dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice" + // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]]) + // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : + // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> + // CHECK: return %[[RESULT]] : tensor<1x4xi32> + %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) + %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) + %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// ----- + +// CHECK-LABEL: xla_dynamic_slice_variable_start +func.func @xla_dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor + // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1) + // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + // CHECK: return %[[RESULT]] : tensor<1x4xi32> + %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) + %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// ----- + +// CHECK-LABEL: xla_dynamic_slice_mhlo_sizes +func.func @xla_dynamic_slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { + // CHECK-NOT: "tf.XlaDynamicSlice" + %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tf.XlaDynamicSlice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> + func.return %1 : tensor<1x512x4xf32> +} + +//===----------------------------------------------------------------------===// +// tf.XlaEinsum legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @xlaeinsum +func.func @xlaeinsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { + // CHECK-NEXT: mhlo.einsum + %0 = "tf.XlaEinsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> + func.return %0: tensor<2x4xf32> +} + + +//===----------------------------------------------------------------------===// +// tf.XlaReduceWindow legalization +//===----------------------------------------------------------------------===// +// ----- +// CHECK-LABEL: @test_xla_reduce_window +func.func @test_xla_reduce_window(%arg0: tensor<7xf32>, %arg1: tensor) -> tensor<10xf32> { + %cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32> + %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_2 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[REDUCE:.*]] = "mhlo.reduce_window"(%arg0, %arg1) <{base_dilations = dense<3> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<4> : tensor<1xi64>, window_dimensions = dense<1> : tensor<1xi64>, window_strides = dense<2> : tensor<1xi64>}> ({ + // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer3(%[[ARG0]], %[[ARG1]]){{.*}} + // CHECK-NEXT: mhlo.return %[[SUM]] : tensor + // CHECK-NEXT: }) : (tensor<7xf32>, tensor) -> tensor<10xf32> + // CHECK-NEXT: return %[[REDUCE]] + %0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer3} : (tensor<7xf32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +func.func private @sum_reducer3(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.XlaSort legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @xlasort_int +// CHECK-SAME: %[[INPUT:.*]]: tensor<16xi32> +func.func @xlasort_int(%input: tensor<16xi32>) -> (tensor<16xi32>) { + // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = -1 : i64, is_stable = false}> ({ + // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], NOTYPE + // CHECK-NEXT: mhlo.return %[[CMP]] + // CHECK-NEXT: }) : (tensor<16xi32>) -> tensor<16xi32> + // CHECK-NEXT: return %[[SORT]] + %output = "tf.XlaSort"(%input) : (tensor<16xi32>) -> (tensor<16xi32>) + func.return %output : tensor<16xi32> +} + +// ----- + +// CHECK-LABEL: @xlasort_float +// CHECK-SAME: %[[INPUT:.*]]: tensor<8xf64> +func.func @xlasort_float(%input: tensor<8xf64>) -> (tensor<8xf64>) { + // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = -1 : i64, is_stable = false}> ({ + // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], TOTALORDER + // CHECK-NEXT: mhlo.return %[[CMP]] + // CHECK-NEXT: }) : (tensor<8xf64>) -> tensor<8xf64> + // CHECK-NEXT: return %[[SORT]] + %output = "tf.XlaSort"(%input) : (tensor<8xf64>) -> (tensor<8xf64>) + func.return %output : tensor<8xf64> +} + +// ----- + +// CHECK-LABEL: @xlasort_const +func.func @xlasort_const() -> (tensor<2x3xi64>) { + // CHECK: [2, 4, 3], [6, 5, 1] + %input = "tf.Const"() {value = dense<[[2, 4, 3], [6, 5, 1]]> : tensor<2x3xi64>} : () -> (tensor<2x3xi64>) + // CHECK-NEXT: [2, 3, 4], [1, 5, 6] + %output = "tf.XlaSort"(%input): (tensor<2x3xi64>) -> (tensor<2x3xi64>) + func.return %output : tensor<2x3xi64> +} + +//===----------------------------------------------------------------------===// +// tf.XlaRngBitGenerator legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @xla_rng_bit_generator +// CHECK-SAME: %[[STATE:.*]]: tensor<2xui64> +func.func @xla_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0,_retval1"}} { + // CHECK-NEXT: %0 = mhlo.constant dense<[10, 12]> : tensor<2xi32> + %cst = "tf.Const"() {value = dense<[10, 12]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK-NEXT: %1 = mhlo.constant dense<3> : tensor + %cst_0 = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + // CHECK-NEXT: %[[OUTPUT_STATE:.*]], %[[OUTPUT:.*]] = "mhlo.rng_bit_generator"(%[[STATE]]) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) + // CHECK-NEXT: return %[[OUTPUT_STATE]], %[[OUTPUT]] : tensor<2xui64>, tensor<10x12xui32> + %output_key, %output = "tf.XlaRngBitGenerator"(%cst_0, %arg0, %cst) : (tensor, tensor<2xui64>, tensor<2xi32>) -> (tensor<2xui64>, tensor<10x12xui32>) + func.return %output_key, %output : tensor<2xui64>, tensor<10x12xui32> +} + +//===----------------------------------------------------------------------===// +// tf.XlaVariadicV2 legalization +//===----------------------------------------------------------------------===// + +// ----- +// CHECK-LABEL: @xla_variadic_reduce_v2 +func.func @xla_variadic_reduce_v2(%arg0: tensor<2x3xcomplex>, %arg1: tensor>) -> tensor<3xcomplex> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { + // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1) + // CHECK-SAME: dimensions = [0] + // CHECK-NEXT: (%[[ARG0:.*]]: tensor>, %[[ARG1:.*]]: tensor>) + // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer(%[[ARG0]], %[[ARG1]]){{.*}} + // CHECK-NEXT: mhlo.return %[[SUM]] : tensor> + // CHECK: return %[[REDUCE]] + %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operandSegmentSizes = array, reducer = @sum_reducer} : (tensor<2x3xcomplex>, tensor>) -> tensor<3xcomplex> + func.return %0 : tensor<3xcomplex> +} + +func.func private @sum_reducer(%arg0: tensor>, %arg1: tensor>) -> tensor> { + %0 = "tf.AddV2"(%arg1, %arg0) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: @xla_variadic_reduce_v2_dynamic +func.func @xla_variadic_reduce_v2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { + // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1) + // CHECK-SAME: dimensions = [0] + // CHECK-NEXT: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer2(%[[ARG0]], %[[ARG1]]){{.*}} + // CHECK-NEXT: mhlo.return %[[SUM]] : tensor + // CHECK: return %[[REDUCE]] + %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operandSegmentSizes = array, reducer = @sum_reducer2} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +func.func private @sum_reducer2(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.AddV2"(%arg1, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.XlaVariadicSort legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @xla_variadic_sort +// CHECK-SAME: %[[INPUT:.*]]: tensor<2x3x4xui8> +func.func @xla_variadic_sort(%arg0: tensor<2x3x4xui8>) -> tensor<2x3x4xui8> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { + // CHECK-NEXT: {{.*}} = mhlo.constant dense<0> : tensor + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = 0 : i64, is_stable = false}> ({ + // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) + // CHECK-NEXT: %[[CMP:.*]] = func.call @compare_lt(%[[LHS]], %[[RHS]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: mhlo.return %[[CMP]] + // CHECK-NEXT: }) : (tensor<2x3x4xui8>) -> tensor<2x3x4xui8> + // CHECK-NEXT: return %[[SORT]] + %0 = "tf.XlaVariadicSort"(%arg0, %cst) {_XlaHasReferenceVars = false, comparator = @compare_lt, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", is_stable = false} : (tensor<2x3x4xui8>, tensor) -> tensor<2x3x4xui8> + func.return %0 : tensor<2x3x4xui8> +} + +func.func private @compare_lt(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._disable_call_shape_inference = true} { + %0 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.NextAfter legalization +//===----------------------------------------------------------------------===// +// CHECK-LABEL: func @nextafter +func.func @nextafter(%arg0: tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: %0 = chlo.broadcast_next_after %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + // CHECK-NEXT: return %0 : tensor<2xf32> + %0 = "tf.NextAfter"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0: tensor<2xf32> +} + +//===----------------------------------------------------------------------===// +// tf.XlaReduceScatter legalization +//===----------------------------------------------------------------------===// +// CHECK-LABEL: func @xla_reduce_scatter +func.func @xla_reduce_scatter(%arg0: tensor<128x128xf32>) -> tensor<64x128xf32> { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> + // CHECK: "mhlo.reduce_scatter"(%arg0) + // CHECK{LITERAL}: replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> + // CHECK-SAME: scatter_dimension = 0 + // + %1 = "tf.XlaReduceScatter"(%arg0, %cst_0, %cst) {reduce_op = "Add"} : (tensor<128x128xf32>, tensor<4x2xi32>, tensor) -> tensor<64x128xf32> + func.return %1 : tensor<64x128xf32> +} + + +//===----------------------------------------------------------------------===// +// tf.XlaSelectAndScatter legalization +//===----------------------------------------------------------------------===// +func.func @test_xla_select_and_scatter(%arg0: tensor<4x5x1x1xbf16>, %arg1: tensor<2x2x1x1xbf16>, %arg2: tensor) -> tensor { + %cst = "tf.Const"() {value = dense<0> : tensor<4x2xi32>} : () -> tensor<4x2xi32> + %cst_0 = "tf.Const"() {value = dense<[2, 2, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %cst_1 = "tf.Const"() {value = dense<[2, 3, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + // CHECK: %[[SELECT_AND_SCATTER:.*]] = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>, window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>}> ({ + // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + // CHECK-NEXT: %[[RES:.*]] = func.call @ge_select(%[[ARG0]], %[[ARG1]]){{.*}} + // CHECK-NEXT: mhlo.return %[[RES]] : tensor + // CHECK-NEXT: }, { + // CHECK-NEXT: ^{{.*}}(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor) + // CHECK-NEXT: %[[RES:.*]] = func.call @add_scatter(%[[ARG2]], %[[ARG3]]){{.*}} + // CHECK-NEXT: mhlo.return %[[RES]] : tensor + // CHECK-NEXT: }) : (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor) -> tensor + // CHECK-NEXT: return %[[SELECT_AND_SCATTER]] + %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) {scatter = @add_scatter, select = @ge_select} : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor) -> tensor + func.return %0 : tensor +} + +func.func private @add_scatter(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +func.func private @ge_select(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.GreaterEqual"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.XlaOptimizationBarrier legalization +//===----------------------------------------------------------------------===// + +func.func @test_xla_optimization_barrier(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) { + // CHECK: %[[OPT_BARRIER:.*]]:2 = mhlo.optimization_barrier %arg0, %arg1 + // CHECK-NEXT: return %[[OPT_BARRIER]]#0, %[[OPT_BARRIER]]#1 + %0, %1 = "tf.XlaOptimizationBarrier"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) + func.return %0, %1 : tensor<4x4xf32>, tensor<3x4xi32> +} + +// CHECK-LABEL: @ifRegion +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) +func.func @ifRegion(%arg0: tensor, %arg1: tensor) -> (tensor) { + // CHECK: [[VAL0:%.+]] = mhlo.compare GT, [[ARG0]], [[ARG1]] + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK: [[VAL1:%.+]] = "mhlo.if"([[VAL0]]) ({ + %1 = "tf.IfRegion"(%0) ({ + // CHECK: [[VAL2:%.+]] = mhlo.log [[ARG0]] + %2 = "tf.Log"(%arg0) : (tensor) -> tensor + // CHECK: mhlo.return [[VAL2]] + "tf.Yield"(%2) : (tensor) -> () + }, { + // CHECK: [[VAL3:%.+]] = mhlo.exponential [[ARG1]] + %2 = "tf.Exp"(%arg1) : (tensor) -> tensor + // CHECK: mhlo.return [[VAL3]] + "tf.Yield"(%2) : (tensor) -> () + // CHECK: }) : (tensor) -> tensor + }) {is_stateless = true} : (tensor) -> tensor + // CHECK: return [[VAL1]] + func.return %1 : tensor +} + +// CHECK-LABEL: func @caseRegion +// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor, [[ARG0:.+]]: tensor, [[ARG1:%.+]]: tensor) +func.func @caseRegion(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: [[VAL1:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]]) ({ + %0:2 = "tf.CaseRegion"(%index) ({ + // CHECK: [[VAL2:%.+]] = mhlo.exponential [[ARG1]] + %1 = mhlo.exponential %arg1 : (tensor) -> tensor + // CHECK: mhlo.return [[VAL2]], [[ARG1]] + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + }, { + // CHECK: [[VAL3:%.+]] = mhlo.log [[ARG0]] + %1 = mhlo.log %arg0 : (tensor) -> tensor + // CHECK: mhlo.return [[VAL3]], [[ARG1]] + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + }, { + // CHECK: [[VAL4:%.+]] = mhlo.floor [[ARG0]] + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + // CHECK: mhlo.return [[VAL4]], [[ARG1]] + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + // CHECK: }) : (tensor) -> (tensor, tensor) + }) {is_stateless = true} : (tensor) -> (tensor, tensor) + // CHECK: return [[VAL1]]#0, [[VAL1]]#1 : tensor, tensor + func.return %0#0, %0#1 : tensor, tensor +} + +// ----- + +// This test case also ensures the mhlo dialect is loaded as a dependency by the +// pass and hence the split here. + +// CHECK-LABEL: func @whileRegion +func.func @whileRegion() -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ + ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): + %3 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): + %4 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.Yield"(%4, %4, %4) : (tensor, tensor, tensor) -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + func.return %2#2 : tensor +} + +// ----- + +// CHECK-LABEL: func @whileRegionAdd +func.func @whileRegionAdd() -> tensor { + // CHECK: [[VAL0:%.+]] = mhlo.constant + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant + %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + // CHECK: [[VAL2:%.+]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[VAL0]], [[ITER_ARG1:.*]] = [[VAL1]], [[ITER_ARG2:.*]] = [[VAL0]]) + %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ + ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): + // CHECK: [[VAL3:%.+]] = mhlo.constant + %3 = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: [[VAL4:%.+]] = mhlo.compare LT, [[ITER_ARG2]], [[VAL3]] + %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK: mhlo.return [[VAL4]] + "tf.Yield"(%4) : (tensor) -> () + }, { + ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): + // CHECK: [[VAL5:%.+]] = mhlo.constant + %5 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: [[VAL6:%.+]] = mhlo.add [[ITER_ARG2]], [[VAL5]] + %6 = mhlo.add %barg2, %5 : tensor + // CHECK: [[VAL7:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL5]] + %7 = mhlo.add %barg0, %5 : tensor + // CHECK: mhlo.return [[VAL7]], [[ITER_ARG1]], [[VAL6]] + "tf.Yield"(%7, %barg1, %6) : (tensor, tensor, tensor) -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + // CHECK: return [[VAL2]]#2 + func.return %2#2 : tensor +} + +// ----- + +// CHECK-LABEL: func @whileRegionImplicitInputs +// CHECK-SAME: ([[ARG0:%.+]]: tensor) +func.func @whileRegionImplicitInputs(%arg0: tensor) -> tensor { + // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> + %0 = mhlo.constant dense<0> : tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> + %1 = mhlo.constant dense<-1> : tensor + // CHECK: [[VAL2:%.+]] = mhlo.while([[ITER_ARG0:.*]] = [[ARG0]]) + %2 = "tf.WhileRegion"(%arg0) ({ + ^cond(%carg0: tensor): + // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[ITER_ARG0]], [[VAL0]] + %3 = mhlo.compare LT, %carg0, %0 : (tensor, tensor) -> tensor + // CHECK: mhlo.return [[VAL3]] + "tf.Yield"(%3) : (tensor) -> () + }, { + ^body(%barg0: tensor): + // CHECK: [[VAL3:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL1]] + %3 = mhlo.add %barg0, %1 : tensor + // CHECK: [[VAL4:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL3]] + %4 = mhlo.add %barg0, %3 : tensor + // CHECK: mhlo.return [[VAL4]] + "tf.Yield"(%4) : (tensor) -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor) -> tensor + // CHECK: return [[VAL2]] + func.return %2 : tensor +} + +// CHECK-LABEL: func @whileRegionMultipleImplicitInputs +func.func @whileRegionMultipleImplicitInputs() { + // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> + %0 = mhlo.constant dense<0> : tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> + %1 = mhlo.constant dense<-1> : tensor + // CHECK: mhlo.while() + "tf.WhileRegion"() ({ + // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[VAL0]], [[VAL1]] + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK: mhlo.return [[VAL3]] + "tf.Yield"(%2) : (tensor) -> () + }, { + // CHECK: [[VAL3:%.+]] = mhlo.add [[VAL0]], [[VAL1]] + %2 = mhlo.add %0, %1 : tensor + // CHECK: mhlo.return + "tf.Yield"() : () -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> () + // CHECK: return + func.return +} From f734272ff2cdc3a1ade523a7351391d4046e7987 Mon Sep 17 00:00:00 2001 From: Michael Hu Date: Mon, 30 Sep 2024 16:14:19 -0700 Subject: [PATCH 446/483] Internal visibility change. PiperOrigin-RevId: 680755798 --- tensorflow/core/tfrt/graph_executor/BUILD | 1 + tensorflow/core/tfrt/runtime/BUILD | 1 + tensorflow/core/tfrt/saved_model/BUILD | 1 + 3 files changed, 3 insertions(+) diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 7c08bae7645d29..ca75590e998fd2 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -12,6 +12,7 @@ package( package_group( name = "friends", packages = [ + # copybara:uncomment "//cloud/ai/platform/prediction/...", # copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/...", # copybara:uncomment "//learning/brain/tfrt/...", # copybara:uncomment "//learning/serving/servables/tfrt/...", diff --git a/tensorflow/core/tfrt/runtime/BUILD b/tensorflow/core/tfrt/runtime/BUILD index dfdc99cde0ab9f..630336de3550dd 100644 --- a/tensorflow/core/tfrt/runtime/BUILD +++ b/tensorflow/core/tfrt/runtime/BUILD @@ -15,6 +15,7 @@ package_group( "//tensorflow/core/tfrt/...", "//tensorflow/core/runtime_fallback/...", # copybara:uncomment "//tensorflow_serving/...", + # copybara:uncomment "//cloud/ai/platform/prediction/...", # copybara:uncomment "//learning/brain/experimental/tfrt/...", # copybara:uncomment "//learning/brain/tfrt/...", # copybara:uncomment "//learning/infra/mira/...", diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index 5261546e6c6a0f..8673fae94ae2c8 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -10,6 +10,7 @@ package_group( name = "friends", packages = [ # Authorized users go here. + # copybara:uncomment "//cloud/ai/platform/prediction/...", # copybara:uncomment "//learning/brain/experimental/tfrt/...", # copybara:uncomment "//learning/brain/tfrt/...", # copybara:uncomment "//learning/infra/mira/...", From 605c30f71ae57a7507570069c4a72ac0bc181e02 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 16:28:38 -0700 Subject: [PATCH 447/483] Remove error-based bridge fallback PiperOrigin-RevId: 680760270 --- .../mlir/mlir_graph_optimization_pass.cc | 85 +++---------------- .../mlir/mlir_graph_optimization_pass_test.cc | 64 -------------- 2 files changed, 14 insertions(+), 135 deletions(-) diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 8d9802deaeaa66..03d3af71e6ecc5 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -245,17 +245,9 @@ Status MlirFunctionOptimizationPass::Run( timings.ReportAndStop(); if (!module_ref_status.ok()) { - // If at least one pass is enabled, return failure to the caller - // immediately. - if (overall_state == MlirOptimizationPassState::Enabled) { - return module_ref_status.status(); - } - // Do not fail, just keep the original TF graph unchanged in fallback mode. - LOG(WARNING) << "Failed to convert graph to MLIR: " - << module_ref_status.status() - << " , continuing without MlirOptimizationPass because " - "fallback enabled."; - return absl::OkStatus(); + LOG(ERROR) << "Failed to convert graph to MLIR: " + << module_ref_status.status(); + return module_ref_status.status(); } mlir::OwningOpRef module_ref = @@ -278,7 +270,7 @@ Status MlirFunctionOptimizationPass::Run( Status pass_status = absl::OkStatus(); auto pass_state = per_pass_state[per_pass_state_index++]; - if (pass_state == MlirOptimizationPassState::Enabled) { + if (pass_state != MlirOptimizationPassState::Disabled) { VLOG(2) << "Run MLIR graph optimization pass: " << StringRefToView(name); VLOG(2) << "Graph #nodes " << (*graph)->num_nodes() << " #edges " << (*graph)->num_edges(); @@ -293,51 +285,18 @@ Status MlirFunctionOptimizationPass::Run( << (*graph)->num_edges(); is_module_updated = true; } - } else if (pass_state == MlirOptimizationPassState::FallbackEnabled) { - VLOG(2) << "Run MLIR graph optimization pass with fallback: " - << StringRefToView(name); - VLOG(2) << "Graph #nodes " << (*graph)->num_nodes() << " #edges " - << (*graph)->num_edges(); - // Make sure when the pass is FallbackEnabled, it only modifies the MLIR - // module in case of no failures. - auto module_ref_clone = module_ref->clone(); - timings.Reset({kTfMlirCategory, name.str() + "_fallback"}); - pass_status = pass_registration.pass->Run( - function_name, config_proto, module_ref_clone, **graph, *flib_def); - timings.ReportAndStop(); - - if (pass_status.ok()) { - VLOG(2) << "Finished MLIR graph optimization pass with fallback: " - << StringRefToView(name); - VLOG(2) << "Graph #nodes " << (*graph)->num_nodes() << " #edges " - << (*graph)->num_edges(); - module_ref = module_ref_clone; - is_module_updated = true; - } else { - module_ref_clone->destroy(); - } } else { VLOG(2) << "MLIR graph optimization pass: " << StringRefToView(name) << " is disabled and will not be run."; } if (!pass_status.ok()) { - // If pass failed and it is: - // FallbackEnabled - only collect metrics, do not propagate - // error to the caller. - // Enabled - return error back to the caller. - if (pass_state == MlirOptimizationPassState::FallbackEnabled) { - LOG(WARNING) << StringRefToView(name) - << " pass failed, continuing without the pass because the " - "pass has fallback enabled"; - mlir_function_pass_fallback_count->GetCell(kFailure)->IncrementBy(1); - } else if (pass_state == MlirOptimizationPassState::Enabled) { + // If pass failed return error back to the caller. + if (pass_state != MlirOptimizationPassState::Disabled) { + LOG(INFO) << StringRefToView(name) + << " pass failed. Try to disbale it."; return pass_status; } - } else { - if (pass_state == MlirOptimizationPassState::FallbackEnabled) { - mlir_function_pass_fallback_count->GetCell(kSuccess)->IncrementBy(1); - } } if (DEBUG_DATA_DUMPER()->ShouldDump(function_name, kDebugGroupMain) || @@ -417,14 +376,9 @@ Status MlirV1CompatGraphOptimizationPass::Run( auto module_ref_status = ConvertGraphToMlir( **options.graph, debug_info, *options.flib_def, import_config, &context); if (!module_ref_status.ok()) { - if (pass_state == MlirOptimizationPassState::Enabled) { - return module_ref_status.status(); - } - LOG(WARNING) << "Failed to convert graph to MLIR: " - << module_ref_status.status() - << " , continuing without MlirOptimizationPass because " - "fallback enabled."; - return absl::OkStatus(); + LOG(ERROR) << "Failed to convert graph to MLIR: " + << module_ref_status.status(); + return module_ref_status.status(); } mlir::OwningOpRef module_ref = @@ -447,20 +401,9 @@ Status MlirV1CompatGraphOptimizationPass::Run( module_ref_clone->destroy(); if (!pass_status.ok()) { - if (pass_state == MlirOptimizationPassState::Enabled) return pass_status; - - if (pass_state == MlirOptimizationPassState::FallbackEnabled) { - LOG(WARNING) << StringRefToView(name) - << " pass failed, continuing without the pass because the " - "pass has fallback enabled"; - mlir_graph_optimization_pass_fallback_count->GetCell(kFailure) - ->IncrementBy(1); - return absl::OkStatus(); - } - } else { - if (pass_state == MlirOptimizationPassState::FallbackEnabled) { - mlir_graph_optimization_pass_fallback_count->GetCell(kSuccess) - ->IncrementBy(1); + if (pass_state == MlirOptimizationPassState::Disabled) { + LOG(INFO) << StringRefToView(name) << " pass failed. Try to disbale it."; + return pass_status; } } diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index 64e230f448f3fe..c299d72e8f600e 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -246,70 +246,6 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoFallback) { verifyCounters(); } -TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsDisabledFallback) { - Init(Status(absl::StatusCode::kAborted, "aborted"), - {MlirOptimizationPassState::Disabled, - MlirOptimizationPassState::FallbackEnabled}); - - // We expect the result graph to be exactly the same as the original graph - // so we define the `graph_` by the following `flib` in this test point - // instead of the way we do in the Init method. - FunctionDefLibrary flib; - *flib.add_function() = XTimesTwo(); - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - graph_ = std::make_unique(flib_def); - - GraphDef original_graph_def; - graph_->ToGraphDef(&original_graph_def); - AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, - Status(absl::StatusCode::kAborted, "aborted")); - - EXPECT_EQ( - function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, function_options_, &graph_, - flib_.get(), &control_ret_node_names_, &control_rets_updated_), - absl::OkStatus()); - verifyGraph(original_graph_def); - verifyCounters(); -} - -TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailFallback) { - Init(absl::OkStatus(), {MlirOptimizationPassState::FallbackEnabled}); - - GraphDef original_graph_def; - graph_->ToGraphDef(&original_graph_def); - - AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, - absl::OkStatus()); - EXPECT_EQ( - function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, function_options_, &graph_, - flib_.get(), &control_ret_node_names_, &control_rets_updated_), - absl::OkStatus()); - - verifyGraph(original_graph_def, true); - verifyCounters(); -} - -TEST_F(MlirGraphOptimizationPassTest, GraphDoesntConvertUpdatesCounter) { - Init(absl::OkStatus(), {MlirOptimizationPassState::FallbackEnabled}); - - graph_ = std::make_unique(OpRegistry::Global()); - control_ret_node_names_.push_back("foo"); - - AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, - absl::OkStatus()); - EXPECT_EQ( - function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, function_options_, &graph_, - flib_.get(), &control_ret_node_names_, &control_rets_updated_), - absl::OkStatus()); - - EXPECT_EQ(mlir_function_pass_graph_conversion_count_.Read(kOk), 0); - EXPECT_EQ(mlir_function_pass_graph_conversion_count_.Read(kInvalidArgument), - 1); -} - TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriorityFails) { MlirOptimizationPassRegistry::Global().Add( 0, std::make_unique>()); From f9d8824e7fa97aeda0dcd49dc60119cc9d87b651 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 30 Sep 2024 16:29:13 -0700 Subject: [PATCH 448/483] Move `tsl/profiler/utils` to `xla/tsl/profiler/utils` PiperOrigin-RevId: 680760468 --- tensorflow/core/framework/BUILD | 2 +- tensorflow/core/framework/allocator_test.cc | 2 +- tensorflow/core/profiler/BUILD | 2 +- tensorflow/core/profiler/backends/gpu/BUILD | 2 +- .../backends/gpu/device_tracer_test.cc | 2 +- tensorflow/core/profiler/convert/BUILD | 132 +++++------ .../core/profiler/convert/dcn_analysis.cc | 6 +- .../profiler/convert/dcn_analysis_test.cc | 4 +- .../convert/dcn_slack_analysis_combiner.cc | 2 +- tensorflow/core/profiler/convert/dcn_utils.cc | 4 +- tensorflow/core/profiler/convert/dcn_utils.h | 2 +- .../core/profiler/convert/dcn_utils_test.cc | 6 +- .../op_stats_to_input_pipeline_analysis.cc | 4 +- .../convert/op_stats_to_overview_page.cc | 6 +- .../convert/preprocess_single_host_xplane.cc | 6 +- .../profiler/convert/process_megascale_dcn.cc | 4 +- .../core/profiler/convert/repository.cc | 2 +- tensorflow/core/profiler/convert/repository.h | 2 +- .../convert/step_events_to_steps_db.cc | 2 +- .../core/profiler/convert/trace_viewer/BUILD | 10 +- .../convert/trace_viewer/trace_events.cc | 2 +- .../convert/trace_viewer/trace_events.h | 2 +- .../trace_viewer/trace_events_to_json.h | 2 +- .../convert/trace_viewer/trace_events_util.cc | 2 +- .../convert/trace_viewer/trace_events_util.h | 2 +- .../trace_viewer/trace_viewer_visibility.cc | 2 +- .../trace_viewer/trace_viewer_visibility.h | 2 +- .../trace_viewer_visibility_test.cc | 2 +- .../core/profiler/convert/xplane_to_hlo.cc | 2 +- .../convert/xplane_to_kernel_stats_db.cc | 4 +- .../convert/xplane_to_memory_profile.cc | 2 +- .../convert/xplane_to_memory_profile_test.cc | 2 +- .../convert/xplane_to_op_metrics_db.cc | 8 +- .../convert/xplane_to_op_metrics_db.h | 2 +- .../profiler/convert/xplane_to_op_stats.cc | 10 +- .../convert/xplane_to_op_stats_test.cc | 2 +- .../profiler/convert/xplane_to_step_events.cc | 10 +- .../convert/xplane_to_step_events_test.cc | 2 +- .../profiler/convert/xplane_to_step_stats.cc | 2 +- .../convert/xplane_to_tf_data_stats.cc | 8 +- .../convert/xplane_to_tf_functions.cc | 4 +- .../convert/xplane_to_tf_functions_test.cc | 2 +- .../convert/xplane_to_trace_container.cc | 12 +- .../convert/xspace_to_dcn_slack_analysis.cc | 14 +- .../convert/xspace_to_dcn_slack_analysis.h | 4 +- tensorflow/core/profiler/lib/BUILD | 2 +- tensorflow/core/profiler/lib/traceme.h | 2 +- tensorflow/core/profiler/rpc/BUILD | 4 +- tensorflow/core/profiler/utils/BUILD | 60 ++--- tensorflow/core/profiler/utils/cost_utils.cc | 2 +- .../core/profiler/utils/derived_timeline.cc | 14 +- .../core/profiler/utils/derived_timeline.h | 4 +- .../profiler/utils/derived_timeline_test.cc | 6 +- .../core/profiler/utils/device_caps_utils.cc | 2 +- tensorflow/core/profiler/utils/event_span.cc | 2 +- tensorflow/core/profiler/utils/event_span.h | 2 +- .../profiler/utils/hardware_type_utils.cc | 2 +- .../utils/hardware_type_utils_test.cc | 2 +- .../core/profiler/utils/hlo_proto_map.cc | 2 +- .../core/profiler/utils/host_offload_utils.cc | 2 +- tensorflow/core/profiler/utils/math_utils.h | 2 +- .../profiler/utils/op_metrics_db_utils.cc | 6 +- .../core/profiler/utils/op_metrics_db_utils.h | 2 +- tensorflow/core/profiler/utils/op_utils.cc | 2 +- tensorflow/core/profiler/utils/op_utils.h | 2 +- .../core/profiler/utils/step_intersection.cc | 2 +- tensorflow/core/profiler/utils/trace_utils.h | 2 +- .../core/profiler/utils/xplane_builder.h | 2 +- .../core/profiler/utils/xplane_schema.h | 2 +- .../core/profiler/utils/xplane_test_utils.h | 2 +- tensorflow/core/profiler/utils/xplane_utils.h | 2 +- .../core/profiler/utils/xplane_visitor.h | 2 +- tensorflow/python/profiler/internal/BUILD | 2 +- .../profiler/internal/profiler_pywrap_impl.cc | 2 +- .../third_party/tsl/tsl/platform/cloud/BUILD | 2 +- .../third_party/tsl/tsl/profiler/lib/BUILD | 8 +- .../tsl/tsl/profiler/lib/profiler_session.cc | 2 +- .../tsl/tsl/profiler/lib/traceme.h | 4 +- .../xla/xla/backends/profiler/cpu/BUILD | 24 +- .../xla/backends/profiler/cpu/host_tracer.cc | 6 +- .../backends/profiler/cpu/host_tracer_test.cc | 8 +- .../profiler/cpu/metadata_collector.cc | 6 +- .../backends/profiler/cpu/metadata_utils.h | 4 +- .../xla/xla/backends/profiler/gpu/BUILD | 34 +-- .../profiler/gpu/cupti_buffer_events.h | 4 +- .../backends/profiler/gpu/cupti_collector.cc | 10 +- .../profiler/gpu/cupti_error_manager_test.cc | 2 +- .../xla/backends/profiler/gpu/cupti_tracer.cc | 2 +- .../profiler/gpu/device_tracer_cuda.cc | 2 +- .../profiler/gpu/device_tracer_rocm.cc | 8 +- .../backends/profiler/gpu/rocm_collector.cc | 8 +- .../backends/profiler/gpu/rocm_collector.h | 2 +- .../xla/backends/profiler/gpu/rocm_tracer.cc | 2 +- .../xla/xla/backends/profiler/plugin/BUILD | 4 +- .../plugin/plugin_tracer_impl_test.cc | 4 +- .../xla/xla/backends/profiler/tpu/BUILD | 2 +- .../xla/backends/profiler/tpu/tpu_tracer.cc | 2 +- third_party/xla/xla/python/BUILD | 16 +- .../xla/xla/python/ifrt_proxy/client/BUILD | 2 +- .../python/ifrt_proxy/client/rpc_helper.cc | 2 +- .../xla/xla/python/profiler/internal/BUILD | 8 +- .../python/profiler/internal/python_hooks.cc | 8 +- .../python/xplane_to_profile_instructions.cc | 10 +- .../xplane_to_profile_instructions_test.cc | 6 +- .../xla/xla/service/gpu/fusions/triton/BUILD | 4 +- .../fusions/triton/kernel_name_tracer_cuda.cc | 2 +- third_party/xla/xla/tsl/lib/gtl/BUILD | 2 +- .../xla/xla/tsl/profiler/backends/cpu/BUILD | 22 +- .../backends/cpu/host_tracer_utils.cc | 8 +- .../backends/cpu/threadpool_listener.cc | 4 +- .../profiler/backends/cpu/traceme_recorder.cc | 4 +- .../backends/cpu/traceme_recorder_test.cc | 4 +- .../xla/xla/tsl/profiler/convert/BUILD | 26 +-- .../post_process_single_host_xplane.cc | 6 +- .../profiler/convert/trace_events_to_json.cc | 4 +- .../convert/xplane_to_trace_events.cc | 10 +- .../convert/xplane_to_trace_events_test.cc | 6 +- third_party/xla/xla/tsl/profiler/rpc/BUILD | 8 +- .../xla/xla/tsl/profiler/rpc/client/BUILD | 10 +- .../profiler/rpc/client/capture_profile.cc | 2 +- .../client/remote_profiler_session_manager.cc | 2 +- .../tsl/profiler/rpc/client/save_profile.cc | 2 +- .../tsl/profiler/rpc/profiler_service_impl.cc | 8 +- .../tsl => xla}/tsl/profiler/utils/BUILD | 210 +++++++++--------- .../tsl/profiler/utils/buffer_pool.cc | 2 +- .../tsl/profiler/utils/buffer_pool.h | 6 +- .../tsl/profiler/utils/buffer_pool_test.cc | 2 +- .../tsl/profiler/utils/device_utils.cc | 4 +- .../tsl/profiler/utils/device_utils.h | 6 +- .../tsl/profiler/utils/device_utils_test.cc | 4 +- .../tsl/profiler/utils/file_system_utils.h | 6 +- .../tsl/profiler/utils/format_utils.h | 6 +- .../tsl/profiler/utils/group_events.cc | 12 +- .../tsl/profiler/utils/group_events.h | 10 +- .../tsl/profiler/utils/group_events_test.cc | 12 +- .../tsl/profiler/utils/lock_free_queue.h | 8 +- .../profiler/utils/lock_free_queue_test.cc | 2 +- .../tsl/profiler/utils/math_utils.h | 6 +- .../tsl => xla}/tsl/profiler/utils/no_init.h | 6 +- .../tsl/profiler/utils/parse_annotation.cc | 2 +- .../tsl/profiler/utils/parse_annotation.h | 6 +- .../profiler/utils/parse_annotation_test.cc | 2 +- .../tsl/profiler/utils/per_thread.h | 6 +- .../tsl/profiler/utils/per_thread_test.cc | 2 +- .../tsl/profiler/utils/preprocess_xplane.cc | 6 +- .../tsl/profiler/utils/preprocess_xplane.h | 16 +- .../profiler/utils/preprocess_xplane_test.cc | 12 +- .../tsl/profiler/utils/session_manager.cc | 2 +- .../tsl/profiler/utils/session_manager.h | 6 +- .../tsl/profiler/utils/tf_op_utils.cc | 2 +- .../tsl/profiler/utils/tf_op_utils.h | 6 +- .../tsl/profiler/utils/tf_op_utils_test.cc | 2 +- .../tsl/profiler/utils/tf_xplane_visitor.h | 10 +- .../tsl/profiler/utils/time_utils.cc | 2 +- .../tsl/profiler/utils/time_utils.h | 8 +- .../tsl => xla}/tsl/profiler/utils/timespan.h | 8 +- .../tsl/profiler/utils/timespan_test.cc | 2 +- .../tsl/profiler/utils/timestamp_utils.cc | 8 +- .../tsl/profiler/utils/timestamp_utils.h | 6 +- .../profiler/utils/timestamp_utils_test.cc | 8 +- .../tsl/profiler/utils/tpu_xplane_utils.cc | 6 +- .../tsl/profiler/utils/tpu_xplane_utils.h | 6 +- .../profiler/utils/tpu_xplane_utils_test.cc | 8 +- .../tsl/profiler/utils/trace_utils.h | 6 +- .../tsl/profiler/utils/xplane_builder.cc | 6 +- .../tsl/profiler/utils/xplane_builder.h | 10 +- .../tsl/profiler/utils/xplane_builder_test.cc | 4 +- .../tsl/profiler/utils/xplane_mutators.h | 8 +- .../tsl/profiler/utils/xplane_schema.cc | 4 +- .../tsl/profiler/utils/xplane_schema.h | 6 +- .../tsl/profiler/utils/xplane_test_utils.cc | 8 +- .../tsl/profiler/utils/xplane_test_utils.h | 10 +- .../tsl/profiler/utils/xplane_utils.cc | 14 +- .../tsl/profiler/utils/xplane_utils.h | 12 +- .../tsl/profiler/utils/xplane_utils_test.cc | 12 +- .../tsl/profiler/utils/xplane_visitor.cc | 2 +- .../tsl/profiler/utils/xplane_visitor.h | 8 +- third_party/xla/xla/xla.bzl | 2 +- 178 files changed, 670 insertions(+), 670 deletions(-) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/BUILD (69%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/buffer_pool.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/buffer_pool.h (92%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/buffer_pool_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/device_utils.cc (92%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/device_utils.h (85%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/device_utils_test.cc (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/file_system_utils.h (91%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/format_utils.h (91%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/group_events.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/group_events.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/group_events_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/lock_free_queue.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/lock_free_queue_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/math_utils.h (94%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/no_init.h (89%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/parse_annotation.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/parse_annotation.h (89%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/parse_annotation_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/per_thread.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/per_thread_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/preprocess_xplane.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/preprocess_xplane.h (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/preprocess_xplane_test.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/session_manager.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/session_manager.h (91%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/tf_op_utils.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/tf_op_utils.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/tf_op_utils_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/tf_xplane_visitor.h (78%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/time_utils.cc (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/time_utils.h (87%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/timespan.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/timespan_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/timestamp_utils.cc (89%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/timestamp_utils.h (86%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/timestamp_utils_test.cc (86%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/tpu_xplane_utils.cc (92%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/tpu_xplane_utils.h (89%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/tpu_xplane_utils_test.cc (93%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/trace_utils.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_builder.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_builder.h (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_builder_test.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_mutators.h (88%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_schema.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_schema.h (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_test_utils.cc (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_test_utils.h (88%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_utils.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_utils.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_utils_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_visitor.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/profiler/utils/xplane_visitor.h (98%) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 5931211307f39b..263ac02be7b6aa 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -1410,7 +1410,7 @@ tf_cc_tests( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", - "@local_tsl//tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", ], ) diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc index 6557a4cec7598e..f1dd62af9e4a3c 100644 --- a/tensorflow/core/framework/allocator_test.cc +++ b/tensorflow/core/framework/allocator_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tensorflow/core/framework/typed_allocator.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index 45b60e7f264704..8b26a517b4bc0b 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -105,10 +105,10 @@ cc_library( deps = [ "//tensorflow/core/profiler/lib:profiler_factory_impl", "//tensorflow/core/profiler/lib:profiler_session_impl", - "@local_tsl//tsl/profiler/utils:time_utils_impl", "@local_xla//xla/tsl/profiler/backends/cpu:annotation_stack_impl", "@local_xla//xla/tsl/profiler/backends/cpu:threadpool_listener", "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", + "@local_xla//xla/tsl/profiler/utils:time_utils_impl", ], alwayslink = True, ) diff --git a/tensorflow/core/profiler/backends/gpu/BUILD b/tensorflow/core/profiler/backends/gpu/BUILD index 94afc773e15835..8c082a82dcdcbf 100644 --- a/tensorflow/core/profiler/backends/gpu/BUILD +++ b/tensorflow/core/profiler/backends/gpu/BUILD @@ -42,10 +42,10 @@ tf_cuda_cc_test( "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/backends/profiler/gpu:cuda_test", "@local_xla//xla/backends/profiler/gpu:cupti_collector", "@local_xla//xla/backends/profiler/gpu:device_tracer", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cupti_headers", diff --git a/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc b/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc index 4030e1608a4ea7..e0b7c9b2e060a6 100644 --- a/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda_runtime.h" #include "xla/backends/profiler/gpu/cupti_collector.h" #endif // GOOGLE_CUDA +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/common_runtime/direct_session.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/graph.pb.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" // TODO(b/186367334) #define CUPTI_NVBUG_3299481_WAR (10000 <= CUDA_VERSION && CUDA_VERSION < 11000) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index b1cafe647a323e..fff53bce2aa3c8 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -30,10 +30,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -139,9 +139,9 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:format_utils", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:format_utils", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -240,8 +240,8 @@ cc_library( "//tensorflow/core/profiler/utils:op_metrics_db_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:format_utils", - "@local_tsl//tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:format_utils", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -299,7 +299,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -333,11 +333,11 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", ], ) @@ -384,8 +384,8 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -406,11 +406,11 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -429,7 +429,7 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:group_events", ], ) @@ -448,8 +448,8 @@ cc_library( "//tensorflow/core/profiler/utils:trace_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -489,8 +489,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -511,7 +511,7 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -537,7 +537,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -556,7 +556,7 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:group_events", ], ) @@ -610,9 +610,9 @@ cc_library( "//tensorflow/core/profiler/utils:derived_timeline", "//tensorflow/core/profiler/utils:xplane_schema", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:preprocess_xplane", - "@local_tsl//tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:preprocess_xplane", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", ], ) @@ -679,10 +679,10 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -722,7 +722,7 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -800,8 +800,8 @@ cc_library( "//tensorflow/core/profiler/utils:hlo_proto_map", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", "@local_xla//xla/service:hlo_proto_cc", + "@local_xla//xla/tsl/profiler/utils:file_system_utils", ], ) @@ -886,7 +886,7 @@ cc_library( "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", + "@local_xla//xla/tsl/profiler/utils:file_system_utils", ], ) @@ -949,12 +949,12 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:trace_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -965,8 +965,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -976,9 +976,9 @@ tf_cc_test( deps = [ ":dcn_utils", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_builder", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -993,9 +993,9 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -1007,8 +1007,8 @@ cc_library( ":dcn_analysis", "//tensorflow/core/profiler/utils:xplane_utils", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", ], ) @@ -1019,8 +1019,8 @@ tf_cc_test( ":dcn_analysis", ":dcn_utils", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -1042,17 +1042,17 @@ cc_library( "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_visitor", "@local_xla//xla:shape_util", "@local_xla//xla:side_effect_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -1063,7 +1063,7 @@ cc_library( deps = [ "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:math_utils", ], ) diff --git a/tensorflow/core/profiler/convert/dcn_analysis.cc b/tensorflow/core/profiler/convert/dcn_analysis.cc index fd54adf88d3fb3..5c58cda325cf33 100644 --- a/tensorflow/core/profiler/convert/dcn_analysis.cc +++ b/tensorflow/core/profiler/convert/dcn_analysis.cc @@ -22,11 +22,11 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/convert/dcn_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_analysis_test.cc b/tensorflow/core/profiler/convert/dcn_analysis_test.cc index f89df221444d4b..345b3752b637aa 100644 --- a/tensorflow/core/profiler/convert/dcn_analysis_test.cc +++ b/tensorflow/core/profiler/convert/dcn_analysis_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/convert/dcn_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc b/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc index 32a91c836c1497..6806742f5cec8b 100644 --- a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc +++ b/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tsl/profiler/utils/math_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_utils.cc b/tensorflow/core/profiler/convert/dcn_utils.cc index 98ecf7dc14106e..7b41905265c385 100644 --- a/tensorflow/core/profiler/convert/dcn_utils.cc +++ b/tensorflow/core/profiler/convert/dcn_utils.cc @@ -17,8 +17,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_utils.h b/tensorflow/core/profiler/convert/dcn_utils.h index 1149daa4b62be7..e0dd3a174df919 100644 --- a/tensorflow/core/profiler/convert/dcn_utils.h +++ b/tensorflow/core/profiler/convert/dcn_utils.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tsl/profiler/utils/xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_utils_test.cc b/tensorflow/core/profiler/convert/dcn_utils_test.cc index 27c74d79c66407..1d31fcd1502a6d 100644 --- a/tensorflow/core/profiler/convert/dcn_utils_test.cc +++ b/tensorflow/core/profiler/convert/dcn_utils_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index 39b0ef3aebfda6..e13e0cb73a2ab5 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -28,6 +28,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/format_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" @@ -45,8 +47,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/html_utils.h" #include "tensorflow/core/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/profiler/utils/format_utils.h" -#include "tsl/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc index 0e0ad42b20a4da..57b974005c3001 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc @@ -22,6 +22,9 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/profiler/utils/format_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_to_record.h" #include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" @@ -42,9 +45,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/format_utils.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc index 043ae143dc969b..760e6439e90e9a 100644 --- a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc +++ b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc @@ -16,11 +16,11 @@ limitations under the License. #include +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/preprocess_xplane.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/derived_timeline.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/preprocess_xplane.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/process_megascale_dcn.cc b/tensorflow/core/profiler/convert/process_megascale_dcn.cc index 947c5e54a19568..2d8313bfc9cb82 100644 --- a/tensorflow/core/profiler/convert/process_megascale_dcn.cc +++ b/tensorflow/core/profiler/convert/process_megascale_dcn.cc @@ -16,10 +16,10 @@ limitations under the License. #include +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" #include "tensorflow/core/profiler/convert/dcn_analysis.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/repository.cc b/tensorflow/core/profiler/convert/repository.cc index abc4a994325d39..fa6f52d3a76754 100644 --- a/tensorflow/core/profiler/convert/repository.cc +++ b/tensorflow/core/profiler/convert/repository.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/path.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/statusor.h" #include "tsl/platform/errors.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/repository.h b/tensorflow/core/profiler/convert/repository.h index 55d33af3d4bfbb..d2569467376859 100644 --- a/tensorflow/core/profiler/convert/repository.h +++ b/tensorflow/core/profiler/convert/repository.h @@ -26,13 +26,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/statusor.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc index e04e1f47412adc..46fcb4d9473161 100644 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc +++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/BUILD b/tensorflow/core/profiler/convert/trace_viewer/BUILD index eef42433fbdadd..0d426cecf12e6b 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/BUILD +++ b/tensorflow/core/profiler/convert/trace_viewer/BUILD @@ -28,7 +28,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -40,7 +40,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -76,7 +76,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/profiler/lib:context_types_hdrs", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -98,7 +98,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -127,7 +127,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:context_types_hdrs", - "@local_tsl//tsl/profiler/utils:timespan", "@local_xla//xla/tsl/lib/io:iterator", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc index 828550de9a4594..994de7aa3fc7c6 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/tsl/lib/io/table.h" #include "xla/tsl/lib/io/table_builder.h" #include "xla/tsl/lib/io/table_options.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -46,7 +47,6 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events.h index e5a76838a6fa6b..cbed82e0e51142 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events.h @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "xla/tsl/lib/io/table.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tsl/platform/file_system.h" #include "tsl/platform/status.h" #include "tsl/profiler/lib/context_types.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h index f5585cb3eb9b08..965b04fbe08edb 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h @@ -36,6 +36,7 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/time/time.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_color.h" #include "tensorflow/core/profiler/lib/context_types.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/profiler/lib/context_types.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.cc index e5f84d14efb270..1080db88ee9d0e 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.cc @@ -20,8 +20,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h index 0d7e8721e2a6c6..4f0e1dc838b830 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h @@ -22,8 +22,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.cc index 5fe66cd7182f00..c51f18043aa480 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "absl/log/check.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h index 4257384bcf88b9..da503d417c81e7 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h @@ -24,9 +24,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility_test.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility_test.cc index 60a3cdfd939801..e9c4dce6d17a4c 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility_test.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.cc b/tensorflow/core/profiler/convert/xplane_to_hlo.cc index f4af2784f039f6..792a31701f3dd5 100644 --- a/tensorflow/core/profiler/convert/xplane_to_hlo.cc +++ b/tensorflow/core/profiler/convert/xplane_to_hlo.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/service/hlo.pb.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/core/profiler/convert/repository.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc index 2f2349884d52e3..2f1f14045567df 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -27,8 +29,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc index b289f54baa67f0..59670794734f2c 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc index 527735095e088c..8d0415db234e94 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/group_events.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/utils/group_events.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index 8b228479872bcc..0384b80e82d980 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -28,6 +28,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -41,10 +45,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h index 126f0118da1b60..c5d2a229d52bc4 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h @@ -18,13 +18,13 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/op_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 5774010604b4d0..cca87ccc3f668c 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -22,6 +22,11 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" #include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" @@ -42,11 +47,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc index 01bfb6c9c2f575..68c0b29a9f4481 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/group_events.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/platform/status.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/group_events.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc index e1591a43195e54..47d1aa8c5f3588 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc @@ -25,6 +25,11 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -33,11 +38,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc index cf4d2b0af40b06..7f6069ec0f511f 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/profiler/utils/group_events.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/utils/group_events.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc index ff084d1e03e1e5..94c7a3da15adfe 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc index 7159570eee4cfe..c566235840ffd3 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc @@ -25,15 +25,15 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tensorflow/core/profiler/utils/html_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc index 533d937792362f..58b64a1696ed9d 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc @@ -26,6 +26,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc index ca85cb6005d32e..c7127b80212372 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/profiler/protobuf/tf_function.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc index cfb4f2ec20cf4b..d442c9eca5047a 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc +++ b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc @@ -22,16 +22,16 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_event_arguments_builder.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc index 82cf25f4e2b180..3a5aeb78da5c14 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc @@ -32,6 +32,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" #include "xla/side_effect_util.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" @@ -43,13 +50,6 @@ limitations under the License. #include "tsl/platform/regexp.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h index daac70f634abca..2f9e5551449cf5 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h @@ -28,13 +28,13 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" #include "tensorflow/core/profiler/protobuf/topology.pb.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 55d726723fc752..6b1ca8e6be8744 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -144,7 +144,7 @@ cc_library( "@local_tsl//tsl/profiler/lib:traceme_encode", ] + if_not_android([ "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder", - "@local_tsl//tsl/profiler/utils:time_utils", + "@local_xla//xla/tsl/profiler/utils:time_utils", ]), ) diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h index 51e7e8ba5fbbe7..23e48948095811 100644 --- a/tensorflow/core/profiler/lib/traceme.h +++ b/tensorflow/core/profiler/lib/traceme.h @@ -20,7 +20,7 @@ limitations under the License. #include "tsl/profiler/lib/traceme.h" #if !defined(IS_MOBILE_PLATFORM) -#include "tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" #endif // TODO: b/323943471 - This macro should eventually be provided by Abseil. diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD index 96d9a50408aa1a..89e9735fbc2190 100644 --- a/tensorflow/core/profiler/rpc/BUILD +++ b/tensorflow/core/profiler/rpc/BUILD @@ -53,9 +53,9 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", - "@local_tsl//tsl/profiler/utils:time_utils", "@local_xla//xla/tsl/profiler/rpc:profiler_service_impl", + "@local_xla//xla/tsl/profiler/utils:file_system_utils", + "@local_xla//xla/tsl/profiler/utils:time_utils", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 258c83545fe55e..4d9490220cb180 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -43,7 +43,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -58,7 +58,7 @@ cc_library( "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:math_utils", ], ) @@ -70,7 +70,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@local_tsl//tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:math_utils", ], ) @@ -79,7 +79,7 @@ cc_library( hdrs = ["math_utils.h"], deps = [ "@com_google_absl//absl/base:core_headers", - "@local_tsl//tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:math_utils", ], ) @@ -104,9 +104,9 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -133,8 +133,8 @@ cc_library( "//tensorflow/core/profiler/convert:op_metrics_db_combiner", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -143,7 +143,7 @@ cc_library( hdrs = ["trace_utils.h"], copts = tf_profiler_copts(), deps = [ - "@local_tsl//tsl/profiler/utils:trace_utils", + "@local_xla//xla/tsl/profiler/utils:trace_utils", ], ) @@ -153,7 +153,7 @@ cc_library( copts = tf_profiler_copts(), visibility = [":friends"], deps = [ - "@local_tsl//tsl/profiler/utils:xplane_builder", + "@local_xla//xla/tsl/profiler/utils:xplane_builder", ], ) @@ -164,7 +164,7 @@ cc_library( visibility = [":friends"], deps = [ "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -174,7 +174,7 @@ cc_library( copts = tf_profiler_copts(), visibility = [":friends"], deps = [ - "@local_tsl//tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", ], ) @@ -190,7 +190,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", - "@local_tsl//tsl/profiler/utils:xplane_test_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_test_utils", ], ) @@ -200,7 +200,7 @@ cc_library( copts = tf_profiler_copts(), visibility = [":friends"], deps = [ - "@local_tsl//tsl/profiler/utils:xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -222,7 +222,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", ], ) @@ -241,8 +241,8 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/utils:timespan", "@local_xla//xla:shape_util", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -269,14 +269,14 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/convert:xla_op_utils", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:trace_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -297,9 +297,9 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -358,7 +358,7 @@ cc_library( "//tensorflow/core/platform:types", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -386,7 +386,7 @@ cc_library( ":xplane_visitor", "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -419,9 +419,9 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/tsl/profiler/convert:xla_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) diff --git a/tensorflow/core/profiler/utils/cost_utils.cc b/tensorflow/core/profiler/utils/cost_utils.cc index 2cbd2590f0c525..f1899f17ac30fd 100644 --- a/tensorflow/core/profiler/utils/cost_utils.cc +++ b/tensorflow/core/profiler/utils/cost_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index 5cf4b280ec2961..9da9ac89efe9e9 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -26,6 +26,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -40,13 +47,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h index 535f5041269a21..7fd06c9ab2f42a 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.h +++ b/tensorflow/core/profiler/utils/derived_timeline.h @@ -24,11 +24,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc index 3bc01244a372ee..edda0d4673cdda 100644 --- a/tensorflow/core/profiler/utils/derived_timeline_test.cc +++ b/tensorflow/core/profiler/utils/derived_timeline_test.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -29,9 +32,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/device_caps_utils.cc b/tensorflow/core/profiler/utils/device_caps_utils.cc index e795081311f28c..b01bc35f72b3bc 100644 --- a/tensorflow/core/profiler/utils/device_caps_utils.cc +++ b/tensorflow/core/profiler/utils/device_caps_utils.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/core/profiler/utils/device_caps_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc index cc9e2ed044361b..27ddddf1e4d195 100644 --- a/tensorflow/core/profiler/utils/event_span.cc +++ b/tensorflow/core/profiler/utils/event_span.cc @@ -22,10 +22,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h index 20c8643c5df722..4100390b88959b 100644 --- a/tensorflow/core/profiler/utils/event_span.h +++ b/tensorflow/core/profiler/utils/event_span.h @@ -21,10 +21,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc index 35bc8a9def667b..85cdb13a03ccea 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils.cc +++ b/tensorflow/core/profiler/utils/hardware_type_utils.cc @@ -19,11 +19,11 @@ limitations under the License. #include "absl/container/btree_map.h" #include "absl/strings/match.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/math_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/hardware_type_utils_test.cc b/tensorflow/core/profiler/utils/hardware_type_utils_test.cc index f97ccc6fecd40a..9476848a650dcc 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils_test.cc +++ b/tensorflow/core/profiler/utils/hardware_type_utils_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/core/profiler/utils/hardware_type_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tensorflow/core/platform/test.h" -#include "tsl/profiler/utils/math_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.cc b/tensorflow/core/profiler/utils/hlo_proto_map.cc index 2269acb66eb4c5..bdb16fca3c3fa9 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_map.cc +++ b/tensorflow/core/profiler/utils/hlo_proto_map.cc @@ -29,10 +29,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/host_offload_utils.cc b/tensorflow/core/profiler/utils/host_offload_utils.cc index 44b2eca6dca1ad..312b47f168cc44 100644 --- a/tensorflow/core/profiler/utils/host_offload_utils.cc +++ b/tensorflow/core/profiler/utils/host_offload_utils.cc @@ -30,11 +30,11 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/math_utils.h b/tensorflow/core/profiler/utils/math_utils.h index 1cffd53aafef7b..380884eeb994af 100644 --- a/tensorflow/core/profiler/utils/math_utils.h +++ b/tensorflow/core/profiler/utils/math_utils.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/base/macros.h" -#include "tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" // TODO: b/323943471 - This macro should eventually be provided by Abseil. #ifndef ABSL_DEPRECATE_AND_INLINE diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc index 66f3ab1e6a129f..5c8f13e58e8e0d 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc @@ -25,13 +25,13 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/tensorflow/core/profiler/utils/op_metrics_db_utils.h index a095a8e451cf0f..e3ff3fcc5f6205 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.h +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc index 8c8fbfdceb9492..52cbd2192d36f2 100644 --- a/tensorflow/core/profiler/utils/op_utils.cc +++ b/tensorflow/core/profiler/utils/op_utils.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h index dea310aeae16fd..57ce74e0c6ba5c 100644 --- a/tensorflow/core/profiler/utils/op_utils.h +++ b/tensorflow/core/profiler/utils/op_utils.h @@ -17,11 +17,11 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/step_intersection.cc b/tensorflow/core/profiler/utils/step_intersection.cc index 6fbf258b5e6b09..ed246abd9737ae 100644 --- a/tensorflow/core/profiler/utils/step_intersection.cc +++ b/tensorflow/core/profiler/utils/step_intersection.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/utils/step_intersection.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/trace_utils.h b/tensorflow/core/profiler/utils/trace_utils.h index 735a0207db2c27..89e2b4cde93586 100644 --- a/tensorflow/core/profiler/utils/trace_utils.h +++ b/tensorflow/core/profiler/utils/trace_utils.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TRACE_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TRACE_UTILS_H_ -#include "tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/trace_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_builder.h b/tensorflow/core/profiler/utils/xplane_builder.h index 873af726d37eab..c0e2c39b0dc6ac 100644 --- a/tensorflow/core/profiler/utils/xplane_builder.h +++ b/tensorflow/core/profiler/utils/xplane_builder.h @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index d6efbd1cd7a1b1..cfa748bf04ab8a 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_ -#include "tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_test_utils.h b/tensorflow/core/profiler/utils/xplane_test_utils.h index c3ed5de0f22237..c2619394d88445 100644 --- a/tensorflow/core/profiler/utils/xplane_test_utils.h +++ b/tensorflow/core/profiler/utils/xplane_test_utils.h @@ -19,10 +19,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_test_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h index 75ed0d1b3ed330..9292ed6a6b8e30 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.h +++ b/tensorflow/core/profiler/utils/xplane_utils.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h index deebadbdee5c3b..81db4a4f1bd315 100644 --- a/tensorflow/core/profiler/utils/xplane_visitor.h +++ b/tensorflow/core/profiler/utils/xplane_visitor.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_ -#include "tsl/profiler/utils/xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 8e07f478198d0d..7fe4593d15c6a5 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -179,8 +179,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", - "@local_tsl//tsl/profiler/utils:session_manager", "@local_xla//xla/tsl/profiler/convert:xplane_to_trace_events", "@local_xla//xla/tsl/profiler/rpc/client:capture_profile", + "@local_xla//xla/tsl/profiler/utils:session_manager", ], ) diff --git a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc index 2a7af8ccac312c..a35f957f7f7d5b 100644 --- a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc +++ b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/types/variant.h" #include "xla/tsl/profiler/convert/xplane_to_trace_events.h" #include "xla/tsl/profiler/rpc/client/capture_profile.h" +#include "xla/tsl/profiler/utils/session_manager.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/rpc/client/save_profile.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" -#include "tsl/profiler/utils/session_manager.h" namespace tensorflow { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD index 42d9a7985119cd..d6a7f8cb6328f3 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD @@ -396,9 +396,9 @@ tsl_cc_test( "//tsl/platform:strcat", "//tsl/platform:test", "//tsl/platform:test_main", - "//tsl/profiler/utils:time_utils_impl", "@local_xla//xla/tsl/lib/core:status_test_util", "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", + "@local_xla//xla/tsl/profiler/utils:time_utils_impl", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD index ee6e82ea4fa185..16bc18085000e1 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD @@ -38,7 +38,7 @@ filegroup( "scoped_memory_debug_annotation.h", "traceme.h", "traceme_encode.h", - "//tsl/profiler/utils:mobile_srcs_no_runtime", + "@local_xla//xla/tsl/profiler/utils:mobile_srcs_no_runtime", ], compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], @@ -221,7 +221,7 @@ cc_library( "//tsl/platform:platform_port", "//tsl/platform:status", "@local_xla//xla/tsl/profiler/convert:post_process_single_host_xplane", - "//tsl/profiler/utils:time_utils", + "@local_xla//xla/tsl/profiler/utils:time_utils", ]), alwayslink = True, ) @@ -268,11 +268,11 @@ cc_library( "//tsl/platform:logging", "//tsl/platform:macros", "//tsl/platform:types", - "//tsl/profiler/utils:no_init", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/profiler/utils:no_init", ] + if_not_android([ "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder", - "//tsl/profiler/utils:time_utils", + "@local_xla//xla/tsl/profiler/utils:time_utils", ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc index 30e718cc456c94..2932415dceae2e 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc @@ -28,12 +28,12 @@ limitations under the License. #if !defined(IS_MOBILE_PLATFORM) #include "xla/tsl/profiler/convert/post_process_single_host_xplane.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/host_info.h" #include "tsl/profiler/lib/profiler_collection.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/profiler_lock.h" -#include "tsl/profiler/utils/time_utils.h" #endif namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h index 4218d1c848a02a..ac5f0c14aba35c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h @@ -21,14 +21,14 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/no_init.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/profiler/lib/traceme_encode.h" // IWYU pragma: export -#include "tsl/profiler/utils/no_init.h" #if !defined(IS_MOBILE_PLATFORM) #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" -#include "tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" #endif namespace tsl { diff --git a/third_party/xla/xla/backends/profiler/cpu/BUILD b/third_party/xla/xla/backends/profiler/cpu/BUILD index a67bf1fdf718e6..bd2556a12079f3 100644 --- a/third_party/xla/xla/backends/profiler/cpu/BUILD +++ b/third_party/xla/xla/backends/profiler/cpu/BUILD @@ -36,15 +36,15 @@ cc_library( "//xla/tsl/profiler/backends/cpu:host_tracer_utils", "//xla/tsl/profiler/backends/cpu:threadpool_listener", "//xla/tsl/profiler/backends/cpu:traceme_recorder", + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/profiler/lib:profiler_collection", "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -96,14 +96,14 @@ cc_library( ":metadata_utils", "//xla/service:hlo_proto_cc", "//xla/service:xla_debug_info_manager", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/status", "@local_tsl//tsl/profiler/lib:profiler_factory", "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], alwayslink = True, ) @@ -117,9 +117,9 @@ cc_library( deps = [ "//xla/service:hlo_proto_cc", "//xla/tsl/profiler/convert:xla_op_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) @@ -129,6 +129,10 @@ xla_cc_test( deps = [ ":host_tracer_impl", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/profiler/utils:timespan", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:blocking_counter", @@ -138,9 +142,5 @@ xla_cc_test( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_visitor", ], ) diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc index 031301290406ab..75843df3160e52 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc @@ -24,13 +24,13 @@ limitations under the License. #include "xla/tsl/profiler/backends/cpu/host_tracer_utils.h" #include "xla/tsl/profiler/backends/cpu/threadpool_listener.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" +#include "xla/tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/errors.h" #include "tsl/profiler/lib/profiler_collection.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/time_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc index 2fca882f9910d8..ad7b241859d478 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc @@ -23,6 +23,10 @@ limitations under the License. #include #include "absl/types/optional.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" @@ -31,10 +35,6 @@ limitations under the License. #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/cpu/metadata_collector.cc b/third_party/xla/xla/backends/profiler/cpu/metadata_collector.cc index 26735490d8851c..2f75c0e6c64676 100644 --- a/third_party/xla/xla/backends/profiler/cpu/metadata_collector.cc +++ b/third_party/xla/xla/backends/profiler/cpu/metadata_collector.cc @@ -22,13 +22,13 @@ limitations under the License. #include "xla/backends/profiler/cpu/metadata_utils.h" #include "xla/service/hlo.pb.h" #include "xla/service/xla_debug_info_manager.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/cpu/metadata_utils.h b/third_party/xla/xla/backends/profiler/cpu/metadata_utils.h index 149e72fe259349..b30da770dda90f 100644 --- a/third_party/xla/xla/backends/profiler/cpu/metadata_utils.h +++ b/third_party/xla/xla/backends/profiler/cpu/metadata_utils.h @@ -18,9 +18,9 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index b7fcada623f68e..64bdf6a97636f7 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -29,6 +29,7 @@ tsl_gpu_library( copts = tf_profiler_copts() + tsl_copts(), deps = [ ":cupti_utils", + "//xla/tsl/profiler/utils:time_utils", "//xla/tsl/util:env_var", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -39,7 +40,6 @@ tsl_gpu_library( "@local_tsl//tsl/profiler/lib:profiler_factory", "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils", ] + if_cuda([ # keep sorted ":cupti_buffer_events", @@ -117,7 +117,7 @@ xla_test( ":cupti_wrapper", ":mock_cupti", "@com_google_absl//absl/memory", - "@local_tsl//tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:time_utils", ]), ) @@ -175,6 +175,8 @@ tsl_gpu_library( ":cupti_utils", ":nvtx_utils", "//xla/tsl/profiler/backends/cpu:annotation_stack", + "//xla/tsl/profiler/utils:lock_free_queue", + "//xla/tsl/profiler/utils:per_thread", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -186,8 +188,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:lock_free_queue", - "@local_tsl//tsl/profiler/utils:per_thread", ], ) @@ -232,6 +232,10 @@ tsl_gpu_library( deps = [ "//xla/stream_executor/rocm:roctracer_wrapper", "//xla/tsl/profiler/backends/cpu:annotation_stack", + "//xla/tsl/profiler/utils:parse_annotation", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "//xla/tsl/util:env_var", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -251,10 +255,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:profiler_factory", "@local_tsl//tsl/profiler/lib:profiler_interface", - "@local_tsl//tsl/profiler/utils:parse_annotation", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -275,6 +275,7 @@ tsl_gpu_library( ":rocm_collector", "//xla/stream_executor/rocm:roctracer_wrapper", "//xla/tsl/profiler/backends/cpu:annotation_stack", + "//xla/tsl/profiler/utils:time_utils", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -288,7 +289,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:time_utils", ], ) @@ -313,6 +313,12 @@ tsl_gpu_library( deps = [ ":cupti_buffer_events", ":cupti_interface", + "//xla/tsl/profiler/utils:lock_free_queue", + "//xla/tsl/profiler/utils:parse_annotation", + "//xla/tsl/profiler/utils:trace_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", @@ -321,12 +327,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:lock_free_queue", - "@local_tsl//tsl/profiler/utils:parse_annotation", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ] + if_cuda([ "//xla/tsl/cuda:cupti", "//xla/tsl/cuda", @@ -341,6 +341,8 @@ tsl_gpu_library( visibility = ["//visibility:public"], deps = [ ":cupti_interface", + "//xla/tsl/profiler/utils:buffer_pool", + "//xla/tsl/profiler/utils:lock_free_queue", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -351,8 +353,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/profiler/utils:buffer_pool", - "@local_tsl//tsl/profiler/utils:lock_free_queue", ] + if_cuda(["//xla/tsl/cuda:cupti"]), ) diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h index f58dda54e623c1..d0a48535834024 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h @@ -31,10 +31,10 @@ limitations under the License. #include "absl/container/node_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/buffer_pool.h" +#include "xla/tsl/profiler/utils/lock_free_queue.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/profiler/utils/buffer_pool.h" -#include "tsl/profiler/utils/lock_free_queue.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc index 49c9c86a49a89a..043bc4250c9681 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc @@ -28,15 +28,15 @@ limitations under the License. #include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_occupancy.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/abi.h" #include "tsl/platform/host_info.h" #include "tsl/platform/mem.h" #include "tsl/platform/mutex.h" -#include "tsl/profiler/utils/parse_annotation.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc index a357d9ab41c97b..05aa020d84ab9e 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" #include "xla/backends/profiler/gpu/mock_cupti.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/test.h" -#include "tsl/profiler/utils/time_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc index 3374181c569204..a21f804781ca6c 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc @@ -31,11 +31,11 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_interface.h" #include "xla/backends/profiler/gpu/nvtx_utils.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#include "xla/tsl/profiler/utils/per_thread.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/host_info.h" #include "tsl/platform/logging.h" -#include "tsl/profiler/utils/per_thread.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc index a34b5134b26455..578d4ab6d3021d 100644 --- a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc +++ b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/time_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc b/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc index aca42312a7d404..e5a91f9431b86a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc +++ b/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc @@ -27,6 +27,10 @@ limitations under the License. #include "xla/backends/profiler/gpu/rocm_collector.h" #include "xla/backends/profiler/gpu/rocm_tracer.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "xla/tsl/util/env_var.h" #include "tsl/platform/abi.h" #include "tsl/platform/env_time.h" @@ -36,10 +40,6 @@ limitations under the License. #include "tsl/platform/thread_annotations.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" -#include "tsl/profiler/utils/parse_annotation.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc index d96cfdc8ed23bb..88371e5b09605a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc @@ -25,6 +25,10 @@ limitations under the License. #include "absl/types/optional.h" #include "xla/stream_executor/rocm/roctracer_wrapper.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "xla/tsl/util/env_var.h" #include "tsl/platform/abi.h" #include "tsl/platform/env_time.h" @@ -36,10 +40,6 @@ limitations under the License. #include "tsl/platform/types.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" -#include "tsl/profiler/utils/parse_annotation.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h index 2c9ccd847aed1e..220fa2bb13e4a2 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_set.h" -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc index 5a77b679112e30..fad3e39831c49a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc @@ -19,12 +19,12 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "rocm/rocm_config.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/mem.h" -#include "tsl/profiler/utils/time_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/plugin/BUILD b/third_party/xla/xla/backends/profiler/plugin/BUILD index b50b08779f2d91..01556a6a0828c5 100644 --- a/third_party/xla/xla/backends/profiler/plugin/BUILD +++ b/third_party/xla/xla/backends/profiler/plugin/BUILD @@ -80,6 +80,8 @@ xla_cc_test( ":plugin_tracer_impl", ":profiler_c_api_hdrs", ":profiler_error", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:logging", @@ -87,7 +89,5 @@ xla_cc_test( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_visitor", ], ) diff --git a/third_party/xla/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc b/third_party/xla/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc index b9dc203ddad276..a0e1cd45407a92 100644 --- a/third_party/xla/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc +++ b/third_party/xla/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/backends/profiler/plugin/plugin_tracer.h" #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/backends/profiler/plugin/profiler_error.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/logging.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/tpu/BUILD b/third_party/xla/xla/backends/profiler/tpu/BUILD index bf6dd25a961296..9128dd6ff9339e 100644 --- a/third_party/xla/xla/backends/profiler/tpu/BUILD +++ b/third_party/xla/xla/backends/profiler/tpu/BUILD @@ -19,6 +19,7 @@ cc_library( "//xla/stream_executor/tpu:tpu_profiler_init_fns", "//xla/stream_executor/tpu:tsl_status_helper", "//xla/tsl/c:tsl_status", + "//xla/tsl/profiler/utils:xplane_schema", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", @@ -27,7 +28,6 @@ cc_library( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], alwayslink = True, ) diff --git a/third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc b/third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc index 7488645cf40f2a..b602cdd79fe603 100644 --- a/third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc +++ b/third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_ops_c_api.h" #include "xla/stream_executor/tpu/tsl_status_helper.h" #include "xla/tsl/c/tsl_status.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/types.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" #if !defined(PLATFORM_GOOGLE) #include "xla/stream_executor/tpu/tpu_profiler_init_fns.inc" diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 0c5bed03247840..0cc840ac39c1af 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -1394,6 +1394,11 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", "//xla/tsl/profiler/convert:xla_op_utils", + "//xla/tsl/profiler/utils:file_system_utils", + "//xla/tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", + "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -1402,11 +1407,6 @@ cc_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_visitor", ], ) @@ -1419,13 +1419,13 @@ xla_cc_test( "//xla/tests:verified_hlo_module", "//xla/tsl/profiler/convert:xla_op_utils", "//xla/tsl/profiler/rpc/client:save_profile", + "//xla/tsl/profiler/utils:file_system_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index ac827cfa0c643d..04c008c035e3b7 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -99,6 +99,7 @@ cc_library( "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", "//xla/python/ifrt_proxy/common:test_utils", "//xla/python/ifrt_proxy/common:types", + "//xla/tsl/profiler/utils:xplane_schema", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/functional:bind_front", @@ -115,7 +116,6 @@ cc_library( "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", - "@local_tsl//tsl/profiler/utils:xplane_schema", ] + if_google(["@com_google_absl//absl/types:source_location"]), ) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index b2631019a6ac5e..19998ffd34619f 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -38,13 +38,13 @@ #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" #include "xla/python/ifrt_proxy/common/test_utils.h" #include "xla/python/ifrt_proxy/common/types.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/profiler/internal/BUILD b/third_party/xla/xla/python/profiler/internal/BUILD index aeebb16d8ca5c0..6d926ea97802de 100644 --- a/third_party/xla/xla/python/profiler/internal/BUILD +++ b/third_party/xla/xla/python/profiler/internal/BUILD @@ -21,6 +21,10 @@ cc_library( "//tensorflow/python/profiler/internal:__subpackages__", ]), deps = [ + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -30,10 +34,6 @@ cc_library( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", "@pybind11", ], alwayslink = True, diff --git a/third_party/xla/xla/python/profiler/internal/python_hooks.cc b/third_party/xla/xla/python/profiler/internal/python_hooks.cc index ce66d686a982e9..0da1fe5e0124b5 100644 --- a/third_party/xla/xla/python/profiler/internal/python_hooks.cc +++ b/third_party/xla/xla/python/profiler/internal/python_hooks.cc @@ -21,13 +21,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/time_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/python/xplane_to_profile_instructions.cc b/third_party/xla/xla/python/xplane_to_profile_instructions.cc index a446cc810b0196..b0db73556e367b 100644 --- a/third_party/xla/xla/python/xplane_to_profile_instructions.cc +++ b/third_party/xla/xla/python/xplane_to_profile_instructions.cc @@ -30,15 +30,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc b/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc index c0a93291a3ef13..ee77891fb6b61c 100644 --- a/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc +++ b/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc @@ -23,12 +23,12 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/test.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 6a8c9ff80aec95..8fe71923b2656c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -291,7 +291,7 @@ cc_library( deps = [ "//xla/backends/profiler/gpu:cupti_collector", "//xla/backends/profiler/gpu:cupti_tracer", - "@local_tsl//tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:time_utils", ], ) @@ -336,6 +336,7 @@ cc_library( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/profiler/utils:time_utils", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -348,7 +349,6 @@ cc_library( "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/utils:time_utils", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc index 5074c8c44f0be0..204f1d73710c03 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" -#include "tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" namespace xla::gpu { diff --git a/third_party/xla/xla/tsl/lib/gtl/BUILD b/third_party/xla/xla/tsl/lib/gtl/BUILD index ceac5767c2cf08..9c0c8faf532110 100644 --- a/third_party/xla/xla/tsl/lib/gtl/BUILD +++ b/third_party/xla/xla/tsl/lib/gtl/BUILD @@ -30,7 +30,7 @@ package( "//xla:__subpackages__", "//tensorflow/core/lib/gtl:__subpackages__", "//xla/tsl/distributed_runtime/rpc:__pkg__", - "@local_tsl//tsl/profiler/utils:__pkg__", + "//xla/tsl/profiler/utils:__pkg__", ]), licenses = ["notice"], ) diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD b/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD index 28c4a7c57b2142..4121ad5fc3a92b 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD @@ -40,6 +40,8 @@ cc_library( "//xla/tsl/profiler:xla_internal", ]), deps = [ + "//xla/tsl/profiler/utils:lock_free_queue", + "//xla/tsl/profiler/utils:per_thread", "@com_google_absl//absl/container:flat_hash_map", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", @@ -47,8 +49,6 @@ cc_library( "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:lock_free_queue", - "@local_tsl//tsl/profiler/utils:per_thread", ], alwayslink = True, ) @@ -59,6 +59,9 @@ tsl_cc_test( deps = [ ":traceme_recorder", ":traceme_recorder_impl", + "//xla/tsl/profiler/utils:math_utils", + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:env", @@ -68,9 +71,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:time_utils_impl", ], ) @@ -120,13 +120,13 @@ cc_library( ]), deps = [ ":traceme_recorder", + "//xla/tsl/profiler/utils:parse_annotation", + "//xla/tsl/profiler/utils:tf_op_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:parse_annotation", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -144,6 +144,8 @@ cc_library( deps = [ ":threadpool_listener_state", ":traceme_recorder", + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:xplane_schema", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:logging", @@ -153,8 +155,6 @@ cc_library( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/lib:traceme_encode", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc index 4195d8df45c953..3ee8fae3f04883 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc @@ -19,12 +19,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/parse_annotation.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc index 3772efb86adbb9..af9fc451b2d238 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc @@ -21,14 +21,14 @@ limitations under the License. #include "absl/status/status.h" #include "xla/tsl/profiler/backends/cpu/threadpool_listener_state.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" +#include "xla/tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/logging.h" #include "tsl/platform/tracing.h" #include "tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/lib/traceme_encode.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/time_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc index 1e0b77e92ed7b4..1279b80dd62dbe 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc @@ -25,12 +25,12 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/profiler/utils/lock_free_queue.h" +#include "xla/tsl/profiler/utils/per_thread.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/lock_free_queue.h" -#include "tsl/profiler/utils/per_thread.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc index bdaf0e9ae3f81d..9fa89ed3d5e400 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc @@ -23,14 +23,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/notification.h" #include "tsl/platform/test.h" #include "tsl/platform/threadpool.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/time_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/BUILD b/third_party/xla/xla/tsl/profiler/convert/BUILD index edda2c08a65d07..7b1eb9abd00ee2 100644 --- a/third_party/xla/xla/tsl/profiler/convert/BUILD +++ b/third_party/xla/xla/tsl/profiler/convert/BUILD @@ -55,11 +55,11 @@ cc_library( copts = tf_profiler_copts(), visibility = internal_visibility(["//xla/tsl/profiler:internal"]), deps = [ + "//xla/tsl/profiler/utils:timestamp_utils", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:timestamp_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -73,14 +73,14 @@ cc_library( ]), deps = [ ":trace_container", + "//xla/tsl/profiler/utils:format_utils", + "//xla/tsl/profiler/utils:math_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", - "@local_tsl//tsl/profiler/utils:format_utils", - "@local_tsl//tsl/profiler/utils:math_utils", ], ) @@ -120,16 +120,16 @@ cc_library( ]), deps = [ ":trace_container", + "//xla/tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/profiler/utils:trace_utils", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", + "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_visitor", ], ) @@ -139,12 +139,12 @@ tsl_cc_test( srcs = ["xplane_to_trace_events_test.cc"], deps = [ ":xplane_to_trace_events", + "//xla/tsl/profiler/utils:trace_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) diff --git a/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc b/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc index ec1c4e72ba970a..864da925423d99 100644 --- a/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc +++ b/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/timestamp_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc index 24d55cc81998be..d9bc3319fdbb5d 100644 --- a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc +++ b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc @@ -22,11 +22,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "json/json.h" +#include "xla/tsl/profiler/utils/format_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/format_utils.h" -#include "tsl/profiler/utils/math_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc index fbc6ea7da1a8d3..c37951436d7168 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc +++ b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc @@ -24,14 +24,14 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/trace_events.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc index 8a724596767a31..6e0d3955c84cbf 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc +++ b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/test.h" #include "tsl/profiler/protobuf/trace_events.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/BUILD b/third_party/xla/xla/tsl/profiler/rpc/BUILD index 6f96ed053dc550..2b902f0c35d391 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/BUILD @@ -29,6 +29,10 @@ cc_library( ]), deps = [ "//xla/tsl/profiler/rpc/client:save_profile", + "//xla/tsl/profiler/utils:file_system_utils", + "//xla/tsl/profiler/utils:math_utils", + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -43,10 +47,6 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:xplane_utils", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD index 9a891967712b0e..0c0afab5305816 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD @@ -36,6 +36,7 @@ cc_library( ":save_profile", "//xla/tsl/profiler/convert:trace_events_to_json", "//xla/tsl/profiler/convert:xplane_to_trace_events", + "//xla/tsl/profiler/utils:session_manager", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -47,7 +48,6 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:session_manager", ], ) @@ -66,6 +66,7 @@ cc_library( deps = [ "//xla/tsl/lib/io:zlib_compression_options", "//xla/tsl/lib/io:zlib_outputbuffer", + "//xla/tsl/profiler/utils:file_system_utils", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env", @@ -75,7 +76,6 @@ cc_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", ], ) @@ -161,6 +161,7 @@ tsl_cc_test( ":profiler_client_test_util", "//xla/tsl/profiler/rpc:profiler_server_impl", "//xla/tsl/profiler/rpc:profiler_service_impl", + "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:errors", @@ -170,7 +171,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:profiler_factory_impl", "@local_tsl//tsl/profiler/lib:profiler_session_impl", - "@local_tsl//tsl/profiler/utils:time_utils_impl", ] + tf_protos_profiler_service(), ) @@ -181,6 +181,7 @@ cc_library( copts = tf_profiler_copts(), deps = [ ":profiler_client_for_pybind", + "//xla/tsl/profiler/utils:time_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -192,7 +193,6 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:time_utils", ], ) @@ -205,6 +205,7 @@ tsl_cc_test( ":remote_profiler_session_manager", "//xla/tsl/profiler/rpc:profiler_server_impl", "//xla/tsl/profiler/rpc:profiler_service_impl", + "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env_impl", @@ -216,6 +217,5 @@ tsl_cc_test( "@local_tsl//tsl/profiler/lib:profiler_factory_impl", "@local_tsl//tsl/profiler/lib:profiler_session_impl", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils_impl", ] + tf_protos_profiler_service(), ) diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc index b78ceadcda2b93..84dd66b6e2f118 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/tsl/profiler/rpc/client/profiler_client.h" #include "xla/tsl/profiler/rpc/client/remote_profiler_session_manager.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" +#include "xla/tsl/profiler/utils/session_manager.h" #include "tsl/platform/errors.h" #include "tsl/platform/host_info.h" #include "tsl/platform/status.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/profiler/protobuf/profiler_analysis.pb.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" -#include "tsl/profiler/utils/session_manager.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc index bc987157013dcf..2eb7e0d6743180 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc @@ -23,11 +23,11 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/tsl/profiler/rpc/client/profiler_client.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/time_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc index acdceeaa6e0c10..bc8bf69f492bfd 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/lib/io/zlib_compression_options.h" #include "xla/tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc b/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc index 3f4e3669174f5e..d359f0bdadb1fd 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc @@ -21,6 +21,10 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "grpcpp/support/status.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" @@ -32,10 +36,6 @@ limitations under the License. #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/time_utils.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/xla/tsl/profiler/utils/BUILD similarity index 69% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD rename to third_party/xla/xla/tsl/profiler/utils/BUILD index 6539e90ea0157c..281fca182068f9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/xla/xla/tsl/profiler/utils/BUILD @@ -1,14 +1,14 @@ +load("@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test") +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") -load("@local_xla//xla/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") -load("//tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tsl/platform:build_config_root.bzl", "if_static") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load("//xla/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = internal_visibility([ - "@local_xla//xla/tsl/profiler:internal", + "//xla/tsl/profiler:internal", ]), licenses = ["notice"], ) @@ -16,7 +16,7 @@ package( package_group( name = "friends", includes = [ - "@local_xla//xla/tsl/profiler:friends", + "//xla/tsl/profiler:friends", ], ) @@ -29,7 +29,7 @@ cc_library( name = "format_utils", hdrs = ["format_utils.h"], deps = [ - "//tsl/platform:logging", + "@local_tsl//tsl/platform:logging", ], ) @@ -53,9 +53,9 @@ cc_library( ], copts = tf_profiler_copts(), visibility = internal_visibility([ - "@local_xla//xla:__subpackages__", - "//tsl/platform/cloud:__pkg__", - "@local_xla//xla/tsl/profiler:internal", + "//xla:__subpackages__", + "@local_tsl//tsl/platform/cloud:__pkg__", + "//xla/tsl/profiler:internal", ]), deps = [ ":math_utils", @@ -71,9 +71,9 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":math_utils", - "//tsl/platform:logging", - "//tsl/platform:types", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", ], ) @@ -82,8 +82,8 @@ tsl_cc_test( srcs = ["timespan_test.cc"], deps = [ ":timespan", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -93,9 +93,9 @@ cc_library( hdrs = ["tf_op_utils.h"], copts = tf_profiler_copts(), deps = [ - "//tsl/platform:macros", - "//tsl/platform:regexp", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:regexp", ], ) @@ -105,9 +105,9 @@ tsl_cc_test( srcs = ["tf_op_utils_test.cc"], deps = [ ":tf_op_utils", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -119,15 +119,15 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":tf_op_utils", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:types", - "//tsl/profiler/lib:context_types_hdrs", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_xla//xla/tsl/lib/gtl:map_util", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/lib:context_types_hdrs", ], ) @@ -139,12 +139,12 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":timespan", - "//tsl/platform:logging", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -157,14 +157,14 @@ cc_library( deps = [ ":math_utils", ":timespan", - "//tsl/platform:macros", - "//tsl/platform:protobuf", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -175,10 +175,10 @@ tsl_cc_test( deps = [ ":xplane_builder", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -187,12 +187,12 @@ cc_library( hdrs = ["trace_utils.h"], copts = tf_profiler_copts(), visibility = internal_visibility([ - "@local_xla//xla/backends/profiler/gpu:__pkg__", - "@local_xla//xla/tsl/profiler:internal", + "//xla/backends/profiler/gpu:__pkg__", + "//xla/tsl/profiler:internal", ]), deps = [ - "//tsl/platform:types", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:types", ], ) @@ -210,16 +210,16 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_visitor", - "//tsl/platform:fingerprint", - "//tsl/platform:types", - "//tsl/profiler/lib:context_types", - "//tsl/profiler/protobuf:xplane_proto_cc", + "//xla/tsl/util:stats_calculator_portable", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/util:stats_calculator_portable", + "@local_tsl//tsl/platform:fingerprint", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/lib:context_types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -232,14 +232,14 @@ tsl_cc_test( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", - "//tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/profiler/utils:tf_xplane_visitor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -251,7 +251,7 @@ cc_library( deps = [ ":xplane_schema", ":xplane_visitor", - "//tsl/profiler/protobuf:xplane_proto_cc", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -271,9 +271,9 @@ tsl_cc_test( srcs = ["parse_annotation_test.cc"], deps = [ ":parse_annotation", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -289,18 +289,18 @@ cc_library( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - "//tsl/platform:logging", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_xla//xla/tsl/lib/gtl:map_util", + "@local_tsl//tsl/platform:dso_loader", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -315,11 +315,11 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_utils", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -333,13 +333,13 @@ tsl_cc_test( ":xplane_schema", ":xplane_test_utils", ":xplane_visitor", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -350,10 +350,10 @@ cc_library( deps = [ ":xplane_schema", ":xplane_utils", - "//tsl/platform:regexp", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:regexp", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -365,10 +365,10 @@ tsl_cc_test( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -377,12 +377,12 @@ cc_library( hdrs = ["file_system_utils.h"], copts = tf_profiler_copts(), visibility = internal_visibility([ - "@local_xla//xla/python:__pkg__", - "@local_xla//xla/tsl/profiler:internal", + "//xla/python:__pkg__", + "//xla/tsl/profiler:internal", ]), deps = [ - "//tsl/platform", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform", ], ) @@ -392,14 +392,14 @@ cc_library( hdrs = ["buffer_pool.h"], copts = tf_profiler_copts(), visibility = internal_visibility([ - "@local_xla//xla/backends/profiler/gpu:__pkg__", - "@local_xla//xla/tsl/profiler:internal", + "//xla/backends/profiler/gpu:__pkg__", + "//xla/tsl/profiler:internal", ]), deps = [ - "//tsl/platform:logging", - "//tsl/platform:mutex", - "//tsl/platform:platform_port", - "//tsl/platform:thread_annotations", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:thread_annotations", ], ) @@ -408,8 +408,8 @@ tsl_cc_test( srcs = ["buffer_pool_test.cc"], deps = [ ":buffer_pool", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -425,14 +425,14 @@ cc_library( ":xplane_builder", ":xplane_mutators", ":xplane_schema", - "//tsl/profiler/lib:context_types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/profiler/lib:context_types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -446,12 +446,12 @@ tsl_cc_test( ":xplane_schema", ":xplane_test_utils", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/profiler/lib:connected_traceme", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/profiler/lib:connected_traceme", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -460,11 +460,11 @@ cc_library( srcs = ["session_manager.cc"], hdrs = ["session_manager.h"], deps = [ - "//tsl/platform:errors", - "//tsl/platform:status", - "//tsl/profiler/lib:profiler_session", - "//tsl/profiler/protobuf:profiler_options_proto_cc", "@com_google_absl//absl/container:flat_hash_map", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/profiler/lib:profiler_session", + "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", ], ) @@ -476,8 +476,8 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_utils", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/log", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -489,8 +489,8 @@ tsl_cc_test( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -516,8 +516,8 @@ cc_library( hdrs = ["lock_free_queue.h"], deps = [ ":no_init", - "//tsl/platform:logging", - "//tsl/platform:macros", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", ], ) @@ -528,11 +528,11 @@ tsl_cc_test( srcs = ["lock_free_queue_test.cc"], deps = [ ":lock_free_queue", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -559,13 +559,13 @@ tsl_cc_test( srcs = ["per_thread_test.cc"], deps = [ ":per_thread", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -575,8 +575,8 @@ cc_library( hdrs = ["device_utils.h"], deps = [ ":xplane_schema", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -586,9 +586,9 @@ tsl_cc_test( deps = [ ":device_utils", ":xplane_schema", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.cc b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.cc rename to third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc index e7811327c475b2..f16fe91d573a8b 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.cc +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/buffer_pool.h" +#include "xla/tsl/profiler/utils/buffer_pool.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.h b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.h rename to third_party/xla/xla/tsl/profiler/utils/buffer_pool.h index dcfd5b0acb6a1b..5482b7cd8bc261 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.h +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_BUFFER_POOL_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_BUFFER_POOL_H_ +#ifndef XLA_TSL_PROFILER_UTILS_BUFFER_POOL_H_ +#define XLA_TSL_PROFILER_UTILS_BUFFER_POOL_H_ #include @@ -59,4 +59,4 @@ class BufferPool { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_BUFFER_POOL_H_ +#endif // XLA_TSL_PROFILER_UTILS_BUFFER_POOL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool_test.cc b/third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool_test.cc rename to third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc index ec1696a12e08b7..4e5dbab63085de 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/buffer_pool.h" +#include "xla/tsl/profiler/utils/buffer_pool.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.cc b/third_party/xla/xla/tsl/profiler/utils/device_utils.cc similarity index 92% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/device_utils.cc index 9caedcc47be08c..945a157ec9c456 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/device_utils.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/device_utils.h" +#include "xla/tsl/profiler/utils/device_utils.h" #include "absl/strings/match.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.h b/third_party/xla/xla/tsl/profiler/utils/device_utils.h similarity index 85% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.h rename to third_party/xla/xla/tsl/profiler/utils/device_utils.h index 33c331a0790a6e..825a9fe975437d 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/device_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ #include "tsl/profiler/protobuf/xplane.pb.h" @@ -34,4 +34,4 @@ DeviceType GetDeviceType(const tensorflow::profiler::XPlane& plane); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc similarity index 93% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc index e01680678c2b19..6f872dc5713bed 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/device_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/device_utils.h" +#include "xla/tsl/profiler/utils/device_utils.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/test.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/file_system_utils.h b/third_party/xla/xla/tsl/profiler/utils/file_system_utils.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/file_system_utils.h rename to third_party/xla/xla/tsl/profiler/utils/file_system_utils.h index 6d7c937908f43c..522b5284afec49 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/file_system_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/file_system_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ #include #include @@ -66,4 +66,4 @@ std::string ProfilerJoinPath(const T&... args) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/format_utils.h b/third_party/xla/xla/tsl/profiler/utils/format_utils.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/format_utils.h rename to third_party/xla/xla/tsl/profiler/utils/format_utils.h index 4a1be939de9ccd..d93d69e8592d70 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/format_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/format_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ #include @@ -60,4 +60,4 @@ inline std::string MaxPrecision(double d) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc b/third_party/xla/xla/tsl/profiler/utils/group_events.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc rename to third_party/xla/xla/tsl/profiler/utils/group_events.cc index d8f3d4ad94c12e..393e170b839446 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc +++ b/third_party/xla/xla/tsl/profiler/utils/group_events.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/group_events.h" #include #include @@ -32,14 +32,14 @@ limitations under the License. #include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" #include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.h b/third_party/xla/xla/tsl/profiler/utils/group_events.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.h rename to third_party/xla/xla/tsl/profiler/utils/group_events.h index c77c32a623ea08..52a73529fb734c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.h +++ b/third_party/xla/xla/tsl/profiler/utils/group_events.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ +#define XLA_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ #include #include @@ -28,11 +28,11 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { @@ -258,4 +258,4 @@ void GroupTpuEventsOSS( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ +#endif // XLA_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events_test.cc b/third_party/xla/xla/tsl/profiler/utils/group_events_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/group_events_test.cc rename to third_party/xla/xla/tsl/profiler/utils/group_events_test.cc index e3607626263004..e8c3306ee4ea3d 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/group_events_test.cc @@ -13,20 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/group_events.h" #include #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue.h b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue.h rename to third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h index d8f197be6c314b..9f22aa8b8e5094 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue.h +++ b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ +#ifndef XLA_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ +#define XLA_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ #include @@ -23,9 +23,9 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/no_init.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" -#include "tsl/profiler/utils/no_init.h" namespace tsl { namespace profiler { @@ -311,4 +311,4 @@ class LockFreeQueue final } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ +#endif // XLA_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue_test.cc b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue_test.cc rename to third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc index 78a8b07d7bdc20..fd8ccdfb659207 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/lock_free_queue.h" +#include "xla/tsl/profiler/utils/lock_free_queue.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/math_utils.h b/third_party/xla/xla/tsl/profiler/utils/math_utils.h similarity index 94% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/math_utils.h rename to third_party/xla/xla/tsl/profiler/utils/math_utils.h index 06b3495576d887..cd9e8685e8c35c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/math_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/math_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_MATH_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_MATH_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_MATH_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_MATH_UTILS_H_ #include @@ -70,4 +70,4 @@ inline double GibibytesPerSecond(double gigabytes, double ns) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_MATH_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_MATH_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/no_init.h b/third_party/xla/xla/tsl/profiler/utils/no_init.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/no_init.h rename to third_party/xla/xla/tsl/profiler/utils/no_init.h index 5beb1908380c90..6f7d6aa95d1b25 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/no_init.h +++ b/third_party/xla/xla/tsl/profiler/utils/no_init.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_NO_INIT_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_NO_INIT_H_ +#ifndef XLA_TSL_PROFILER_UTILS_NO_INIT_H_ +#define XLA_TSL_PROFILER_UTILS_NO_INIT_H_ #include @@ -48,4 +48,4 @@ union NoInit { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_NO_INIT_H_ +#endif // XLA_TSL_PROFILER_UTILS_NO_INIT_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.cc b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.cc rename to third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc index 986f9b08e65a32..67328c1ea6e9bc 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.cc +++ b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.h b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.h rename to third_party/xla/xla/tsl/profiler/utils/parse_annotation.h index 1552a2f271140b..8d755f7e64fb4c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.h +++ b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ +#ifndef XLA_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ +#define XLA_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ #include @@ -48,4 +48,4 @@ std::vector ParseAnnotationStack( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ +#endif // XLA_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation_test.cc b/third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation_test.cc rename to third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc index 93730c851d4311..6225916ef96cfc 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread.h b/third_party/xla/xla/tsl/profiler/utils/per_thread.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread.h rename to third_party/xla/xla/tsl/profiler/utils/per_thread.h index 3163fd890c1c7f..f3e9d79242ceab 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread.h +++ b/third_party/xla/xla/tsl/profiler/utils/per_thread.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_PER_THREAD_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_PER_THREAD_H_ +#ifndef XLA_TSL_PROFILER_UTILS_PER_THREAD_H_ +#define XLA_TSL_PROFILER_UTILS_PER_THREAD_H_ #include #include @@ -143,4 +143,4 @@ class PerThread { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_PER_THREAD_H_ +#endif // XLA_TSL_PROFILER_UTILS_PER_THREAD_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread_test.cc b/third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread_test.cc rename to third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc index 15af47c3195a47..9007319c4d0c74 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/per_thread.h" +#include "xla/tsl/profiler/utils/per_thread.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc rename to third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.cc index 3d06a05609a118..9925276dacfdde 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc +++ b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/preprocess_xplane.h" +#include "xla/tsl/profiler/utils/preprocess_xplane.h" #include #include @@ -21,10 +21,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h rename to third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h index 9d027c780c9bfb..c64a6d02417e48 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h +++ b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ +#ifndef XLA_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ +#define XLA_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ #include #include @@ -31,13 +31,13 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_mutators.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_mutators.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { @@ -533,4 +533,4 @@ void PreprocessXPlane(XPlane* plane); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ +#endif // XLA_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc rename to third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc index 9712893645090c..d18d6452a6a85d 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/preprocess_xplane.h" +#include "xla/tsl/profiler/utils/preprocess_xplane.h" #include #include @@ -21,14 +21,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/hash/hash.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/profiler/lib/connected_traceme.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.cc b/third_party/xla/xla/tsl/profiler/utils/session_manager.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.cc rename to third_party/xla/xla/tsl/profiler/utils/session_manager.cc index 7fd31dab970104..d45b6edd83efba 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.cc +++ b/third_party/xla/xla/tsl/profiler/utils/session_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/session_manager.h" +#include "xla/tsl/profiler/utils/session_manager.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.h b/third_party/xla/xla/tsl/profiler/utils/session_manager.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.h rename to third_party/xla/xla/tsl/profiler/utils/session_manager.h index 9a6a6300cef51f..fd8c60cbc63d13 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.h +++ b/third_party/xla/xla/tsl/profiler/utils/session_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ +#ifndef XLA_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ +#define XLA_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ #include #include @@ -52,4 +52,4 @@ absl::Status ValidateHostPortPair(absl::string_view host_port); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ +#endif // XLA_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/tf_op_utils.cc index 4129e2ae8fa7c7..981de3f3141d32 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.h b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.h rename to third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h index 85230331cc6e06..078d4d7c3b6f9c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ #include #include @@ -151,4 +151,4 @@ bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc index d03be20b197ca6..aef2bbc686f4d8 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_xplane_visitor.h b/third_party/xla/xla/tsl/profiler/utils/tf_xplane_visitor.h similarity index 78% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tf_xplane_visitor.h rename to third_party/xla/xla/tsl/profiler/utils/tf_xplane_visitor.h index 59dbbcc100fbfc..f902562935743b 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_xplane_visitor.h +++ b/third_party/xla/xla/tsl/profiler/utils/tf_xplane_visitor.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ +#define XLA_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { @@ -32,4 +32,4 @@ inline XPlaneVisitor CreateTfXPlaneVisitor( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ +#endif // XLA_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.cc b/third_party/xla/xla/tsl/profiler/utils/time_utils.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/time_utils.cc index 03d9973df0562f..a101ec2335070a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/time_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "absl/time/clock.h" #include "absl/time/time.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.h b/third_party/xla/xla/tsl/profiler/utils/time_utils.h similarity index 87% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.h rename to third_party/xla/xla/tsl/profiler/utils/time_utils.h index 3cd30214f49975..65c12c70005b30 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/time_utils.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TIME_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TIME_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TIME_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TIME_UTILS_H_ #include -#include "tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" namespace tsl { namespace profiler { @@ -40,4 +40,4 @@ inline void SpinForMicros(int64_t us) { SpinForNanos(us * 1000); } } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TIME_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TIME_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timespan.h b/third_party/xla/xla/tsl/profiler/utils/timespan.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timespan.h rename to third_party/xla/xla/tsl/profiler/utils/timespan.h index ee4c1d646ab3a3..d1883b8566a6ae 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timespan.h +++ b/third_party/xla/xla/tsl/profiler/utils/timespan.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TIMESPAN_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TIMESPAN_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TIMESPAN_H_ +#define XLA_TSL_PROFILER_UTILS_TIMESPAN_H_ #include #include #include "absl/strings/str_cat.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/math_utils.h" namespace tsl { namespace profiler { @@ -131,4 +131,4 @@ inline Timespan MilliSpan(double start_ms, double end_ms) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TIMESPAN_H_ +#endif // XLA_TSL_PROFILER_UTILS_TIMESPAN_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timespan_test.cc b/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timespan_test.cc rename to third_party/xla/xla/tsl/profiler/utils/timespan_test.cc index f729f088b14222..57d7876365c904 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timespan_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils.cc similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/timestamp_utils.cc index ea208ed309c468..17b728f7f9cad1 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/profiler/utils/timestamp_utils.h" #include #include "absl/log/log.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils.h similarity index 86% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h rename to third_party/xla/xla/tsl/profiler/utils/timestamp_utils.h index 87013c97a6f5b0..a2b61672fbabba 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ #include @@ -30,4 +30,4 @@ void SetSessionTimestamps(uint64_t start_walltime_ns, uint64_t stop_walltime_ns, } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc similarity index 86% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc index 893e31ebb5ec59..dd2e434adbc0f3 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.cc similarity index 92% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.cc index 9274a1da941743..d456164be18a46 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" #include #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/regexp.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h rename to third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.h index 2fb7c677e3a058..3f6adb498cd270 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ #include #include @@ -43,4 +43,4 @@ std::optional GetSparseCoreId(absl::string_view plane_name); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc similarity index 93% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc index e5bcd73c339be9..fc341c98582cc9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" #include #include "absl/strings/str_cat.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h b/third_party/xla/xla/tsl/profiler/utils/trace_utils.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h rename to third_party/xla/xla/tsl/profiler/utils/trace_utils.h index 98e18973a5ea79..ef53e611ab95fa 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/trace_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TRACE_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TRACE_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TRACE_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TRACE_UTILS_H_ #include @@ -81,4 +81,4 @@ static inline std::optional ParseDeviceOrdinal( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TRACE_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TRACE_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc index 6def9d1c768d47..ebcc5884ffe3e7 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/timespan.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.h b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_builder.h index 4a837c86ac00aa..522b34612d48bc 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ #include @@ -27,12 +27,12 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tsl/platform/macros.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/timespan.h" namespace tsl { namespace profiler { @@ -454,4 +454,4 @@ absl::string_view XStatsBuilder::StrOrRefValue(const XStat& stat) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder_test.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder_test.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc index 87c433773472b6..ee2c8e4df0400b 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_mutators.h b/third_party/xla/xla/tsl/profiler/utils/xplane_mutators.h similarity index 88% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_mutators.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_mutators.h index e558936b2130e6..0873aee9a41f3a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_mutators.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_mutators.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ #include #include -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" namespace tsl { namespace profiler { @@ -60,4 +60,4 @@ class XplaneEventMutatorFactory { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc index 81e370f3cc0c7f..d1ea0faf889d07 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include #include @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "xla/tsl/lib/gtl/map_util.h" -#include "tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_schema.h index 0d51de4d2905a5..d61b59b7464143 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ #include #include @@ -536,4 +536,4 @@ TF_CONST_INIT extern const absl::string_view kThreadpoolListenerRegion; } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc index f80e7e15bae4a7..548b444b912263 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_test_utils.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" #include #include @@ -20,11 +20,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.h b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h similarity index 88% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h index 4568feed8d0439..b2e5e58494c67a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ #include #include @@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { @@ -58,4 +58,4 @@ void CreateTfFunctionCallEvent(XPlaneBuilder* plane_builder, } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc index a327aab2d92c08..1efe3be09ce5a5 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include #include @@ -28,17 +28,17 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "xla/tsl/util/stats_calculator.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.h b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_utils.h index 8ea1429c1d90d2..7992d795ea0318 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ #include #include @@ -24,11 +24,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { @@ -222,4 +222,4 @@ bool IsDevicePlane(const XPlane& plane); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc index bdb5a3c4da4b03..cc350f83bccd24 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include #include @@ -24,14 +24,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc index 7d0a723221ffc5..b7bfad3f7211eb 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.h b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h index eb0eb01a6b17c2..a9c8510355cde2 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ #include @@ -25,9 +25,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tsl { namespace profiler { @@ -359,4 +359,4 @@ void XEventMetadataVisitor::ForEachChild( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ diff --git a/third_party/xla/xla/xla.bzl b/third_party/xla/xla/xla.bzl index 43f002ab499ee1..c1193cb53fe951 100644 --- a/third_party/xla/xla/xla.bzl +++ b/third_party/xla/xla/xla.bzl @@ -53,7 +53,7 @@ _XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_static(extra_deps = [], otherwise = [ "//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl", - "@local_tsl//tsl/profiler/utils:time_utils_impl", + "//xla/tsl/profiler/utils:time_utils_impl", "//xla/tsl/protobuf:protos_all_cc_impl", ]) + if_cuda_is_configured([ Label("//xla/stream_executor/cuda:all_runtime"), From aac1b267312c160ed98ba9d30c5e68b6231c8fec Mon Sep 17 00:00:00 2001 From: Joshua Lang Date: Mon, 30 Sep 2024 16:56:53 -0700 Subject: [PATCH 449/483] Rename cuda_libdevice_path to cuda_root_path PiperOrigin-RevId: 680768657 --- .../compiler/mlir/tools/kernel_gen/transforms/BUILD | 2 +- .../transforms/gpu_kernel_to_blob_pass.cc | 1 - tensorflow/core/platform/build_config.bzl | 4 ++-- third_party/xla/third_party/tsl/tsl/platform/BUILD | 10 +++++----- .../third_party/tsl/tsl/platform/build_config.bzl | 4 ++-- .../{cuda_libdevice_path.h => cuda_root_path.h} | 6 +++--- .../xla/third_party/tsl/tsl/platform/default/BUILD | 6 +++--- .../tsl/tsl/platform/default/build_config.bzl | 4 ++-- .../{cuda_libdevice_path.cc => cuda_root_path.cc} | 2 +- .../xla/xla/service/gpu/llvm_gpu_backend/BUILD | 13 +++++++++++-- .../service/gpu/llvm_gpu_backend/gpu_backend_lib.cc | 2 +- third_party/xla/xla/stream_executor/cuda/BUILD | 2 +- .../xla/stream_executor/cuda/cuda_asm_compiler.cc | 2 +- 13 files changed, 33 insertions(+), 25 deletions(-) rename third_party/xla/third_party/tsl/tsl/platform/{cuda_libdevice_path.h => cuda_root_path.h} (90%) rename third_party/xla/third_party/tsl/tsl/platform/default/{cuda_libdevice_path.cc => cuda_root_path.cc} (98%) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 489e13d172c059..dee541c450dabd 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -167,7 +167,7 @@ cc_library( "@local_xla//xla/service/gpu:target_constants", "@local_xla//xla/service/gpu/llvm_gpu_backend", ] + if_cuda_is_configured([ - "@local_tsl//tsl/platform:cuda_libdevice_path", + "@local_tsl//tsl/platform:cuda_root_path", "@local_xla//xla/stream_executor/cuda:cuda_asm_compiler", ]) + if_rocm_is_configured([ "@local_xla//xla/stream_executor/gpu:asm_compiler", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index e4b4c6cd84dc96..d5b0ce09538dbc 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" -#include "tsl/platform/cuda_libdevice_path.h" #if GOOGLE_CUDA #include "xla/stream_executor/cuda/cuda_asm_compiler.h" diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl index de0453b6deac98..dd10f841a5235c 100644 --- a/tensorflow/core/platform/build_config.bzl +++ b/tensorflow/core/platform/build_config.bzl @@ -11,7 +11,7 @@ load( _tf_additional_rpc_deps = "tf_additional_rpc_deps", _tf_additional_tensor_coding_deps = "tf_additional_tensor_coding_deps", _tf_additional_test_deps = "tf_additional_test_deps", - _tf_cuda_libdevice_path_deps = "tf_cuda_libdevice_path_deps", + _tf_cuda_root_path_deps = "tf_cuda_root_path_deps", _tf_fingerprint_deps = "tf_fingerprint_deps", _tf_google_mobile_srcs_no_runtime = "tf_google_mobile_srcs_no_runtime", _tf_google_mobile_srcs_only_runtime = "tf_google_mobile_srcs_only_runtime", @@ -53,7 +53,7 @@ tf_additional_lib_hdrs = _tf_additional_lib_hdrs tf_additional_rpc_deps = _tf_additional_rpc_deps tf_additional_tensor_coding_deps = _tf_additional_tensor_coding_deps tf_additional_test_deps = _tf_additional_test_deps -tf_cuda_libdevice_path_deps = _tf_cuda_libdevice_path_deps +tf_cuda_root_path_deps = _tf_cuda_root_path_deps tf_fingerprint_deps = _tf_fingerprint_deps tf_google_mobile_srcs_no_runtime = _tf_google_mobile_srcs_no_runtime tf_google_mobile_srcs_only_runtime = _tf_google_mobile_srcs_only_runtime diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD index f7d995e3e4065f..d9c4c7baca7d4a 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD @@ -21,7 +21,7 @@ load( load("@local_xla//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") load( "//tsl/platform:build_config.bzl", - "tf_cuda_libdevice_path_deps", + "tf_cuda_root_path_deps", "tf_error_logging_deps", "tf_fingerprint_deps", "tf_google_mobile_srcs_no_runtime", @@ -675,7 +675,7 @@ exports_files( "cpu_info.h", "crash_analysis.h", "criticality.h", - "cuda_libdevice_path.h", + "cuda_root_path.h", "demangle.h", "env.cc", "env.h", @@ -1197,10 +1197,10 @@ tsl_cc_test( ) cc_library( - name = "cuda_libdevice_path", + name = "cuda_root_path", compatible_with = get_compatible_with_portable(), - textual_hdrs = ["cuda_libdevice_path.h"], - deps = tf_cuda_libdevice_path_deps(), + textual_hdrs = ["cuda_root_path.h"], + deps = tf_cuda_root_path_deps(), ) cc_library( diff --git a/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl b/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl index bec0e8403b2488..4a22f84baf1493 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl @@ -11,7 +11,7 @@ load( _tf_additional_rpc_deps = "tf_additional_rpc_deps", _tf_additional_tensor_coding_deps = "tf_additional_tensor_coding_deps", _tf_additional_test_deps = "tf_additional_test_deps", - _tf_cuda_libdevice_path_deps = "tf_cuda_libdevice_path_deps", + _tf_cuda_root_path_deps = "tf_cuda_root_path_deps", _tf_error_logging_deps = "tf_error_logging_deps", _tf_fingerprint_deps = "tf_fingerprint_deps", _tf_google_mobile_srcs_no_runtime = "tf_google_mobile_srcs_no_runtime", @@ -49,7 +49,7 @@ tf_additional_lib_hdrs = _tf_additional_lib_hdrs tf_additional_rpc_deps = _tf_additional_rpc_deps tf_additional_tensor_coding_deps = _tf_additional_tensor_coding_deps tf_additional_test_deps = _tf_additional_test_deps -tf_cuda_libdevice_path_deps = _tf_cuda_libdevice_path_deps +tf_cuda_root_path_deps = _tf_cuda_root_path_deps tf_error_logging_deps = _tf_error_logging_deps tf_fingerprint_deps = _tf_fingerprint_deps tf_google_mobile_srcs_no_runtime = _tf_google_mobile_srcs_no_runtime diff --git a/third_party/xla/third_party/tsl/tsl/platform/cuda_libdevice_path.h b/third_party/xla/third_party/tsl/tsl/platform/cuda_root_path.h similarity index 90% rename from third_party/xla/third_party/tsl/tsl/platform/cuda_libdevice_path.h rename to third_party/xla/third_party/tsl/tsl/platform/cuda_root_path.h index d8c2b6d01daf43..65a9ca5a7acb0c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cuda_libdevice_path.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cuda_root_path.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PLATFORM_CUDA_LIBDEVICE_PATH_H_ -#define TENSORFLOW_TSL_PLATFORM_CUDA_LIBDEVICE_PATH_H_ +#ifndef TENSORFLOW_TSL_PLATFORM_CUDA_ROOT_PATH_H_ +#define TENSORFLOW_TSL_PLATFORM_CUDA_ROOT_PATH_H_ #include #include @@ -46,4 +46,4 @@ bool PreferPtxasFromPath(); } // namespace tsl -#endif // TENSORFLOW_TSL_PLATFORM_CUDA_LIBDEVICE_PATH_H_ +#endif // TENSORFLOW_TSL_PLATFORM_CUDA_ROOT_PATH_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index 785a44cfff0e7c..c824290dd85b7b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -56,9 +56,9 @@ cc_library( ) cc_library( - name = "cuda_libdevice_path", - srcs = ["cuda_libdevice_path.cc"], - hdrs = ["//tsl/platform:cuda_libdevice_path.h"], + name = "cuda_root_path", + srcs = ["cuda_root_path.cc"], + hdrs = ["//tsl/platform:cuda_root_path.h"], compatible_with = [], data = if_cuda_tools([ "@cuda_nvcc//:nvvm", diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl index 32b91de84c7e93..397f216698c27b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl @@ -845,5 +845,5 @@ def tf_google_mobile_srcs_no_runtime(): def tf_google_mobile_srcs_only_runtime(): return [] -def tf_cuda_libdevice_path_deps(): - return tf_platform_deps("cuda_libdevice_path") +def tf_cuda_root_path_deps(): + return tf_platform_deps("cuda_root_path") diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_root_path.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc rename to third_party/xla/third_party/tsl/tsl/platform/default/cuda_root_path.cc index ac0a804b4dfd42..ca6da0e5532eaa 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_root_path.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/cuda_libdevice_path.h" +#include "tsl/platform/cuda_root_path.h" #include diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index b1f957b0a4a812..5dd3f329182e22 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -11,7 +11,11 @@ load( "if_cuda_is_configured", ) load("//xla:xla.bzl", "xla_cc_test") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load( + "//xla/tsl:tsl.bzl", + "if_google", + "internal_visibility", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -36,6 +40,11 @@ cc_library( "gpu_backend_lib.h", "utils.h", ], + data = if_cuda_is_configured( + if_google( + ["@local_config_cuda//cuda:runtime_libdevice"], + ), + ), local_defines = if_cuda_is_configured([ "GOOGLE_CUDA=1", ]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), @@ -76,7 +85,7 @@ cc_library( "@llvm-project//llvm:Target", "@llvm-project//mlir:NVVMDialect", "@local_config_cuda//cuda:cuda_headers", - "@local_tsl//tsl/platform:cuda_libdevice_path", + "@local_tsl//tsl/platform:cuda_root_path", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index bfa3d41fed3a4d..402a0957ef1e8a 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -86,7 +86,7 @@ limitations under the License. #include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "xla/xla.pb.h" -#include "tsl/platform/cuda_libdevice_path.h" +#include "tsl/platform/cuda_root_path.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 35d33e03af429c..bb96c2ab09824f 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -937,7 +937,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", - "@local_tsl//tsl/platform:cuda_libdevice_path", + "@local_tsl//tsl/platform:cuda_root_path", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc index d788a8dd077fe6..c0be91811703b2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -57,7 +57,7 @@ limitations under the License. #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" -#include "tsl/platform/cuda_libdevice_path.h" +#include "tsl/platform/cuda_root_path.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" From 2846a34d64ddf8aaacac25c7ea3bfbf355d2ead6 Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Mon, 30 Sep 2024 17:04:40 -0700 Subject: [PATCH 450/483] Add pattern to match GELU with tf.Erfc implementation. PiperOrigin-RevId: 680771213 --- .../compiler/mlir/lite/tests/optimize.mlir | 14 ++++++++++++++ .../mlir/lite/transforms/optimize_patterns.td | 17 +++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index d70b18c9fa6036..284889ab0b4de4 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -3636,6 +3636,20 @@ func.func @gelu(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: "tfl.gelu"(%arg0) <{approximate = false}> : (tensor<3xf32>) -> tensor<3xf32> } +func.func @gelu_erfc(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.707106769> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %2 = "tfl.neg"(%arg0) : (tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tf.Erfc"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.mul"(%arg0, %cst_0) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%5, %4) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %6 : tensor<3xf32> + +// CHECK-LABEL:gelu_erfc +// CHECK: "tfl.gelu"(%arg0) <{approximate = false}> : (tensor<3xf32>) -> tensor<3xf32> +} + func.func @gelu_no_match(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst = arith.constant dense<0.707106769> : tensor %cst_0 = arith.constant dense<5.000000e-01> : tensor diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index c1cb138b2cadd0..ebf2508ab56dbd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1441,6 +1441,23 @@ def MatchGelu : Pat< (HasOneUse $mul_out1), ]>; +// For Gelu, replaces +// 0.5 * x * ( erfc( -x * sqrt_1_2 ) ) +def MatchGeluWithErfc : Pat< + (TFL_MulOp + (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), + (TF_ErfcOp:$erfc_out + (TFL_MulOp:$mul_out1 + (TFL_NegOp $arg0), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_1_2), TFL_AF_None)), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrFalse), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"0.707106769"> $Cst_sqrt_1_2), + (HasOneUse $mul_out), + (HasOneUse $erfc_out), + (HasOneUse $mul_out1), + ]>; + // Fetches the output of FC op, from the provided arguments. def GetFcOutput : NativeCodeCall< "GetFcOutput(&$_builder, $0, $1, $2, $3, $4, $5, $6, $7)">; From 50b0da21c4513ff7c0980223ce0a68934eaff419 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 30 Sep 2024 18:37:34 -0700 Subject: [PATCH 451/483] [XLA:Python] Improve the error message for the case where the previous permissive None treedef behavior is encountered. PiperOrigin-RevId: 680797730 --- third_party/xla/xla/python/pytree.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 5592a96454821d..e5662f9f5da674 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -595,9 +595,14 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { case PyTreeKind::kNone: if (!object.is_none()) { - throw std::invalid_argument( - absl::StrFormat("Expected None, got %s.", - nb::cast(nb::repr(object)))); + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object)))); } break; From dfe4765e06fc95015858ba15cf9756463b88e77b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 21:10:45 -0700 Subject: [PATCH 452/483] Automated Code Change PiperOrigin-RevId: 680840038 --- tensorflow/compiler/jit/BUILD | 6 +++++ .../increase_dynamism_for_auto_jit_pass.cc | 25 +++++++++++++++---- .../jit/increase_dynamism_for_auto_jit_pass.h | 1 + ...ncrease_dynamism_for_auto_jit_pass_test.cc | 19 +++++++++++++- 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 968aeb2f028d64..819ebafb446559 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1145,7 +1145,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -1315,10 +1318,13 @@ tf_cc_test( "//tensorflow/core:session_options", "//tensorflow/core:test", "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:device_set", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@local_xla//xla:test", ], ) diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 0ef7156ef9f593..c9e9c55ef1f2cc 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -14,25 +14,40 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" + #include + #include "absl/algorithm/container.h" -#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "xla/status_macros.h" -#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h index 818ca948d64b03..ad90bb1bbec647 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index e864ef1dd12ae9..3c3047a4f98fa2 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -15,14 +15,31 @@ limitations under the License. #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" +#include +#include "absl/status/status.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/core/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session_options.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace { From 331c4fc6f9c6a35e2e0162ac49186cebd11d57c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 21:31:41 -0700 Subject: [PATCH 453/483] Automated Code Change PiperOrigin-RevId: 680845639 --- tensorflow/core/grappler/optimizers/dependency_optimizer.cc | 2 +- tensorflow/core/grappler/optimizers/evaluation_utils.cc | 4 ++-- tensorflow/core/grappler/optimizers/evaluation_utils.h | 4 ++-- tensorflow/core/grappler/optimizers/loop_optimizer.cc | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 7eeb154f46fb2c..bfe53994837059 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -540,7 +540,7 @@ Status DependencyOptimizer::TransitiveReduction() { // highest index of a target of any control output from each node. int num_controls = 0; std::vector> outputs(num_nodes); - std::vector, 2>> control_outputs( + std::vector, 2UL>> control_outputs( num_nodes); // target_range[i] contains the range of node indices for which to compute // longest paths starting from node i. diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.cc b/tensorflow/core/grappler/optimizers/evaluation_utils.cc index 855635aefc1bb2..b17fee735cfd6c 100644 --- a/tensorflow/core/grappler/optimizers/evaluation_utils.cc +++ b/tensorflow/core/grappler/optimizers/evaluation_utils.cc @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -using TensorVector = gtl::InlinedVector; +using TensorVector = absl::InlinedVector; // In order to avoid the overhead of creating a large thread pool, we set a // small default thread count. This value should be revised should DeviceSimple @@ -81,7 +81,7 @@ Status EvaluateNode(const NodeDef& node, const TensorVector& inputs, params.op_kernel = op_kernel.get(); params.resource_manager = resource_mgr; - gtl::InlinedVector output_attrs; + absl::InlinedVector output_attrs; const int num_outputs = op_kernel->num_outputs(); for (int i = 0; i < num_outputs; i++) { AllocatorAttributes attr; diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.h b/tensorflow/core/grappler/optimizers/evaluation_utils.h index a146c9a5cad1ef..3c208d6110a24f 100644 --- a/tensorflow/core/grappler/optimizers/evaluation_utils.h +++ b/tensorflow/core/grappler/optimizers/evaluation_utils.h @@ -55,9 +55,9 @@ class DeviceSimple : public DeviceBase { }; Status EvaluateNode(const NodeDef& node, - const gtl::InlinedVector& inputs, + const absl::InlinedVector& inputs, DeviceBase* cpu_device, ResourceMgr* resource_mgr, - gtl::InlinedVector* output); + absl::InlinedVector* output); } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 94c2c22f472f19..b32b6ab850467e 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -55,7 +55,7 @@ namespace tensorflow { namespace grappler { namespace { -using TensorVector = gtl::InlinedVector; +using TensorVector = absl::InlinedVector; class LoopInvariantNodeMotionOptimizer { public: From 0dc32f2f5dd3f96a3c26924bb7cf31c32ee19762 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 22:02:19 -0700 Subject: [PATCH 454/483] Automated Code Change PiperOrigin-RevId: 680853324 --- .../experimental/shlo/quantized_tensor_element_type_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc b/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc index 10c5a878ddd931..de2bcc3f0d079d 100644 --- a/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc +++ b/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc @@ -28,9 +28,7 @@ namespace shlo_ref { namespace { using testing::Each; -using testing::ElementsAre; using testing::ElementsAreArray; -using testing::Eq; using testing::FloatEq; using testing::Pointwise; From d6a4de2f6c77afb80cf94667edeaa9d3187ff644 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 22:06:46 -0700 Subject: [PATCH 455/483] Automated Code Change PiperOrigin-RevId: 680854863 --- .../gpu/transforms/all_gather_dynamic_slice_simplifier.cc | 3 ++- .../gpu/transforms/all_gather_dynamic_slice_simplifier.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc index 4035b80606cdff..adf8d870b836ac 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc @@ -58,7 +58,8 @@ bool AllGatherDynamicSliceSimplifier::InstructionMatchesPattern( return match; } -StatusOr AllGatherDynamicSliceSimplifier::ExpandInstruction( +absl::StatusOr +AllGatherDynamicSliceSimplifier::ExpandInstruction( HloInstruction* instruction) { HloDynamicSliceInstruction* dynamic_slice = Cast(instruction); diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h index f0fb673ad1f6fa..52328f583e519b 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h @@ -39,7 +39,7 @@ class AllGatherDynamicSliceSimplifier : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; }; From a40fa3a9ef22c3429cdfc26daa61a111606bffb0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Sep 2024 22:08:33 -0700 Subject: [PATCH 456/483] Automated Code Change PiperOrigin-RevId: 680855416 --- .../core/grappler/optimizers/arithmetic_optimizer_test.cc | 2 +- .../core/grappler/optimizers/auto_mixed_precision_lists.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index bd10921cb877a1..22bfd0fea50aa6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -4496,7 +4496,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) { } else if (node.name() == "pc_slice_out") { ASSERT_EQ(node.input_size(), 1); EXPECT_EQ(node.input(0), "c"); - } else if (str_util::EndsWith(node.name(), "_out")) { + } else if (absl::EndsWith(node.name(), "_out")) { ASSERT_EQ(node.input_size(), 1); EXPECT_EQ( absl::StrCat(node.input(0), "_out"), diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index b99f3864d8daea..5e39c80f9fd424 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -102,7 +102,7 @@ class AutoMixedPrecisionListsFp16 : public AutoMixedPrecisionLists { TF_CHECK_OK( ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level)); - optimization_level = str_util::Uppercase(optimization_level); + optimization_level = absl::AsciiStrToUpper(optimization_level); return optimization_level == "TENSOR_CORES_ONLY"; } From cc3d3977997b3a2f75ce8e4c278470f7a4d970af Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Oct 2024 00:49:43 -0700 Subject: [PATCH 457/483] Automated Code Change PiperOrigin-RevId: 680897506 --- third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc index 49a99ee88a679c..b4c84bb0242c85 100644 --- a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc +++ b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc @@ -38,7 +38,6 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_stream_executor_client.h" #include "xla/service/platform_util.h" -#include "xla/service/stream_pool.h" #include "xla/stream_executor/platform.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" From 017fb01300cdd273749c173f7425d3e09d9fe3de Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Oct 2024 01:35:15 -0700 Subject: [PATCH 458/483] Automated Code Change PiperOrigin-RevId: 680910396 --- .../core/common_runtime/base_collective_executor.cc | 13 +++++++------ tensorflow/core/common_runtime/direct_session.cc | 10 +++++----- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index cef1f1a7e2b57b..29ae90d454b2d5 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -338,9 +338,10 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, core::ScopedUnref unref(col_impl); tsl::profiler::TraceMeConsumer consumer( [ctx, col_ctx] { - string op = profiler::TraceMeOp(ctx->op_kernel().name_view(), - ctx->op_kernel().type_string_view()); - return profiler::TraceMeEncode( + string op = + tsl::profiler::TraceMeOp(ctx->op_kernel().name_view(), + ctx->op_kernel().type_string_view()); + return tsl::profiler::TraceMeEncode( std::move(op), {{"step_id", ctx->step_id()}, {"iter_id", ctx->frame_iter().iter_id}, @@ -369,9 +370,9 @@ void BaseCollectiveExecutor::CompleteParamsAsync( // callback is responsible for invoking done() at the end. const auto is_callback_called = std::make_shared>(false); int64_t trace_id = tsl::profiler::TraceMe::ActivityStart([cp]() { - return profiler::TraceMeEncode("CollectiveExecutor::CompleteParams", - {{"group_key", cp->group.group_key}, - {"group_size", cp->group.group_size}}); + return tsl::profiler::TraceMeEncode("CollectiveExecutor::CompleteParams", + {{"group_key", cp->group.group_key}, + {"group_size", cp->group.group_size}}); }); auto done_safe = [this, is_callback_called, cancel_mgr, trace_id, diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index ecef91df59b923..7b2496637d66b1 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -525,12 +525,12 @@ Status DirectSession::RunInternal( options_.config.experimental().session_metadata(); string model_id = strings::StrCat(model_metadata.name(), ":", model_metadata.version()); - return profiler::TraceMeEncode("SessionRun", - {{"id", step_id}, - {"_r", 1} /*root_event*/, - {"model_id", model_id}}); + return tsl::profiler::TraceMeEncode("SessionRun", + {{"id", step_id}, + {"_r", 1} /*root_event*/, + {"model_id", model_id}}); } else { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( "SessionRun", {{"id", step_id}, {"_r", 1} /*root_event*/}); } }, From 83177ba4f9eeafe00596b99f9b99f26ab4c992c2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Oct 2024 01:36:05 -0700 Subject: [PATCH 459/483] Automated Code Change PiperOrigin-RevId: 680910687 --- tensorflow/core/tpu/graph_rewrite/BUILD | 2 ++ .../core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc | 2 +- .../core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc | 1 + .../core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc | 2 +- .../tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc | 1 + 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD index 73fbacd589160b..36cc709525ed15 100644 --- a/tensorflow/core/tpu/graph_rewrite/BUILD +++ b/tensorflow/core/tpu/graph_rewrite/BUILD @@ -76,6 +76,7 @@ cc_library( "//tensorflow/core/platform:strcat", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_xla//xla:status_macros", ], @@ -163,6 +164,7 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/common_runtime:optimization_registry", "//tensorflow/core/config:flag_defs", + "//tensorflow/core/framework:types_proto_cc", ], ) diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc index bca30520071c66..1862edab9cd38a 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/strings/str_join.h" #include "xla/status_macros.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/device.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/device_name_utils.h" diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index a13b3caba2fc17..ff4c030a0d560d 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/escaping.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc index 9370cd6b01ab1c..bb643300dfc8e4 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" @@ -62,7 +63,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_node_util.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc index a21cdaec4dbc72..a08a5e6be10a01 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/config/flag_defs.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/testlib.h" From bf54aa6b89db855bd663e785f5c4520d28572745 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Oct 2024 01:38:28 -0700 Subject: [PATCH 460/483] Automated Code Change PiperOrigin-RevId: 680911349 --- tensorflow/core/framework/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 263ac02be7b6aa..62bdeb4e00572b 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -722,6 +722,7 @@ cc_library( "//learning/deepmind/tensorflow/queues:__pkg__", "//learning/deepmind/tensorflow/sstable:__pkg__", "//learning/deepmind/video/tensorflow:__pkg__", + "//learning/sibyl/tfx/state/kernels:__pkg__", "//learning/sibyl/tfx/transformation/kernels:__pkg__", "//tensorflow/compiler/mlir/tools/kernel_gen:__pkg__", "//tensorflow/compiler/tf2xla:__pkg__", From d67cd391710c69f76456ce03fb18682b700f22f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Oct 2024 01:42:30 -0700 Subject: [PATCH 461/483] Automated Code Change PiperOrigin-RevId: 680912601 --- .../grappler/optimizers/arithmetic_optimizer.cc | 14 +++++++------- .../core/grappler/optimizers/constant_folding.cc | 2 +- .../core/grappler/optimizers/constant_folding.h | 7 ++++--- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index f3079b745d029c..27c8acfc854cd3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2015,7 +2015,7 @@ class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage { // chain of unary elementwise ops that are not outputs. if (IsReshape(*node)) { bool skip = false; - gtl::InlinedVector nodes_in_chain; + absl::InlinedVector nodes_in_chain; const auto predicate_fn = [this, node, &skip, &nodes_in_chain](const NodeDef& input) { nodes_in_chain.push_back(&input); @@ -3838,7 +3838,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { } auto copy_tensor_values_to_vector = - [node](const Tensor& t, gtl::InlinedVector* vec) { + [node](const Tensor& t, absl::InlinedVector* vec) { if (t.dtype() == DT_INT32) { auto t_flat = t.flat(); vec->assign(&t_flat(0), &t_flat(t.NumElements())); @@ -3853,8 +3853,8 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { return absl::OkStatus(); }; - gtl::InlinedVector slice_begin_vec; - gtl::InlinedVector slice_size_vec; + absl::InlinedVector slice_begin_vec; + absl::InlinedVector slice_size_vec; TF_RETURN_IF_ERROR( copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec)); TF_RETURN_IF_ERROR( @@ -3958,9 +3958,9 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { bool is_identity; bool is_simple_slice; bool slice_dim0; - gtl::InlinedVector slice_begin_vec; - gtl::InlinedVector slice_end_vec; - gtl::InlinedVector slice_strides_vec; + absl::InlinedVector slice_begin_vec; + absl::InlinedVector slice_end_vec; + absl::InlinedVector slice_strides_vec; TF_RETURN_IF_ERROR(ValidateStridedSliceOp( &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 90eb63a22bcaa7..ad69a3f5bd80c2 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -61,7 +61,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -using TensorVector = gtl::InlinedVector; +using TensorVector = absl::InlinedVector; // We only fold/materialize constants smaller than 100kB. const int64_t kMaxConstantSize = 100 * 1024; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 9305aa09764f3d..5a31b65717f91b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -89,8 +89,8 @@ class ConstantFolding : public GraphOptimizer { const GraphProperties* properties) const; Status EvaluateNode(const NodeDef& node, - const gtl::InlinedVector& inputs, - gtl::InlinedVector* output) const; + const absl::InlinedVector& inputs, + absl::InlinedVector* output) const; Status EvaluateOneFoldable(const NodeDef& node, std::vector* outputs, bool* result_too_large); @@ -232,7 +232,8 @@ class ConstantFolding : public GraphOptimizer { // input dimensions to reduce along are all of size 1 and keep_dims is true). bool IsReductionSimplifiableToIdentity( const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims, - const gtl::InlinedVector& reduction_indices_vector) const; + const absl::InlinedVector& reduction_indices_vector) + const; // Changes a reduction into an Identity op, returning true on success. bool ReplaceReductionWithIdentity(NodeDef* node) const; From 9e0ca3947b3b9796046b64486db0e32e8d127ef8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Oct 2024 01:48:21 -0700 Subject: [PATCH 462/483] Automated Code Change PiperOrigin-RevId: 680914425 --- tensorflow/lite/core/acceleration/configuration/BUILD | 8 ++++++++ .../core/acceleration/configuration/delegate_registry.cc | 1 + .../core/acceleration/configuration/delegate_registry.h | 1 + .../lite/core/acceleration/configuration/nnapi_plugin.cc | 2 ++ .../lite/core/acceleration/configuration/nnapi_plugin.h | 1 + .../core/acceleration/configuration/nnapi_plugin_test.cc | 5 ++++- .../configuration/stable_delegate_registry.cc | 1 + .../acceleration/configuration/stable_delegate_registry.h | 1 + .../configuration/stable_delegate_registry_test.cc | 1 + .../core/acceleration/configuration/xnnpack_plugin.cc | 2 +- .../acceleration/configuration/xnnpack_plugin_test.cc | 4 ++-- 11 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/core/acceleration/configuration/BUILD b/tensorflow/lite/core/acceleration/configuration/BUILD index cd2c147603d710..6e9eb895ebd0f1 100644 --- a/tensorflow/lite/core/acceleration/configuration/BUILD +++ b/tensorflow/lite/core/acceleration/configuration/BUILD @@ -19,6 +19,7 @@ cc_library( deps = [ "//tensorflow/lite/acceleration/configuration:configuration_fbs", "//tensorflow/lite/core/c:common", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], ) @@ -38,6 +39,7 @@ cc_library( "//tensorflow/lite/core/c:common", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/nnapi:nnapi_implementation_headers", + "//tensorflow/lite/nnapi:nnapi_lib", "@com_google_absl//absl/memory", ], alwayslink = 1, # For registration to always run. @@ -61,6 +63,9 @@ cc_test( "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/delegates/nnapi:nnapi_delegate_mock_test", "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/nnapi:nnapi_implementation_headers", + "//tensorflow/lite/nnapi:nnapi_lib", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", "@flatbuffers", ], @@ -76,6 +81,7 @@ cc_library( deps = [ "//tensorflow/lite/core/acceleration/configuration/c:stable_delegate", "//tensorflow/lite/core/shims:tflite_use_opaque_delegate", # buildcleaner: keep + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], ) @@ -85,6 +91,7 @@ cc_test( srcs = ["stable_delegate_registry_test.cc"], deps = [ ":stable_delegate_registry", + "//tensorflow/lite/core/acceleration/configuration/c:stable_delegate", "@com_google_googletest//:gtest_main", ], ) @@ -99,6 +106,7 @@ cc_library( deps = [ "//tensorflow/lite:minimal_logging", "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/acceleration/configuration:delegate_registry", "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@com_google_absl//absl/base:log_severity", diff --git a/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc b/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc index b28759b6af77b9..71ee43e2b5f935 100644 --- a/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc +++ b/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/synchronization/mutex.h" +#include "tensorflow/lite/acceleration/configuration/configuration_generated.h" namespace tflite { namespace delegates { diff --git a/tensorflow/lite/core/acceleration/configuration/delegate_registry.h b/tensorflow/lite/core/acceleration/configuration/delegate_registry.h index e3dc41e5dd707f..742e74389927a5 100644 --- a/tensorflow/lite/core/acceleration/configuration/delegate_registry.h +++ b/tensorflow/lite/core/acceleration/configuration/delegate_registry.h @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.cc b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.cc index 2adfaa5b2ae1ff..34dd7bbed229f6 100644 --- a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h" +#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" + namespace tflite { namespace delegates { diff --git a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h index 03721d73a98df4..8b86801be3d28c 100644 --- a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h" namespace tflite { diff --git a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc index 75179c020dc3cf..57a3042737600a 100644 --- a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc +++ b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" #include "tensorflow/lite/core/c/common.h" @@ -29,6 +29,9 @@ limitations under the License. #include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h" #include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h" +#include "tensorflow/lite/nnapi/nnapi_implementation.h" +#include "tensorflow/lite/schema/schema_generated.h" // Tests for checking that the NNAPI Delegate plugin correctly handles all the // options from the flatbuffer. diff --git a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.cc b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.cc index e5203762d2affa..87284f3bcfe074 100644 --- a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.cc +++ b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/synchronization/mutex.h" +#include "tensorflow/lite/core/acceleration/configuration/c/stable_delegate.h" namespace tflite { namespace delegates { diff --git a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h index ede67164500795..25ac647290fb49 100644 --- a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h +++ b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "tensorflow/lite/core/acceleration/configuration/c/stable_delegate.h" diff --git a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry_test.cc b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry_test.cc index a3b6725599e33d..c3a8335345ebdf 100644 --- a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry_test.cc +++ b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h" #include +#include "tensorflow/lite/core/acceleration/configuration/c/stable_delegate.h" namespace { diff --git a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc index 81566bcd7f186c..f0ee70606286ea 100644 --- a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include -#include "absl/memory/memory.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc index 2aa1d95a44f10d..47b03a32179e75 100644 --- a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc +++ b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc @@ -15,9 +15,9 @@ limitations under the License. // Some very simple unit tests of the (C++) XNNPack Delegate Plugin. -#include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "pthreadpool.h" // from @pthreadpool #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" From a7a799eb503acc1b4c04a53295c1f6e6f907b93c Mon Sep 17 00:00:00 2001 From: Emmanuel Ferdman Date: Tue, 1 Oct 2024 01:53:29 -0700 Subject: [PATCH 463/483] PR #17746: Update XlaBuilder reference Imported from GitHub PR https://github.com/openxla/xla/pull/17746 # PR Summary PR #17622 moved the location of `xla_builder.cc`. This PR adjusts sources to changes. Copybara import of the project: -- 679a3871a7cfa10160a123fddded728fd6a61853 by Emmanuel Ferdman : Update XlaBuilder reference Signed-off-by: Emmanuel Ferdman Merging this change closes #17746 PiperOrigin-RevId: 680916113 --- third_party/xla/docs/broadcasting.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/docs/broadcasting.md b/third_party/xla/docs/broadcasting.md index f47a6f7a642a66..0f781f06e869fd 100644 --- a/third_party/xla/docs/broadcasting.md +++ b/third_party/xla/docs/broadcasting.md @@ -99,7 +99,7 @@ dimensions 1 and 2 of the cuboid. This type of broadcast is used in the binary ops in `XlaBuilder`, if the `broadcast_dimensions` argument is given. For example, see -[XlaBuilder::Add](https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc). +[XlaBuilder::Add](https://github.com/openxla/xla/blob/main/xla/hlo/builder/xla_builder.cc). In the XLA source code, this type of broadcasting is sometimes called "InDim" broadcasting. From 5728c70b69da8f52be711ef781f9a444940f8c57 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Oct 2024 02:02:51 -0700 Subject: [PATCH 464/483] compat: Update forward compatibility horizon to 2024-10-01 PiperOrigin-RevId: 680919002 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index cd736eb9b0e74b..c14bf3a8442c06 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 30) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 10, 1) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From ebcb823a33f0dcf28b35c42724324ffeb0196804 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Oct 2024 02:02:52 -0700 Subject: [PATCH 465/483] Update GraphDef version to 2002. PiperOrigin-RevId: 680919010 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 641712a9d171ce..222796622f632e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2001 // Updated: 2024/9/30 +#define TF_GRAPH_DEF_VERSION 2002 // Updated: 2024/10/1 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 8520eeee6beec3c3f0a0366fb059f55e9117bf0f Mon Sep 17 00:00:00 2001 From: Harsha H S Date: Tue, 1 Oct 2024 02:20:33 -0700 Subject: [PATCH 466/483] PR #17544: [ROCm] Pass AMDGPU_TARGETS to crosstool wrapper Imported from GitHub PR https://github.com/openxla/xla/pull/17544 Passing amdgpu targets to crosstool wrapper which calls hipcc can restrict the kernels generated to specific set of supported amdgpu architectures. Copybara import of the project: -- aba828b02a32aeca576086e8e41aa8b6f70e4f39 by Harsha HS : [ROCm] Pass AMDGPU_TARGETS to crosstool wrapper Passing amdgpu targets to crosstool wrapper which calls hipcc can restrict the kernels generated to specific set of supported amdgpu architectures. Merging this change closes #17544 PiperOrigin-RevId: 680924738 --- third_party/xla/build_tools/rocm/run_xla.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/build_tools/rocm/run_xla.sh b/third_party/xla/build_tools/rocm/run_xla.sh index 23cc801fc260f0..c7fe4b7a77e4f3 100755 --- a/third_party/xla/build_tools/rocm/run_xla.sh +++ b/third_party/xla/build_tools/rocm/run_xla.sh @@ -67,6 +67,7 @@ bazel \ --local_test_jobs=${N_TEST_JOBS} \ --test_env=TF_TESTS_PER_GPU=$TF_TESTS_PER_GPU \ --test_env=TF_GPU_COUNT=$TF_GPU_COUNT \ + --action_env=TF_ROCM_AMDGPU_TARGETS=gfx90a \ --action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \ --action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \ --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute \ From 6e9f297c3fda181fdf6ae2feeeb76168aec1d919 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 1 Oct 2024 12:03:50 +0000 Subject: [PATCH 467/483] Fix merge conflicts --- .bazelrc | 6 --- third_party/xla/xla/client/lib/BUILD | 42 ------------------- .../xla/xla/service/gpu/gpu_compiler_test.cc | 8 ---- third_party/xla/xla/service/gpu/tests/BUILD | 14 ------- third_party/xla/xla/stream_executor/gpu/BUILD | 12 ------ 5 files changed, 82 deletions(-) diff --git a/.bazelrc b/.bazelrc index d0aeffff3a5497..b857fe975bebdd 100644 --- a/.bazelrc +++ b/.bazelrc @@ -358,16 +358,10 @@ build:linux --copt="-Werror=unused-result" # Add switch as an error on Linux. build:linux --copt="-Wswitch" build:linux --copt="-Werror=switch" -<<<<<<< HEAD -# Required for building with clang -build:linux --copt="-Wno-error=unused-but-set-variable" # We have some invalid linker scripts in the build, # so we need to disable this check build:linux --linkopt=-Wl,--undefined-version build:linux --host_linkopt=-Wl,--undefined-version -======= - ->>>>>>> upstream/master # Linux ARM64 specific options build:linux_arm64 --copt="-mtune=generic" --copt="-march=armv8-a" --copt="-O3" diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index ca91aeb0b8325a..3642ff36e9d34f 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -240,49 +240,7 @@ cc_library( name = "self_adjoint_eig", hdrs = ["self_adjoint_eig.h"], deps = [ -<<<<<<< HEAD - ":slicing", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "self_adjoint_eig_test", - srcs = ["self_adjoint_eig_test.cc"], - real_hardware_only = True, - shard_count = 5, - tags = ["optonly"], - local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1",]), - deps = [ - ":arithmetic", - ":constants", - ":math", - ":matrix", - ":self_adjoint_eig", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", -======= "//xla/hlo/builder/lib:self_adjoint_eig", ->>>>>>> upstream/master ], ) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 723fb32018131b..c2c523d4b93363 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -90,20 +90,12 @@ class GpuCompilerTest : public HloTestBase { return tensorflow::down_cast(compiler) ->RunPostSchedulingPipelines(module, 4 * 1024 * 1024, gpu_device_info); } -<<<<<<< HEAD - const auto& device_desc() { - return backend().default_stream_executor()->GetDeviceDescription(); - } - const se::GpuComputeCapability& GpuComputeComp() { - return device_desc().gpu_compute_capability(); -======= const stream_executor::GpuComputeCapability& GpuComputeComp() { return backend() .default_stream_executor() ->GetDeviceDescription() .gpu_compute_capability(); ->>>>>>> upstream/master } }; diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 63944d01c73e20..0577e171ce2afa 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -678,16 +678,6 @@ lit_test_suite( ], default_tags = tf_cuda_tests_tags(), hermetic_cuda_data_dir = "%S/../../../../../cuda_nvcc", -<<<<<<< HEAD -======= - tags_override = { - "element_wise_row_vectorization.hlo": ["cuda-only"], - "scatter_bf16.hlo": ["cuda-only"], - "single_instruction.hlo": ["cuda-only"], - "reduce_unnested.hlo": ["cuda-only"], - "reduction_vectorization_sm_all.hlo": ["cuda-only"], - }, ->>>>>>> upstream/master tools = [ "//xla/tools:hlo-opt", "@llvm-project//llvm:FileCheck", @@ -797,10 +787,6 @@ xla_test( "gpu_a100", "gpu_h100", ], -<<<<<<< HEAD -======= - tags = ["cuda-only"], ->>>>>>> upstream/master deps = if_cuda_is_configured( [ ":gpu_codegen_test", diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 13ce03c42229c1..3a3938adf31f65 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -613,23 +613,11 @@ tsl_gpu_library( tsl_gpu_library( name = "gpu_cudamallocasync_allocator", -<<<<<<< HEAD - srcs = [ - "gpu_cudamallocasync_allocator.cc", - ], - hdrs = ["gpu_cudamallocasync_allocator.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - cuda_deps = [ - "//xla/stream_executor/cuda:cuda_executor", -======= srcs = ["gpu_cudamallocasync_allocator.cc"], hdrs = ["gpu_cudamallocasync_allocator.h"], tags = [ "cuda-only", "gpu", ->>>>>>> upstream/master ], deps = [ ":gpu_init_impl", From da5df2547190976d4a3d07d302d0e45cd30ae0c2 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 1 Oct 2024 12:25:21 +0000 Subject: [PATCH 468/483] Fix dependency issue - tsl/protobuf has been moved to xla/tsl/protobuf --- tensorflow/core/grappler/optimizers/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index e12fb53b24c2d7..6d303acefb88f6 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -1191,7 +1191,7 @@ cc_library( ["//tensorflow/core/platform:tensor_float_32_hdr_lib"], ) + if_rocm_is_configured([ "//tensorflow/core/platform:stream_executor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", + "@local_xla//xla/tsl/protobuf:dnn_proto_cc", "@local_xla//xla/stream_executor/rocm:miopen_plugin" ]), ) From ff760cd260d0331705293f43685f2f022f9dcc14 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 1 Oct 2024 12:51:21 +0000 Subject: [PATCH 469/483] Change no_rocm tag to cuda-only in *.bazerc files and test scripts --- .../tools/ci_build/builds/run_pip_tests.sh | 2 +- .../ci_build/linux/rocm/rocm_py310_pip.sh | 2 +- .../ci_build/linux/rocm/rocm_py36_pip.sh | 2 +- .../ci_build/linux/rocm/rocm_py37_pip.sh | 2 +- .../ci_build/linux/rocm/rocm_py38_pip.sh | 2 +- .../ci_build/linux/rocm/rocm_py39_pip.sh | 2 +- .../tools/ci_build/linux/rocm/run_cpu.sh | 2 +- .../ci_build/linux/rocm/run_gpu_multi.sh | 2 +- .../ci_build/linux/rocm/run_gpu_single.sh | 2 +- .../tools/ci_build/linux/rocm/run_xla.sh | 4 +-- .../devel.usertools/cpu.bazelrc | 2 +- .../devel.usertools/gpu.bazelrc | 16 +++++------ .../devel.usertools/gpu_gcc.bazelrc | 28 +++++++++---------- 13 files changed, 34 insertions(+), 34 deletions(-) diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh index d91bfa5ae7d610..563055ec2d61a1 100755 --- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh +++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh @@ -97,7 +97,7 @@ if [[ ${IS_GPU} == "1" ]] || [[ ${IS_ROCM} == "1" ]]; then PIP_TEST_FILTER_TAG="-no_gpu,-no_pip_gpu,${PIP_TEST_FILTER_TAG}" fi if [[ ${IS_ROCM} == "1" ]]; then - PIP_TEST_FILTER_TAG="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36,-no_rocm,${PIP_TEST_FILTER_TAG}" + PIP_TEST_FILTER_TAG="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36,-cuda-only,${PIP_TEST_FILTER_TAG}" fi if [[ ${IS_MAC} == "1" ]]; then # TODO(b/122370901): Fix nomac, no_mac inconsistency. diff --git a/tensorflow/tools/ci_build/linux/rocm/rocm_py310_pip.sh b/tensorflow/tools/ci_build/linux/rocm/rocm_py310_pip.sh index 544f28df76a15a..bbaa595a2bb202 100755 --- a/tensorflow/tools/ci_build/linux/rocm/rocm_py310_pip.sh +++ b/tensorflow/tools/ci_build/linux/rocm/rocm_py310_pip.sh @@ -55,7 +55,7 @@ else fi # # Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_rocm${TF_TEST_FILTER_TAGS_ROCM_VERSION_SPECIFIC}' +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-cuda-only${TF_TEST_FILTER_TAGS_ROCM_VERSION_SPECIFIC}' export TF_BUILD_FLAGS="--config=release_base " export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ --test_env=TF2_BEHAVIOR=1 \ diff --git a/tensorflow/tools/ci_build/linux/rocm/rocm_py36_pip.sh b/tensorflow/tools/ci_build/linux/rocm/rocm_py36_pip.sh index 596703e59c9e34..32fbf7404c1525 100755 --- a/tensorflow/tools/ci_build/linux/rocm/rocm_py36_pip.sh +++ b/tensorflow/tools/ci_build/linux/rocm/rocm_py36_pip.sh @@ -43,7 +43,7 @@ export N_TEST_JOBS=$(expr ${TF_GPU_COUNT} \* ${TF_TESTS_PER_GPU}) source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh # # Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py36,-no_rocm' +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py36,-cuda-only' export TF_BUILD_FLAGS="--config=release_base " export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ --test_env=TF2_BEHAVIOR=1 \ diff --git a/tensorflow/tools/ci_build/linux/rocm/rocm_py37_pip.sh b/tensorflow/tools/ci_build/linux/rocm/rocm_py37_pip.sh index 64879e14392f6b..e84ae26993a139 100755 --- a/tensorflow/tools/ci_build/linux/rocm/rocm_py37_pip.sh +++ b/tensorflow/tools/ci_build/linux/rocm/rocm_py37_pip.sh @@ -57,7 +57,7 @@ fi # # Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_rocm'${TF_TEST_FILTER_TAGS_ROCM_VERSION_SPECIFIC} +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-cuda-only'${TF_TEST_FILTER_TAGS_ROCM_VERSION_SPECIFIC} export TF_BUILD_FLAGS="--config=release_base " export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ --test_env=TF2_BEHAVIOR=1 \ diff --git a/tensorflow/tools/ci_build/linux/rocm/rocm_py38_pip.sh b/tensorflow/tools/ci_build/linux/rocm/rocm_py38_pip.sh index 9f3dd586a359e0..96f33b4ee10e41 100755 --- a/tensorflow/tools/ci_build/linux/rocm/rocm_py38_pip.sh +++ b/tensorflow/tools/ci_build/linux/rocm/rocm_py38_pip.sh @@ -57,7 +57,7 @@ fi # # Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_rocm'${TF_TEST_FILTER_TAGS_ROCM_VERSION_SPECIFIC} +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-cuda-only'${TF_TEST_FILTER_TAGS_ROCM_VERSION_SPECIFIC} export TF_BUILD_FLAGS="--config=release_base " export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ --test_env=TF2_BEHAVIOR=1 \ diff --git a/tensorflow/tools/ci_build/linux/rocm/rocm_py39_pip.sh b/tensorflow/tools/ci_build/linux/rocm/rocm_py39_pip.sh index 5aee933eb1afc3..ef6a0d6b83f906 100755 --- a/tensorflow/tools/ci_build/linux/rocm/rocm_py39_pip.sh +++ b/tensorflow/tools/ci_build/linux/rocm/rocm_py39_pip.sh @@ -55,7 +55,7 @@ else fi # # Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_rocm'${TF_TEST_FILTER_TAGS_ROCM_VERSION_SPECIFIC} +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-cuda-only'${TF_TEST_FILTER_TAGS_ROCM_VERSION_SPECIFIC} export TF_BUILD_FLAGS="--config=release_base " export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ --test_env=TF2_BEHAVIOR=1 \ diff --git a/tensorflow/tools/ci_build/linux/rocm/run_cpu.sh b/tensorflow/tools/ci_build/linux/rocm/run_cpu.sh index 5e57a56104da25..2e70b1b16ce26b 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_cpu.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_cpu.sh @@ -34,7 +34,7 @@ yes "" | $PYTHON_BIN_PATH configure.py bazel test \ -k \ - --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-multi_gpu,-tpu,-no_rocm,-benchmark-test,-v1only \ + --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-multi_gpu,-tpu,-cuda-only,-benchmark-test,-v1only \ --jobs=${N_BUILD_JOBS} \ --local_test_jobs=${N_BUILD_JOBS} \ --test_timeout 600,900,2400,7200 \ diff --git a/tensorflow/tools/ci_build/linux/rocm/run_gpu_multi.sh b/tensorflow/tools/ci_build/linux/rocm/run_gpu_multi.sh index 20302eba135dad..9d251e16e8ad5f 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_gpu_multi.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_gpu_multi.sh @@ -68,7 +68,7 @@ else bazel test \ --config=rocm \ -k \ - --test_tag_filters=-no_gpu,-no_rocm \ + --test_tag_filters=-no_gpu,-cuda-only \ --jobs=30 \ --local_ram_resources=60000 \ --local_cpu_resources=15 \ diff --git a/tensorflow/tools/ci_build/linux/rocm/run_gpu_single.sh b/tensorflow/tools/ci_build/linux/rocm/run_gpu_single.sh index b02becccc9a2b3..713246c4beaad1 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_gpu_single.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_gpu_single.sh @@ -57,7 +57,7 @@ yes "" | $PYTHON_BIN_PATH configure.py bazel test \ --config=rocm \ -k \ - --test_tag_filters=gpu,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-no_rocm,-benchmark-test,-rocm_multi_gpu,-tpu,-v1only \ + --test_tag_filters=gpu,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-cuda-only,-benchmark-test,-rocm_multi_gpu,-tpu,-v1only \ --jobs=${N_BUILD_JOBS} \ --local_test_jobs=${N_TEST_JOBS} \ --test_env=TF_GPU_COUNT=$TF_GPU_COUNT \ diff --git a/tensorflow/tools/ci_build/linux/rocm/run_xla.sh b/tensorflow/tools/ci_build/linux/rocm/run_xla.sh index c5a045037ceac5..8f9b63966ad61c 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_xla.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_xla.sh @@ -75,8 +75,8 @@ else bazel \ test \ -k \ - --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,gpu,requires-gpu,-no_gpu,-no_rocm --keep_going \ - --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,gpu,requires-gpu,-no_gpu,-no_rocm \ + --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,gpu,requires-gpu,-no_gpu,-cuda-only --keep_going \ + --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,gpu,requires-gpu,-no_gpu,-cuda-only \ --config=rocm \ --test_output=errors \ --local_test_jobs=${N_TEST_JOBS} \ diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc index d41aa5293688a1..0f7d397efa4fe2 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc @@ -37,7 +37,7 @@ build:build_event_export --build_event_json_file=/tf/pkg/bep.json build:rbe --config=rbe_linux_cpu # For continuous builds -test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-no_rocm,-benchmark-test,-v1only +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-cuda-only,-benchmark-test,-v1only test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc index c80d2d496aaf71..30b4c4a67c8d8e 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc @@ -32,14 +32,14 @@ test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # "nonpip_large" will run tests marked as large as well -test:nonpip_filters_large --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_cuda11,-no_rocm,-benchmark-test,-tpu,-v1only -test:nonpip_filters_large --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_cuda11,-no_rocm +test:nonpip_filters_large --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_cuda11,-cuda-only,-benchmark-test,-tpu,-v1only +test:nonpip_filters_large --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_cuda11,-cuda-only test:nonpip_filters_large --test_lang_filters=py --flaky_test_attempts=2 --test_size_filters=small,medium,large test:nonpip_large --config=nonpip_filters_large -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # "nonpip_filter_multi_gpu" will run a defined set of multi-gpu tests -test:nonpip_filters_multi_gpu --test_tag_filters=-no_gpu,-no_rocm -test:nonpip_filters_multi_gpu --build_tag_filters=-no_gpu,-no_rocm +test:nonpip_filters_multi_gpu --test_tag_filters=-no_gpu,-cuda-only +test:nonpip_filters_multi_gpu --build_tag_filters=-no_gpu,-cuda-only test:nonpip_filters_multi_gpu --test_lang_filters=py --flaky_test_attempts=2 --test_size_filters=small,medium,large --test_env=TF_PER_DEVICE_MEMORY_LIMIT_MB=2048 test:nonpip_multi_gpu --config=nonpip_filters_multi_gpu -- \ //tensorflow/core/nccl:nccl_manager_test_2gpu \ @@ -75,12 +75,12 @@ build:build_event_export --build_event_json_file=/tf/pkg/bep.json build:rbe --config=rbe_linux_cuda # For continuous builds -test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-no_rocm -test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-no_rocm +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-cuda-only +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-cuda-only test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/dtensor/python/tests:multi_client_test_2gpus -//tensorflow/dtensor/python/tests:multi_client_test_nccl_2gpus -//tensorflow/python/distribute/experimental:multi_worker_mirrored_strategy_test_2gpus # For XLA (rocm) -test:xla_cpp_filters --test_tag_filters=gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-no_rocm,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only --keep_going -test:xla_cpp_filters --build_tag_filters=gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-no_rocm,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only +test:xla_cpp_filters --test_tag_filters=gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-cuda-only,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only --keep_going +test:xla_cpp_filters --build_tag_filters=gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-cuda-only,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only test:xla_cpp --config=xla_cpp_filters -- //xla/... //build_tools/... diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu_gcc.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu_gcc.bazelrc index 936a84f77847b8..476f33ca75c1bf 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu_gcc.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu_gcc.bazelrc @@ -66,20 +66,20 @@ test --test_summary=short # Pass --config=nonpip to run the same suite of tests. If you want to run just # one test for investigation, you don't need --config=nonpip; just run the # bazel test invocation as normal. -test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_rocm,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_rocm,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-cuda-only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-cuda-only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # "nonpip_large" will run tests marked as large as well -test:nonpip_filters_large --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_cuda11,-no_rocm,-benchmark-test,-tpu,-v1only -test:nonpip_filters_large --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_cuda11,-no_rocm +test:nonpip_filters_large --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_cuda11,-cuda-only,-benchmark-test,-tpu,-v1only +test:nonpip_filters_large --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_cuda11,-cuda-only test:nonpip_filters_large --test_lang_filters=py --flaky_test_attempts=2 --test_size_filters=small,medium,large test:nonpip_large --config=nonpip_filters_large -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # "nonpip_filter_multi_gpu" will run a defined set of multi-gpu tests -test:nonpip_filters_multi_gpu --test_tag_filters=-no_gpu,-no_rocm -test:nonpip_filters_multi_gpu --build_tag_filters=-no_gpu,-no_rocm +test:nonpip_filters_multi_gpu --test_tag_filters=-no_gpu,-cuda-only +test:nonpip_filters_multi_gpu --build_tag_filters=-no_gpu,-cuda-only test:nonpip_filters_multi_gpu --test_lang_filters=py --flaky_test_attempts=2 --test_size_filters=small,medium,large --test_env=TF_PER_DEVICE_MEMORY_LIMIT_MB=2048 test:nonpip_multi_gpu --config=nonpip_filters_multi_gpu -- \ //tensorflow/core/nccl:nccl_manager_test_2gpu \ @@ -121,8 +121,8 @@ test:pip_venv --action_env PYTHON_LIB_PATH="/bazel_pip/lib/python3/site-packages test:pip_venv --python_path="/bazel_pip/bin/python3" test:pip_venv --define=no_tensorflow_py_deps=true # Yes, we don't exclude the gpu tests on pip for some reason. -test:pip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_pip,-nopip,-no_rocm,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:pip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_pip,-nopip,-no_rocm,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:pip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_pip,-nopip,-cuda-only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:pip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_pip,-nopip,-cuda-only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:pip_filters --test_lang_filters=py --test_size_filters=small,medium test:pip --config=pip_venv --config=pip_filters -- //bazel_pip/tensorflow/... -//bazel_pip/tensorflow/python/integration_testing/... -//bazel_pip/tensorflow/compiler/tf2tensorrt/... -//bazel_pip/tensorflow/compiler/xrt/... -//bazel_pip/tensorflow/core/tpu/... -//bazel_pip/tensorflow/lite/... -//tensorflow/tools/toolchains/... @@ -166,17 +166,17 @@ build:rbe --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.14_config_nccl" build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14_config_python" # For continuous builds -test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-no_rocm -test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-no_rocm +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-cuda-only +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-cuda-only test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -test:pycpp_large_filters --test_tag_filters=-no_oss,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-no_rocm -test:pycpp_large_filters --build_tag_filters=-no_oss,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-no_rocm +test:pycpp_large_filters --test_tag_filters=-no_oss,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-cuda-only +test:pycpp_large_filters --build_tag_filters=-no_oss,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11,-cuda-only test:pycpp_large_filters --test_lang_filters=cc,py --flaky_test_attempts=3 --test_size_filters=small,medium,large test:pycpp_large --config=pycpp_large_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # For XLA (rocm) -test:xla_cpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,gpu,requires-gpu,-no_gpu,-no_rocm --keep_going -test:xla_cpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,gpu,requires-gpu,-no_gpu,-no_rocm +test:xla_cpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,gpu,requires-gpu,-no_gpu,-cuda-only --keep_going +test:xla_cpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,gpu,requires-gpu,-no_gpu,-cuda-only test:xla_cpp --config=xla_cpp_filters -- //xla/... //build_tools/... From de1e449a903ffc8053b220c0950ce407ed95136e Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Mon, 7 Oct 2024 09:27:58 +0000 Subject: [PATCH 470/483] Bring back cuda_only_cc_library --- .../xla/xla/stream_executor/build_defs.bzl | 30 +++++++++++++++ .../xla/xla/stream_executor/cuda/BUILD | 38 +++++++++++-------- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/stream_executor/build_defs.bzl b/third_party/xla/xla/stream_executor/build_defs.bzl index 3204b886c651ff..5c153b733ba0c7 100644 --- a/third_party/xla/xla/stream_executor/build_defs.bzl +++ b/third_party/xla/xla/stream_executor/build_defs.bzl @@ -1,5 +1,6 @@ """Configurations for StreamExecutor builds""" +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load( "@local_config_rocm//rocm:build_defs.bzl", _if_cuda_or_rocm = "if_cuda_or_rocm", @@ -63,5 +64,34 @@ def gpu_only_cc_library(name, tags = [], **kwargs): target_compatible_with = kwargs.get("target_compatible_with"), ) +def cuda_only_cc_library(name, tags = [], **kwargs): + """A library that only gets compiled when CUDA is configured, otherwise it's an empty target. + Args: + name: Name of the target + tags: Tags being applied to the implementation target + **kwargs: Accepts all arguments that a `cc_library` would also accept + """ + if not native.package_name().startswith("xla/stream_executor"): + fail("cuda_only_cc_library may only be used in `xla/stream_executor/...`.") + + cc_library( + name = "%s_non_cuda" % name, + tags = ["manual"], + ) + cc_library( + name = "%s_cuda_only" % name, + tags = tags + ["manual", "cuda-only"], + **kwargs + ) + native.alias( + name = name, + actual = if_cuda_is_configured(":%s_cuda_only" % name, ":%s_non_cuda" % name), + visibility = kwargs.get("visibility"), + compatible_with = kwargs.get("compatible_with"), + restricted_to = kwargs.get("restricted_to"), + target_compatible_with = kwargs.get("target_compatible_with"), + ) + + def stream_executor_build_defs_bzl_deps(): return [] diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index bb96c2ab09824f..68fdd4a79ad45b 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -10,14 +10,20 @@ load( ) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", "if_cuda_newer_than", ) load( "//xla:xla.bzl", "xla_cc_test", ) +load( + "//xla/service/gpu:build_defs.bzl", + "gpu_kernel_library", +) load( "//xla/stream_executor:build_defs.bzl", + "cuda_only_cc_library", "stream_executor_friends", "tf_additional_cuda_platform_deps", "tf_additional_cudnn_plugin_copts", @@ -81,7 +87,7 @@ cc_library( deps = ["//xla/stream_executor:platform"], ) -cc_library( +cuda_only_cc_library( name = "cuda_platform", srcs = ["cuda_platform.cc"], hdrs = ["cuda_platform.h"], @@ -121,7 +127,7 @@ cc_library( alwayslink = True, # Registers itself with the PlatformManager. ) -cc_library( +cuda_only_cc_library( name = "cuda_diagnostics", srcs = ["cuda_diagnostics.cc"], hdrs = ["cuda_diagnostics.h"], @@ -159,7 +165,7 @@ cc_library( ), ) -cc_library( +cuda_only_cc_library( name = "cuda_driver", srcs = ["cuda_driver.cc"], hdrs = ["cuda_driver.h"], @@ -204,7 +210,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_status", srcs = ["cuda_status.cc"], hdrs = ["cuda_status.h"], @@ -220,7 +226,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_runtime", srcs = ["cuda_runtime.cc"], hdrs = ["cuda_runtime.h"], @@ -239,7 +245,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_collectives", hdrs = ["cuda_collectives.h"], tags = [ @@ -339,7 +345,7 @@ xla_test( ], ) -cc_library( +cuda_only_cc_library( name = "cublas_lt_header", hdrs = [ "cuda_blas_lt.h", @@ -365,7 +371,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cublas_plugin", srcs = [ "cuda_blas.cc", @@ -431,7 +437,7 @@ cc_library( alwayslink = True, ) -cc_library( +cuda_only_cc_library( name = "cuda_blas_utils", srcs = ["cuda_blas_utils.cc"], hdrs = ["cuda_blas_utils.h"], @@ -451,7 +457,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cufft_plugin", srcs = ["cuda_fft.cc"], hdrs = ["cuda_fft.h"], @@ -484,7 +490,7 @@ cc_library( alwayslink = True, ) -cuda_library( +gpu_kernel_library( name = "delay_kernel_cuda", srcs = [ "delay_kernel.h", @@ -508,7 +514,7 @@ cuda_library( ], ) -cc_library( +cuda_only_cc_library( name = "cudnn_plugin", srcs = ["cuda_dnn.cc"], hdrs = ["cuda_dnn.h"], @@ -569,7 +575,7 @@ cc_library( alwayslink = True, ) -cc_library( +cuda_only_cc_library( name = "cuda_kernel", srcs = ["cuda_kernel.cc"], hdrs = ["cuda_kernel.h"], @@ -637,7 +643,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_event", srcs = ["cuda_event.cc"], hdrs = ["cuda_event.h"], @@ -880,7 +886,7 @@ xla_cc_test( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_asm_compiler", srcs = ["cuda_asm_compiler.cc"], hdrs = ["cuda_asm_compiler.h"], @@ -948,7 +954,7 @@ cc_library( ], ) -cc_library( +cuda_only_cc_library( name = "cuda_executor", srcs = [ "cuda_executor.cc", From b4e5e53a176a581606faf5615019d6acb16a26fb Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Mon, 7 Oct 2024 09:33:58 +0000 Subject: [PATCH 471/483] Fix build issues caused by making cuda_deps actually cuda specific --- tensorflow/core/grappler/BUILD | 6 ++---- tensorflow/tensorflow.bzl | 7 +++---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index aa6834aafc5d8a..d8b6d1b228ef61 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -76,15 +76,13 @@ tf_cuda_library( name = "devices", srcs = ["devices.cc"], hdrs = ["devices.h"], - cuda_deps = [ - "@local_xla//xla/stream_executor/gpu:gpu_init", - "//tensorflow/core/platform:stream_executor", - ], visibility = ["//visibility:public"], deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/log", + "@local_xla//xla/stream_executor/gpu:gpu_init", + "//tensorflow/core/platform:stream_executor", ], ) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index aa40b33739a7e3..59674b9e0b660b 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2027,7 +2027,7 @@ def tf_kernel_library( [prefix + "*impl.h"], exclude = [prefix + "*test*", prefix + "*.cu.h"], ) - cuda_deps = [clean_dep("//tensorflow/core:gpu_lib")] + cuda_or_rocm_deps = [clean_dep("//tensorflow/core:gpu_lib")] if gpu_srcs: for gpu_src in gpu_srcs: if gpu_src.endswith(".cc") and not gpu_src.endswith(".cu.cc"): @@ -2040,7 +2040,7 @@ def tf_kernel_library( copts = gpu_copts, **kwargs ) - cuda_deps.extend([":" + name + "_gpu"]) + cuda_or_rocm_deps.extend([":" + name + "_gpu"]) kwargs["tags"] = kwargs.get("tags", []) + [ "req_dep=%s" % clean_dep("//tensorflow/core:gpu_lib"), "req_dep=@local_config_cuda//cuda:cuda_headers", @@ -2051,10 +2051,9 @@ def tf_kernel_library( hdrs = hdrs, textual_hdrs = textual_hdrs, copts = copts, - cuda_deps = cuda_deps + gpu_deps, linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 alwayslink = alwayslink, - deps = deps, + deps = deps + gpu_deps + cuda_or_rocm_deps, compatible_with = compatible_with, **kwargs ) From ce1b20f1aa8f9d16665bceb12d6e73f242138144 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Mon, 7 Oct 2024 09:51:46 +0000 Subject: [PATCH 472/483] Enable gpu_kernel_tiling_test --- third_party/xla/xla/service/gpu/tests/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 0577e171ce2afa..7ec7cef84993c5 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -399,7 +399,6 @@ xla_test( "gpu_p100", "gpu_amd_any", ] + if_oss(["gpu_any"]), - tags = ["no_rocm"], # TODO(rocm): weekly sync 24-08-20 deps = [ ":gpu_codegen_test", "//xla:error_spec", From 874c457af7ed851372994166cdfa211ba3028e95 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 8 Oct 2024 07:03:59 +0000 Subject: [PATCH 473/483] Fix ops_testutils deps issue --- tensorflow/core/kernels/BUILD | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ed240a3e54936e..81decd9f1019f9 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -467,14 +467,12 @@ tf_cuda_library( testonly = 1, srcs = ["ops_testutil.cc"], hdrs = ["ops_testutil.h"], - cuda_deps = [ - "//tensorflow/core:gpu_lib", - "//tensorflow/core:gpu_runtime", - ], deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:gpu_lib", + "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", From 451d15bb5fe71f4596be0e9efe67175cf6e58919 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 8 Oct 2024 08:30:46 +0000 Subject: [PATCH 474/483] Enable some of the unit tests --- tensorflow/compiler/mlir/lite/debug/BUILD | 1 - tensorflow/compiler/mlir/lite/tests/BUILD | 3 --- tensorflow/compiler/mlir/lite/tests/end2end/BUILD | 4 ---- tensorflow/compiler/mlir/quantization/stablehlo/BUILD | 2 +- tensorflow/compiler/mlir/tensorflow/tests/BUILD | 2 -- tensorflow/compiler/tests/BUILD | 7 +++---- tensorflow/core/kernels/mlir_generated/BUILD | 3 +-- tensorflow/dtensor/python/tests/BUILD | 5 ++--- tensorflow/lite/python/kernel_tests/signal/BUILD | 1 - tensorflow/python/client/BUILD | 2 -- tensorflow/python/saved_model/BUILD | 1 - third_party/xla/xla/service/gpu/tests/BUILD | 3 +-- 12 files changed, 8 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/debug/BUILD b/tensorflow/compiler/mlir/lite/debug/BUILD index c8bfd87e378aa1..4cd1b6e7eddb35 100644 --- a/tensorflow/compiler/mlir/lite/debug/BUILD +++ b/tensorflow/compiler/mlir/lite/debug/BUILD @@ -44,7 +44,6 @@ cc_library( tf_cc_test( name = "debug_test", srcs = ["debug_test.cc"], - tags = ["no_rocm"], deps = [ ":debug", ":debug_options_proto_cc", diff --git a/tensorflow/compiler/mlir/lite/tests/BUILD b/tensorflow/compiler/mlir/lite/tests/BUILD index 72efe28296cf55..9ec3f85c3545cd 100644 --- a/tensorflow/compiler/mlir/lite/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/BUILD @@ -19,9 +19,6 @@ glob_lit_tests( "raise-custom-ops.mlir": "medium", }, tags_override = { - "legalize-tf.mlir": ["no_rocm"], - "optimize.mlir": ["no_rocm"], - "prepare-tf.mlir": ["no_rocm"], "const-fold.mlir": ["no_mac_arm64"], }, test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD index 529fe381fa3bfe..f3a41720ae33e1 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD @@ -15,10 +15,6 @@ glob_lit_tests( size_override = { "quant_stats.pbtxt": "medium", }, - tags_override = { - "add.pbtxt": ["no_rocm"], - "fake_quant_per_channel.pbtxt": ["no_rocm"], - }, test_file_exts = [ "pbtxt", ], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 9eea3596a84296..ba674527171041 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -381,7 +381,7 @@ tf_cc_test( srcs = [ "passes/bridge/convert_tf_quant_to_mhlo_int_test.cc", ], - tags = ["nomac", "no_rocm"], # TODO(b/297362678): re-enable mac test. + tags = ["nomac"], # TODO(b/297362678): re-enable mac test. deps = [ ":bridge_passes", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/BUILD index 9ac51ab0ecb6aa..a446200b75bfee 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/BUILD @@ -17,8 +17,6 @@ glob_lit_tests( "layout_optimization_to_nhwc.mlir": "medium", }, tags_override = { - "optimize.mlir": ["no_rocm"], - "tf_optimize.mlir": ["no_rocm"], "tf-reduce-identity.mlir": ["no_windows"], }, test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index e4fb5a414c614f..cc4fcf9c87c0ee 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -907,7 +907,6 @@ tf_xla_py_strict_test( shard_count = 12, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "no_rocm", "optonly", ], deps = [ @@ -1791,7 +1790,7 @@ tf_xla_py_strict_test( python_version = "PY3", shard_count = 20, tags = [ - "no_rocm", + "cuda-only", "no_aarch64", # TODO(b/348125886) "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2866,7 +2865,7 @@ tf_cuda_cc_test( tags = [ "config-cuda-only", "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "no_rocm", # ROCmSoftwarePlatform #958 + "cuda-only", # ROCmSoftwarePlatform #958 "noasan", # TODO(b/201651800) "requires-gpu-nvidia", ] + tf_cuda_tests_tags(), @@ -2887,7 +2886,7 @@ tf_cuda_cc_test( tags = [ "config-cuda-only", "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "no_rocm", # ROCmSoftwarePlatform #958 + "cuda-only", # ROCmSoftwarePlatform #958 "noasan", # TODO(b/201651800) "requires-gpu-nvidia", ] + tf_cuda_tests_tags(), diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index b88ee74c0999ef..5f306f29b9146e 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -455,7 +455,6 @@ tf_cuda_cc_test( tags = tf_cuda_tests_tags() + [ "no_cuda", # TODO(b/196608406): re-enable "no_cuda_asan", # TODO(b/171341759): re-enable. - "no_rocm", # fail on CI ], deps = [ ":base_ops_test", @@ -545,7 +544,7 @@ tf_cuda_cc_test( tags = tf_cuda_tests_tags() + [ "no_cuda", # TODO(b/196608406): re-enable "no_cuda_asan", # TODO(b/171341759): re-enable. - "no_rocm", + "cuda-only", ], deps = [ ":base_binary_ops_test", diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index 8e61e6083a849d..c358acd9755d6d 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -555,7 +555,6 @@ dtensor_test( "NCCL_P2P_DISABLE": "1", # FIXME(b/251183104): p2p detection in cuda 10.1+ is broken. }, tags = [ - "no_rocm", "no_windows", "nosan", # b/195537906 ], @@ -749,7 +748,7 @@ dtensor_test( TPU_V3_DONUT_BACKEND: 32, }, tags = [ - "no_rocm", + "cuda-only", ], deps = [ ":test_util", @@ -804,7 +803,7 @@ dtensor_test( }, tags = [ "no_oss_py38", # TODO(b/267017937) - "no_rocm", + "cuda-only", ], deps = [ ":test_util", diff --git a/tensorflow/lite/python/kernel_tests/signal/BUILD b/tensorflow/lite/python/kernel_tests/signal/BUILD index a6128e6be32f54..ec560911cc686d 100644 --- a/tensorflow/lite/python/kernel_tests/signal/BUILD +++ b/tensorflow/lite/python/kernel_tests/signal/BUILD @@ -24,7 +24,6 @@ cuda_py_strict_test( python_version = "PY3", shard_count = 4, tags = [ - "no_rocm", "no_windows_gpu", ], deps = [ diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index 1c7805a4d5158a..f8a5c6a1627a55 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -437,9 +437,7 @@ tf_py_strict_test( python_version = "PY3", tags = [ "no_gpu", - "no_rocm", "no_windows", - "no_rocm", ], deps = [ ":session", diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index e909b79797f022..12ad8fadc6c0e0 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -597,7 +597,6 @@ cuda_py_strict_test( tags = [ "no_gpu", # TODO(b/136560979): flaky "no_mac", # TODO(b/124822121): Re-enable this test. - "no_rocm", ], deps = [ ":load", diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 7ec7cef84993c5..6c0edab8c37d2f 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -74,7 +74,6 @@ xla_test( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = [ "notsan", # TODO(b/345034145): Fix tsan error. - "no_rocm", # TODO(rocm): sync 24-08-20 ], deps = if_gpu_is_configured( #keep sorted @@ -322,7 +321,7 @@ xla_test( "gpu_a100", "gpu_v100", ], - tags = ["no_rocm"], # TODO(rocm) 240729 Test checks only for cuda capability + tags = ["cuda-only"], # TODO(rocm) 240729 Test checks only for cuda capability deps = [ ":gpu_codegen_test", "//xla:shape_util", From 8fc123214df0e473a758e50e30fb183cd7a237e9 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 8 Oct 2024 08:31:23 +0000 Subject: [PATCH 475/483] Change remaining "no_rocm" tags to "cuda-only" --- .../mlir/tensorflow/tests/tf_saved_model/BUILD | 2 +- tensorflow/core/kernels/mkl/BUILD | 2 +- tensorflow/core/nccl/BUILD | 2 +- tensorflow/core/profiler/backends/gpu/BUILD | 2 +- tensorflow/core/util/autotune_maps/BUILD | 2 +- tensorflow/dtensor/cc/BUILD | 2 +- tensorflow/python/compiler/tensorrt/test/BUILD | 2 +- tensorflow/python/debug/lib/BUILD | 4 ++-- tensorflow/python/distribute/BUILD | 4 ++-- tensorflow/python/feature_column/BUILD | 4 ++-- tensorflow/python/framework/BUILD | 2 +- tensorflow/python/kernel_tests/image_ops/BUILD | 2 +- tensorflow/python/kernel_tests/linalg/BUILD | 2 +- tensorflow/python/kernel_tests/math_ops/BUILD | 2 +- tensorflow/python/kernel_tests/nn_ops/BUILD | 6 +++--- tensorflow/python/kernel_tests/sparse_ops/BUILD | 2 +- tensorflow/python/ops/BUILD | 13 ++++++------- tensorflow/python/ops/numpy_ops/tests/BUILD | 2 +- tensorflow/python/ops/parallel_for/BUILD | 2 +- tensorflow/python/tools/BUILD | 8 ++++---- tensorflow/python/training/BUILD | 2 +- tensorflow/tools/docs/BUILD | 4 ++-- third_party/xla/xla/service/gpu/BUILD | 10 +++------- .../xla/xla/service/gpu/fusions/triton/BUILD | 6 +++--- third_party/xla/xla/tests/BUILD | 7 ++----- 25 files changed, 44 insertions(+), 52 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD index f2f686f9822927..b80b275c560741 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD @@ -296,7 +296,7 @@ glob_lit_tests( default_tags = [ "no_mac", # TODO(b/191167848) "no_oss", # TODO(b/190855110) - "no_rocm", + "cuda-only", ], driver = "@llvm-project//mlir:run_lit.sh", exclude = [ diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD index 2f099cc93b70b0..ba75ffc4e6b429 100644 --- a/tensorflow/core/kernels/mkl/BUILD +++ b/tensorflow/core/kernels/mkl/BUILD @@ -427,7 +427,7 @@ tf_cc_test_mkl( size = "small", srcs = ["mkl_fused_batch_norm_op_test.cc"], linkstatic = 1, - tags = ["no_rocm"], # fails on AMD Rome CPUs as of 2021-03-29 + tags = ["cuda-only"], # fails on AMD Rome CPUs as of 2021-03-29 deps = [ ":mkl_conv_op", ":mkl_fused_batch_norm_op", diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 99dc7559d59982..b8471de41fe6fc 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -62,7 +62,7 @@ tf_cuda_cc_test( "multi_gpu", "no_oss", "notap", - "no_rocm", # flaky on CI as of 2022-05-30 + "cuda-only", # flaky on CI as of 2022-05-30 ], deps = [ "//tensorflow/core:test", diff --git a/tensorflow/core/profiler/backends/gpu/BUILD b/tensorflow/core/profiler/backends/gpu/BUILD index c803eb6b16d1b4..14a102d0af5221 100644 --- a/tensorflow/core/profiler/backends/gpu/BUILD +++ b/tensorflow/core/profiler/backends/gpu/BUILD @@ -18,7 +18,7 @@ tf_cuda_cc_test( tags = tf_cuda_tests_tags() + [ "gpu_cupti", "nomac", - "no_rocm", # flaky on CI + "cuda-only", # flaky on CI ], deps = [ "//tensorflow/cc:cc_ops", diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index 1d5ce0d8676788..42a3388c3a22ed 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -193,7 +193,7 @@ tf_cuda_only_cc_test( size = "small", srcs = ["autotune_serialize_test.cc"], features = ["-layering_check"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":autotune_serialize", ":conv_autotune_maps", diff --git a/tensorflow/dtensor/cc/BUILD b/tensorflow/dtensor/cc/BUILD index cf69454562b81e..39d1ea0c2e88f2 100644 --- a/tensorflow/dtensor/cc/BUILD +++ b/tensorflow/dtensor/cc/BUILD @@ -224,7 +224,7 @@ tf_kernel_library( "dtensor_tpu_kernels.cc", ], tags = [ - "no_rocm", + "cuda-only", "tpu", ], # Disable building of TPU kernels on non-TPU platforms. deps = [ diff --git a/tensorflow/python/compiler/tensorrt/test/BUILD b/tensorflow/python/compiler/tensorrt/test/BUILD index b99bcab6abe090..d19c57d4f21e67 100644 --- a/tensorflow/python/compiler/tensorrt/test/BUILD +++ b/tensorflow/python/compiler/tensorrt/test/BUILD @@ -75,7 +75,7 @@ filegroup( base_tags = [ "no_cuda_on_cpu_tap", - "no_rocm", + "cuda-only", "no_windows", "nomac", # TODO(b/303453873): Re-enable tests once TensorRT has been updated diff --git a/tensorflow/python/debug/lib/BUILD b/tensorflow/python/debug/lib/BUILD index 2ea56871c33a42..d4dce56e1aa53f 100644 --- a/tensorflow/python/debug/lib/BUILD +++ b/tensorflow/python/debug/lib/BUILD @@ -358,7 +358,7 @@ cuda_py_strict_test( shard_count = 4, tags = [ "no_windows", # TODO(b/142475891): Enable this test on Windows. - "no_rocm", #TODO(ROCm) Re-enable after issue is fixed. + "cuda-only", #TODO(ROCm) Re-enable after issue is fixed. ], xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ @@ -390,7 +390,7 @@ cuda_py_strict_test( python_version = "PY3", tags = [ "no_windows_gpu", - "no_rocm", #TODO(ROCm) Re-enable after issue is fixed. + "cuda-only", #TODO(ROCm) Re-enable after issue is fixed. ], deps = [ ":debug_events_reader", diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index a7960d75d0e0ca..09f7dcb2a52ccf 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -746,7 +746,7 @@ distribute_py_strict_test( "multi_and_single_gpu", "no_cuda_asan", # b/213388775 "no_oss", # b/241013307 - "no_rocm", + "cuda-only", "notap", # Flaky; TODO(b/289970206) ], tpu_tags = [ @@ -2581,7 +2581,7 @@ distribute_py_strict_test( "multi_and_single_gpu", "nomac", # TODO(b/201788023): Attempt MultiProcessCluster to fix this. "notpu", - "no_rocm", #times out + "cuda-only", #times out ], deps = [ ":distribute_lib", diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 8bc042c20c7423..23a0f665fcec13 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -169,7 +169,7 @@ tf_py_strict_test( "no_cuda_on_cpu_tap", "no_oss", # TODO(b/206860622): Broken with numpy 1.20+ "no_pip", - "no_rocm", + "cuda-only", "no_windows", ], deps = [ @@ -215,7 +215,7 @@ tf_py_strict_test( "no_cuda_on_cpu_tap", "no_oss", # TODO(b/206860622): Broken with numpy 1.20+ "no_pip", - "no_rocm", + "cuda-only", "no_windows", ], deps = [":feature_column_v2_test_main_lib"], diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index d83d2a022b7828..5f8b951fb92e73 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -1462,7 +1462,7 @@ cuda_py_strict_test( python_version = "PY3", tags = [ "no_pip", # test_ops are not available in pip - "no_rocm", + "cuda-only", ], deps = [ ":config", diff --git a/tensorflow/python/kernel_tests/image_ops/BUILD b/tensorflow/python/kernel_tests/image_ops/BUILD index c63a8bccd5d5e9..0b2bc33fc7e844 100644 --- a/tensorflow/python/kernel_tests/image_ops/BUILD +++ b/tensorflow/python/kernel_tests/image_ops/BUILD @@ -140,7 +140,7 @@ cuda_py_strict_test( shard_count = 15, tags = [ "no_oss", # b/241024908 - "no_rocm", + "cuda-only", "nomac", # b/181799478 "notap", # b/31080670 ], diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 3f6dfc1fe3e6a3..88cc811f08fa7d 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -262,7 +262,7 @@ cuda_py_strict_test( shard_count = 50, tags = [ "no_cuda11", # TODO(b/197522782): reenable test after fixing. - "no_rocm", # extremely slow, thousands of subtests, many triggering + "cuda-only", # extremely slow, thousands of subtests, many triggering # llvm invocations "optonly", # times out, b/79171797 ], diff --git a/tensorflow/python/kernel_tests/math_ops/BUILD b/tensorflow/python/kernel_tests/math_ops/BUILD index 111657b83114fb..23504efdeea273 100644 --- a/tensorflow/python/kernel_tests/math_ops/BUILD +++ b/tensorflow/python/kernel_tests/math_ops/BUILD @@ -263,7 +263,7 @@ cuda_py_strict_test( name = "cwise_ops_binary_test", size = "medium", srcs = ["cwise_ops_binary_test.py"], - tags = ["no_rocm"], #TODO(rocm): weekly sync 240919 + tags = ["cuda-only"], #TODO(rocm): weekly sync 240919 shard_count = 50, # b/140155647: Error just outside of tolerance xla_enable_strict_auto_jit = False, diff --git a/tensorflow/python/kernel_tests/nn_ops/BUILD b/tensorflow/python/kernel_tests/nn_ops/BUILD index dcc1602b83dcd6..f09519d9adb240 100644 --- a/tensorflow/python/kernel_tests/nn_ops/BUILD +++ b/tensorflow/python/kernel_tests/nn_ops/BUILD @@ -294,7 +294,7 @@ cuda_py_strict_test( shard_count = 4, tags = [ "no_mac_arm64", - "no_rocm", + "cuda-only", "optonly", # times out ], deps = [ @@ -405,7 +405,7 @@ cuda_py_strict_test( srcs = ["cudnn_d9m_test.py"], tags = [ "no_cuda_asan", # TODO(b/171509035): re-enable. - "no_rocm", #This is test is specific to CUDA and enables determinism through a CUDA specific env var. + "cuda-only", #This is test is specific to CUDA and enables determinism through a CUDA specific env var. ], deps = [ ":cudnn_deterministic_base", @@ -437,7 +437,7 @@ cuda_py_strict_test( size = "medium", # http://b/30603882 timeout = "long", srcs = ["depthwise_conv_op_d9m_test.py"], - tags = ["no_rocm"], + tags = ["cuda-only"], shard_count = 8, deps = [ ":depthwise_conv_op_base", diff --git a/tensorflow/python/kernel_tests/sparse_ops/BUILD b/tensorflow/python/kernel_tests/sparse_ops/BUILD index ec88a5c81cf44f..e5e172e3fd9a9f 100644 --- a/tensorflow/python/kernel_tests/sparse_ops/BUILD +++ b/tensorflow/python/kernel_tests/sparse_ops/BUILD @@ -108,7 +108,7 @@ cuda_py_strict_test( shard_count = 5, tags = [ "optonly", # b/77589990 - "no_rocm" + "cuda-only" ], deps = [ "//tensorflow/python/eager:def_function", diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index a6486af0837f67..e557f734d96d32 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -1017,7 +1017,7 @@ tf_py_strict_test( srcs = ["collective_ops_test.py"], python_version = "PY3", tags = [ - "no_rocm", + "cuda-only", ], deps = [ ":array_ops", @@ -1048,7 +1048,7 @@ tf_py_strict_test( python_version = "PY3", tags = [ "no_pip", - "no_rocm", + "cuda-only", "no_windows", "nomac", ], @@ -3630,7 +3630,7 @@ cuda_py_strict_test( python_version = "PY3", tags = [ "no_windows_gpu", - "no_rocm", #TODO(rocm): weekly sync 240919 + "cuda-only", #TODO(rocm): weekly sync 240919 ], deps = [ ":array_ops", @@ -3715,7 +3715,7 @@ cuda_py_strict_test( python_version = "PY3", shard_count = 4, tags = [ - "no_rocm", + "cuda-only", ], deps = [ ":nn_grad", @@ -3740,7 +3740,7 @@ cuda_py_strict_test( python_version = "PY3", shard_count = 24, tags = [ - "no_rocm", + "cuda-only", ], deps = [ ":array_ops", @@ -3891,7 +3891,7 @@ cuda_py_strict_test( python_version = "PY3", shard_count = 10, tags = [ - "no_rocm", + "cuda-only", "no_windows_gpu", ], deps = [ @@ -4702,7 +4702,6 @@ cuda_py_strict_test( python_version = "PY3", shard_count = 10, tags = [ - "no_rocm", "no_windows_gpu", ], deps = [ diff --git a/tensorflow/python/ops/numpy_ops/tests/BUILD b/tensorflow/python/ops/numpy_ops/tests/BUILD index 70c9b958895d71..f8c90ed4d907b4 100644 --- a/tensorflow/python/ops/numpy_ops/tests/BUILD +++ b/tensorflow/python/ops/numpy_ops/tests/BUILD @@ -206,7 +206,7 @@ py_strict_test( tags = [ "gpu", "no_pip", - "no_rocm", # Disabling due to excessive test length + "cuda-only", # Disabling due to excessive test length "requires-gpu-nvidia", ], deps = [ diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index b1af2b06fd8a73..c1212af898667f 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -131,7 +131,7 @@ cuda_py_strict_test( shard_count = 16, tags = [ "no_oss", - "no_rocm", + "cuda-only", ], deps = [ ":control_flow_ops", diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 324fcb8389757f..be5a29d969d12d 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -499,7 +499,7 @@ genrule( cmd = ( "$(location :make_aot_compile_models) --out_dir $(@D)" ), - tags = ["no_rocm"], + tags = ["cuda-only"], tools = [":make_aot_compile_models"], ) @@ -514,7 +514,7 @@ saved_model_compile_aot( directory = "//tensorflow/python/tools:x_matmul_y_large", filegroups = [":aot_saved_models"], force_without_xla_support_flag = False, - tags = ["no_rocm"], + tags = ["cuda-only"], ) saved_model_compile_aot( @@ -524,7 +524,7 @@ saved_model_compile_aot( filegroups = [":aot_saved_models"], force_without_xla_support_flag = False, multithreading = True, - tags = ["no_rocm"], + tags = ["cuda-only"], ) saved_model_compile_aot( @@ -533,7 +533,7 @@ saved_model_compile_aot( directory = "//tensorflow/python/tools:x_matmul_y_small", filegroups = [":aot_saved_models"], force_without_xla_support_flag = False, - tags = ["no_rocm"], + tags = ["cuda-only"], ) saved_model_compile_aot( diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index 23460be0233c0d..01a60776a5079e 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -1241,7 +1241,7 @@ cuda_py_strict_test( size = "medium", srcs = ["basic_loops_test.py"], python_version = "PY3", - tags = ["no_rocm"], #TODO(ROCm) Re-enable after issue is fixed. + tags = ["cuda-only"], #TODO(ROCm) Re-enable after issue is fixed. deps = [ ":basic_loops", ":supervisor", diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 45081a2a695599..2fefc752ae4940 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -55,7 +55,7 @@ py_strict_test( tags = [ "no_oss", # b/275546007 "no_pip", - "no_rocm", # No need to rerun this test for ROCm config. + "cuda-only", # No need to rerun this test for ROCm config. "no_windows", # numpy prints differently on windows. "noasan", ], @@ -103,7 +103,7 @@ py_strict_test( python_version = "PY3", tags = [ "no_pip", - "no_rocm", + "cuda-only", "no_windows", # numpy prints differently on windows. "noasan", "nomsan", diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index af513d9c436ce2..0d11a5baecff06 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1604,8 +1604,8 @@ xla_test( name = "gpu_compiler_test", srcs = ["gpu_compiler_test.cc"], backend_tags = { - "gpu_a100": ["no_rocm"], - "gpu_v100": ["no_rocm"], + "gpu_a100": ["cuda-only"], + "gpu_v100": ["cuda-only"], }, backends = ["gpu"], data = ["gpu_compiler_test_autotune_db.textproto"], @@ -1657,7 +1657,6 @@ xla_test( name = "gpu_offloading_test", srcs = ["gpu_offloading_test.cc"], backends = ["gpu"], - tags = ["no_rocm"], #TODO(rocm): weekly sync deps = [ ":backend_configs_cc", "//xla:autotune_results_proto_cc", @@ -2535,7 +2534,7 @@ xla_test( name = "float_support_test", srcs = ["float_support_test.cc"], backend_tags = {"gpu": [ - "no_rocm" + "cuda-only" ]}, backends = [ "gpu_a100", @@ -2875,9 +2874,6 @@ xla_test( "gpu_a100", "gpu_amd_any", ], - tags = [ - "no_rocm", #TODO(rocm): TEMP, weekly sync - ], deps = [ "//xla:literal", "//xla:xla_proto_cc", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 4d063e53792b94..0646af1f3db93f 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -419,7 +419,7 @@ xla_test( ], shard_count = 10, tags = ["no_mac", - "no_rocm"], # TODO(rocm) 240729 + "cuda-only"], # TODO(rocm) 240729 deps = [ ":triton_support", ":triton_test_utils", @@ -474,7 +474,7 @@ xla_cc_test( # TODO(b/353912594): this test does not need to run on GPU, but it is broken on CPU in OSS. # Force it to run on GPU temporarily in order to get important OSS coverage. tags = ["gpu"] + - ["no_rocm"], # TODO(rocm) 240729 + ["cuda-only"], # TODO(rocm) 240729 deps = [ ":triton_fusion_emitter", ":triton_support", @@ -507,7 +507,7 @@ xla_test( "gpu_amd_any", ], tags = ["no_mac", - "no_rocm"], # TODO(rocm) 240729 + "cuda-only"], # TODO(rocm) 240729 deps = [ ":kernel_name_tracer", ":triton_fusion_emitter", diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 2c498c166b601d..834eb4c42be0d9 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1409,7 +1409,7 @@ xla_test( name = "convolution_cudnn_test", timeout = "long", srcs = ["convolution_cudnn_test.cc"], - tags = ["no_rocm"], # No int8 + tags = ["cuda-only"], # No int8 backends = [ "gpu_v100", "gpu_a100", @@ -2525,9 +2525,7 @@ xla_test( "cpu", "gpu", ], - tags = [ - "no_rocm", # doesn't work in CI with reduced resources - "test_xla_cpu_thunks"], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", @@ -3141,7 +3139,6 @@ xla_test( shard_count = 3, tags = [ "optonly", - "no_rocm", #TODO(rocm): TEMP, sync 24-06-24 "test_xla_cpu_thunks", ], deps = [ From c1526607359e50818653502f6831e3aab2760689 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 8 Oct 2024 11:18:31 +0000 Subject: [PATCH 476/483] Disable pjrt_c_api_gpu_test --- third_party/xla/xla/pjrt/c/BUILD | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index 023de87343880c..7bf109fab09850 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -365,7 +365,9 @@ xla_test( name = "pjrt_c_api_gpu_test", srcs = ["pjrt_c_api_gpu_test.cc"], backends = ["gpu"], - tags = if_google([ + tags = [ + "cuda-only", #TODO(rocm): weekly sync 241001 + ] + if_google([ "config-cuda-only", ]), deps = [ From 5dbad84c23c11d622239f5f25832135208d69ac4 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Wed, 9 Oct 2024 12:51:55 +0000 Subject: [PATCH 477/483] Disable failing unit tests --- third_party/xla/xla/service/gpu/autotuning/BUILD | 1 + third_party/xla/xla/service/gpu/transforms/BUILD | 1 + third_party/xla/xla/tests/BUILD | 8 +++++++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 7a51cf0c3c8e11..8e2c82a9c3a51c 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -514,6 +514,7 @@ xla_cc_test( ], tags = [ "gpu", + "cuda-only", #TODO(rocm): weekly sync 24-10-01 ], deps = [ ":autotuner_util", diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 1deb2aa6dddb23..c1543f6674d5bd 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -1708,6 +1708,7 @@ cc_library( xla_test( name = "gemm_rewriter_test", srcs = ["gemm_rewriter_test.cc"], + tags = ["cuda-only",] #TODO(rocm): weekly sync 24-01-10 backends = ["gpu"], shard_count = 5, deps = [ diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 834eb4c42be0d9..d9f824858721df 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -931,6 +931,7 @@ xla_test( tags = [ "optonly", "test_xla_cpu_thunks", + "cuda-only", #TODO(rocm): weekly sync 24-10-01 ], deps = [ ":client_library_test_base", @@ -990,6 +991,7 @@ xla_test( # TODO(b/151340488): Timed out on 2020-03-12. "nozapfhahn", "test_xla_cpu_thunks", + "cuda-only", #TODO(rocm): weekly sync 24-10-01 ], deps = [ ":client_library_test_base", @@ -1121,6 +1123,7 @@ xla_test( tags = [ "optonly", "test_xla_cpu_thunks", + "cuda-only", #TODO(rocm): weekly sync 24-10-01 ], deps = [ ":client_library_test_base", @@ -2269,7 +2272,10 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = [ + "test_xla_cpu_thunks", + "cuda-only", #TODO(rocm): weekly sync 24-10-01 + ], deps = [ ":client_library_test_base", ":test_macros_header", From 5c1aacac4c44b8b99b8a17cd4cf12b44c0e93b21 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Wed, 9 Oct 2024 13:52:55 +0000 Subject: [PATCH 478/483] Skip failing subtests instead of whole tests --- .../xla/xla/service/gpu/transforms/BUILD | 1 - .../gpu/transforms/gemm_rewriter_test.cc | 36 +++++++++++++++++++ third_party/xla/xla/tests/BUILD | 2 -- .../xla/xla/tests/dot_operation_test.cc | 2 ++ 4 files changed, 38 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index c1543f6674d5bd..1deb2aa6dddb23 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -1708,7 +1708,6 @@ cc_library( xla_test( name = "gemm_rewriter_test", srcs = ["gemm_rewriter_test.cc"], - tags = ["cuda-only",] #TODO(rocm): weekly sync 24-01-10 backends = ["gpu"], shard_count = 5, deps = [ diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc index 1df17a83a59f05..6acc0570f17d14 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -6022,6 +6022,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } const char* hlo_text = R"( HloModule test @@ -6230,6 +6234,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } const char* hlo_text = R"( HloModule test @@ -6339,6 +6347,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } const char* hlo_text = R"( HloModule test ENTRY test { @@ -6406,6 +6418,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } const char* hlo_text = R"( HloModule test @@ -6486,6 +6502,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } const char* hlo_text = R"( HloModule test @@ -7300,6 +7320,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } const char* hlo_text = R"( HloModule test @@ -7377,6 +7401,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8WithF16Intermediates) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } // This is the same as ScaledABScaledDWithDAmaxF8, but uses F16 intermediate // values instead of F32 intermediate values. const char* hlo_text = R"( @@ -7459,6 +7487,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationWithDAmaxF8) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } const char* hlo_text = R"( HloModule test @@ -7734,6 +7766,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { } TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { + if(IsRocm()) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; + } // Test that FNUZ FP8 gemms are not rewritten, as cuBLAS does not support them const char* hlo_text = R"( HloModule test diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index d9f824858721df..78175b0a13c722 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -931,7 +931,6 @@ xla_test( tags = [ "optonly", "test_xla_cpu_thunks", - "cuda-only", #TODO(rocm): weekly sync 24-10-01 ], deps = [ ":client_library_test_base", @@ -991,7 +990,6 @@ xla_test( # TODO(b/151340488): Timed out on 2020-03-12. "nozapfhahn", "test_xla_cpu_thunks", - "cuda-only", #TODO(rocm): weekly sync 24-10-01 ], deps = [ ":client_library_test_base", diff --git a/third_party/xla/xla/tests/dot_operation_test.cc b/third_party/xla/xla/tests/dot_operation_test.cc index 526a13b62d5db3..c4dba6bac696ec 100644 --- a/third_party/xla/xla/tests/dot_operation_test.cc +++ b/third_party/xla/xla/tests/dot_operation_test.cc @@ -314,6 +314,8 @@ class ParametricDotTest : public DotOperationTest, std::string_view name( ::testing::UnitTest::GetInstance()->current_test_info()->name()); if (name.find("TestF16/270x270x520_MajorToMinor") != std::string::npos) { + // TODO(rocm): weekly sync 24-10-01 + GTEST_SKIP() << "Currently failing on ROCm!"; execution_options_.mutable_debug_options()->set_xla_gpu_autotune_level( 0); DotTestParam param = GetParam(); From f79915189d42c069fbc993874587bae28579e339 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Wed, 9 Oct 2024 16:02:42 +0000 Subject: [PATCH 479/483] Disable triton_fusion_emitter_device_legacy_test --- third_party/xla/xla/service/gpu/fusions/triton/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 0646af1f3db93f..f3c29ec00e0c0d 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -205,6 +205,7 @@ xla_test( shard_count = 20, tags = [ "no_mac", + "cuda-only", #TODO(rocm): weekly-sync 24-10-01 ], deps = [ ":kernel_name_tracer", From d377ce14977da5f7b6224f855c11858a5962475f Mon Sep 17 00:00:00 2001 From: Harsha H S Date: Mon, 7 Oct 2024 01:03:54 -0700 Subject: [PATCH 480/483] PR #17900: [ROCm] Fix build break in executor and kernel test introduced in f896afd Imported from GitHub PR https://github.com/openxla/xla/pull/17900 Copybara import of the project: -- f9bd89ce7fa5fd297baaef4e5936847abc4d59f9 by Harsha HS : [ROCm] Fix build break in executor and kernel test introduced in f896afd Merging this change closes #17900 PiperOrigin-RevId: 683073855 --- third_party/xla/xla/stream_executor/rocm/BUILD | 5 ++++- .../stream_executor/rocm/rocm_executor_test.cc | 17 ++++++----------- .../stream_executor/rocm/rocm_kernel_test.cc | 1 - 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 467ff4fa6e73cc..f5dc8198ddd3a6 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -222,7 +222,10 @@ xla_test( srcs = ["rocm_executor_test.cc"], backends = ["gpu_amd_any"], tags = ["rocm-only"], - deps = [":rocm_executor"], + deps = [ + ":rocm_executor", + "@com_google_googletest//:gtest_main", + ], ) cc_library( diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc index 0716b5c3d0ee17..1ed4fed1b0e462 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc @@ -17,22 +17,17 @@ limitations under the License. #include #include -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/test.h" +#include "xla/stream_executor/device_description.h" namespace stream_executor::gpu { namespace { -using testing::Field; using testing::Ge; using testing::IsEmpty; using testing::Not; -using testing::VariantWith; TEST(RocmExecutorTest, CreateDeviceDescription) { - TF_ASSERT_OK(GpuDriver::Init()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - CudaExecutor::CreateDeviceDescription(0)); + RocmExecutor::CreateDeviceDescription(0)); constexpr SemanticVersion kNullVersion{0, 0, 0}; EXPECT_NE(result->runtime_version(), kNullVersion); @@ -44,10 +39,10 @@ TEST(RocmExecutorTest, CreateDeviceDescription) { EXPECT_THAT(result->model_str(), Not(IsEmpty())); EXPECT_THAT(result->device_vendor(), "Advanced Micro Devices, Inc"); - EXPECT_THAT(result->gpu_compute_capability(), - VariantWith( - Field("gcn_arch_name", &RocmComputeCapability::gcn_arch_name, - Not(IsEmpty())))); + EXPECT_THAT( + std::get_if(&result->gpu_compute_capability()) + ->gcn_arch_name(), + Not(IsEmpty())); } } // namespace diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc index cfff348f9b5b11..f3cc0b6ea781de 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_kernel.h" #include -#include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/launch_dim.h" From 7c9276b529a0f544c7a7c6e83bf9c02b3dbb54ae Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 1 Oct 2024 08:58:45 -0700 Subject: [PATCH 481/483] Remove more unnecessary forward declarations from stream_executor Many of them weren't even needed anymore. For some others I included the header with the full type declaration. Previously this was not possible in many cases due to cyclic dependencies which have been removed. PiperOrigin-RevId: 681040224 --- third_party/xla/xla/stream_executor/BUILD | 4 ++++ third_party/xla/xla/stream_executor/blas.h | 11 ++--------- third_party/xla/xla/stream_executor/cuda/BUILD | 1 + .../xla/stream_executor/cuda/cuda_blas_lt.h | 5 +---- .../xla/xla/stream_executor/cuda/cuda_fft.h | 18 ++++++++---------- .../xla/xla/stream_executor/device_memory.h | 2 -- third_party/xla/xla/stream_executor/dnn.h | 6 ++---- .../xla/xla/stream_executor/gpu/gpu_stream.h | 2 -- .../xla/xla/stream_executor/gpu/gpu_timer.h | 3 +-- third_party/xla/xla/stream_executor/kernel.h | 2 -- third_party/xla/xla/stream_executor/rocm/BUILD | 2 ++ .../xla/xla/stream_executor/rocm/hip_blas_lt.h | 5 +---- .../xla/xla/stream_executor/rocm/rocm_blas.h | 3 +-- .../xla/xla/stream_executor/rocm/rocm_fft.h | 5 +---- 14 files changed, 24 insertions(+), 45 deletions(-) diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index a450d0e7b3c42c..1c7f1fdd42b369 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -411,6 +411,8 @@ cc_library( ":data_type", ":device_memory", ":numeric_options", + ":scratch_allocator", + ":stream", "//xla/stream_executor/platform", "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/log", @@ -430,6 +432,8 @@ cc_library( ":device_description_proto_cc", ":device_memory", ":numeric_options", + ":scratch_allocator", + ":stream", "//xla/stream_executor/platform", "//xla/tsl/lib/strings:proto_serialization", "//xla/tsl/protobuf:dnn_proto_cc", diff --git a/third_party/xla/xla/stream_executor/blas.h b/third_party/xla/xla/stream_executor/blas.h index 73814f0e467a3e..3153628b43fc9f 100644 --- a/third_party/xla/xla/stream_executor/blas.h +++ b/third_party/xla/xla/stream_executor/blas.h @@ -35,6 +35,8 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" #include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/errors.h" @@ -50,15 +52,6 @@ struct MatrixDescriptor; struct OutputMatrixDescriptor; } // namespace gpu -class Stream; -class ScratchAllocator; - -template -class DeviceMemory; - -template -class HostOrDeviceScalar; - template using DeviceMemorySlice = absl::Span *const>; diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 68fdd4a79ad45b..5d8b163e80f84a 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -362,6 +362,7 @@ cuda_only_cc_library( "//xla/stream_executor:blas", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/stream_executor/gpu:gpu_executor_header", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h index 3d61c816024af9..e25b9e32797ece 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h @@ -33,14 +33,11 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/types.h" namespace stream_executor { -namespace gpu { -class GpuExecutor; -} // namespace gpu - namespace cuda { class BlasLt : public gpu::BlasLt { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_fft.h b/third_party/xla/xla/stream_executor/cuda/cuda_fft.h index 56fc78fb360219..10d2e3fdf83a74 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_fft.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_fft.h @@ -27,17 +27,15 @@ limitations under the License. #include "absl/status/status.h" #include "third_party/gpus/cuda/include/cufft.h" #include "xla/stream_executor/fft.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" namespace stream_executor { -class Stream; - namespace gpu { -class GpuExecutor; - // CUDAFftPlan uses deferred initialization. Only a single call of // Initialize() is allowed to properly create cufft plan and set member // variable is_initialized_ to true. Newly added interface that uses member @@ -116,17 +114,17 @@ class CUDAFft : public fft::FftSupport { // This is for complex to complex FFT, when the direction is required. template - bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, + bool DoFftWithDirectionInternal(Stream* stream, fft::Plan* plan, FuncT cufft_exec, - const DeviceMemory &input, - DeviceMemory *output); + const DeviceMemory& input, + DeviceMemory* output); // This is for complex to real or real to complex FFT, when the direction // is implied. template - bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufft_exec, - const DeviceMemory &input, - DeviceMemory *output); + bool DoFftInternal(Stream* stream, fft::Plan* plan, FuncT cufft_exec, + const DeviceMemory& input, + DeviceMemory* output); CUDAFft(const CUDAFft&) = delete; void operator=(const CUDAFft&) = delete; diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index 5334e79f4565c6..43b645b4c345df 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -35,8 +35,6 @@ limitations under the License. namespace stream_executor { -class DeviceMemoryAllocator; - // void*-analogous device memory allocation. For the typed variation, see // DeviceMemory. // diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index e6d3e67c68c87d..f6eea6ddd47e8e 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -44,6 +44,8 @@ limitations under the License. #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" #include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/logging.h" @@ -53,10 +55,6 @@ struct half; namespace stream_executor { -class HostBuffer; -class Stream; -class ScratchAllocator; - namespace dnn { // Specifies an index to use when accessing specific spatial dimensions. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index 249fbf78877a4e..8c82e7f8b9f755 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -43,8 +43,6 @@ limitations under the License. namespace stream_executor { namespace gpu { -class GpuExecutor; - // Wraps a GpuStreamHandle in order to satisfy the platform-independent // StreamInterface. // diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h index 14002f61bc7478..ef0508589f8b66 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h @@ -25,6 +25,7 @@ limitations under the License. #include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" +#include "xla/stream_executor/gpu/gpu_stream.h" namespace xla { namespace gpu { @@ -35,8 +36,6 @@ class DeterminismTest; namespace stream_executor { namespace gpu { -class GpuStream; - // When a timer is created it launches a delay kernel into the given stream and // queues a start event immediately afterwards. This delay kernel blocks // execution on the stream until GetElapsedDuration() is called, at which point diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index f03b373b5a4d76..6076717d430598 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -94,8 +94,6 @@ limitations under the License. namespace stream_executor { -class Kernel; - //===----------------------------------------------------------------------===// // Kernel metadata //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index f5dc8198ddd3a6..3060bd91700e85 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -695,6 +695,7 @@ cc_library( "//xla/stream_executor:event_based_timer", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:scoped_activate_context", @@ -732,6 +733,7 @@ cc_library( "//xla/stream_executor:blas", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index c839f51866ac0b..eee64fd5446f64 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -18,6 +18,7 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/host_or_device_scalar.h" #include "xla/types.h" @@ -28,10 +29,6 @@ limitations under the License. namespace stream_executor { -namespace gpu { -class GpuExecutor; -} // namespace gpu - namespace rocm { class BlasLt : public gpu::BlasLt { diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.h b/third_party/xla/xla/stream_executor/rocm/rocm_blas.h index 6199d0e551a815..329ad155fd1a4f 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.h @@ -33,6 +33,7 @@ limitations under the License. #endif #include "xla/stream_executor/blas.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" #if TF_HIPBLASLT @@ -75,8 +76,6 @@ using RocBlasType_t = rocblas_float_complex, std::complex, rocblas_double_complex>::type; -class GpuExecutor; - // BLAS plugin for ROCM platform via rocBLAS library. // // This satisfies the platform-agnostic BlasSupport interface. diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_fft.h b/third_party/xla/xla/stream_executor/rocm/rocm_fft.h index dad6f3e0864f19..fa51950f824a6b 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_fft.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_fft.h @@ -33,6 +33,7 @@ limitations under the License. #endif #include "xla/stream_executor/fft.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" @@ -40,12 +41,8 @@ limitations under the License. namespace stream_executor { -class Stream; - namespace gpu { -class GpuExecutor; - // ROCMFftPlan uses deferred initialization. Only a single call of // Initialize() is allowed to properly create hipfft plan and set member // variable is_initialized_ to true. Newly added interface that uses member From f5f981d433e600facd628e5568dc7412f12e9481 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Thu, 10 Oct 2024 09:59:51 +0000 Subject: [PATCH 482/483] Remove unnecessary hip_runtime.h --- third_party/xla/xla/stream_executor/rocm/rocm_kernel.h | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h index 26d20b667e7609..985b6f2f85a1a3 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h @@ -26,7 +26,6 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/launch_dim.h" From b540099535f9b04a9d171507a8da9c9ab9d3fdc5 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Thu, 10 Oct 2024 11:13:11 +0000 Subject: [PATCH 483/483] Enable gemm_rewriter_test (https://github.com/openxla/xla/pull/18062) --- .../gpu/transforms/gemm_rewriter_test.cc | 60 +++---------------- 1 file changed, 8 insertions(+), 52 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc index 6acc0570f17d14..1128662094d37a 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -6022,10 +6022,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } const char* hlo_text = R"( HloModule test @@ -6060,9 +6056,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6234,10 +6229,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } const char* hlo_text = R"( HloModule test @@ -6282,8 +6273,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), -; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6347,10 +6337,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } const char* hlo_text = R"( HloModule test ENTRY test { @@ -6396,8 +6382,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), -; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6418,10 +6403,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } const char* hlo_text = R"( HloModule test @@ -6481,7 +6462,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { ; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5) ; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), ; CHECK-NOT: output_to_operand_aliasing -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6502,10 +6483,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } const char* hlo_text = R"( HloModule test @@ -6558,8 +6535,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { ; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) ; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]], /*index=5*/[[CV2]]), -; CHECK-GCN: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7320,10 +7296,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } const char* hlo_text = R"( HloModule test @@ -7378,8 +7350,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), -; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7401,10 +7372,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8WithF16Intermediates) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } // This is the same as ScaledABScaledDWithDAmaxF8, but uses F16 intermediate // values instead of F32 intermediate values. const char* hlo_text = R"( @@ -7464,8 +7431,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[P4_INV_CONVERT]]), -; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7487,10 +7453,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationWithDAmaxF8) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } const char* hlo_text = R"( HloModule test @@ -7547,8 +7509,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), -; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7766,10 +7727,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { } TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { - if(IsRocm()) { - // TODO(rocm): weekly sync 24-10-01 - GTEST_SKIP() << "Currently failing on ROCm!"; - } // Test that FNUZ FP8 gemms are not rewritten, as cuBLAS does not support them const char* hlo_text = R"( HloModule test @@ -7822,7 +7779,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { ; CHECK-GCN-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-GCN-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-GCN-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX: custom_call_target="<>", ; CHECK-GCN: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={